<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/">
  <channel>
    <title>Krishnatheja Vanka</title>
    <link>https://theja-vanka.github.io/blogs/</link>
    <description>Applied Scientist and Machine Learning Engineer writing about ML research, model deployment, and production systems.</description>
    <language>en-us</language>
    <atom:link href="https://theja-vanka.github.io/blogs/feed.xml" rel="self" type="application/rss+xml"/>
    <lastBuildDate>Mon, 15 Jun 2026 01:38:26 GMT</lastBuildDate>
    <item>
      <title><![CDATA[AlexNet: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/foundation/alexnet/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/foundation/alexnet/</guid>
      <pubDate>Mon, 15 Jun 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <category>foundation model</category>
      <content:encoded><![CDATA[






<section id="alexnet-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/foundation/alexnet/alex.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>AlexNet is a deep convolutional neural network (CNN) that fundamentally changed the landscape of computer vision and machine learning when it was introduced in 2012. Designed by Alex Krizhevsky, Ilya Sutskever, and Geoffrey Hinton at the University of Toronto, the network achieved a top-5 error rate of 15.3% on the ImageNet Large Scale Visual Recognition Challenge (ILSVRC-2012), compared to 26.2% achieved by the second-best entry. This ~11 percentage point gap was unprecedented and sent shockwaves through the research community.</p>
<p>Before AlexNet, classical computer vision techniques — hand-crafted feature extractors like HOG (Histogram of Oriented Gradients), SIFT (Scale-Invariant Feature Transform), and SURF — dominated competitive benchmarks. These methods required deep domain expertise and painstaking engineering. AlexNet demonstrated that deep learning, trained end-to-end on raw pixel data, could outperform these approaches by a substantial margin.</p>
<p>The paper describing AlexNet — <em>“ImageNet Classification with Deep Convolutional Neural Networks”</em> <span class="citation" data-cites="krizhevsky2012imagenet">[@krizhevsky2012imagenet]</span> — became one of the most cited papers in the history of computer science, and is widely regarded as the catalyst for the modern deep learning era.</p>
<hr>
</section>
<section id="historical-context" class="level2">
<h2 class="anchored" data-anchor-id="historical-context" id="historical-context">Historical Context</h2>
<section id="the-imagenet-dataset" class="level3">
<h3 class="anchored" data-anchor-id="the-imagenet-dataset" id="the-imagenet-dataset">The ImageNet Dataset</h3>
<p>To appreciate AlexNet’s significance, we must first understand the challenge it was designed to tackle. ImageNet is a massive visual database organized according to the WordNet hierarchy. For the ILSVRC competition, the dataset contained:</p>
<ul>
<li>~1.2 million training images</li>
<li>50,000 validation images</li>
<li>150,000 test images</li>
<li>1,000 object categories (classes)</li>
</ul>
<p>This scale was unprecedented at the time. Prior CNN architectures (like LeNet-5, introduced in 1998 for digit recognition) were trained on small datasets with grayscale images of a single domain. The sheer diversity and volume of ImageNet posed a completely different engineering and statistical challenge.</p>
</section>
<section id="the-state-of-deep-learning-before-2012" class="level3">
<h3 class="anchored" data-anchor-id="the-state-of-deep-learning-before-2012" id="the-state-of-deep-learning-before-2012">The State of Deep Learning Before 2012</h3>
<p>Neural networks had fallen somewhat out of favor in the 2000s. Despite theoretical appeal, they were difficult to train at scale due to:</p>
<ul>
<li><strong>Vanishing gradients</strong>: Deep networks were notoriously hard to train because gradients diminished as they backpropagated through many layers.</li>
<li><strong>Computational constraints</strong>: Training large networks on CPUs was prohibitively slow.</li>
<li><strong>Overfitting</strong>: With millions of parameters and limited regularization techniques, large models quickly overfit to small datasets.</li>
</ul>
<p>Researchers like Yann LeCun had demonstrated the power of CNNs for constrained domains (handwritten digits) <span class="citation" data-cites="lecun1998gradient">[@lecun1998gradient]</span>, but scaling to general object recognition remained elusive. Geoffrey Hinton’s group had been steadily working on deep network training through the 2000s (deep belief networks, restricted Boltzmann machines), laying the groundwork for what was to come.</p>
</section>
<section id="the-gpu-revolution" class="level3">
<h3 class="anchored" data-anchor-id="the-gpu-revolution" id="the-gpu-revolution">The GPU Revolution</h3>
<p>The critical enabling factor for AlexNet was the availability of fast, programmable GPUs — specifically NVIDIA’s CUDA platform (introduced in 2006–2007), which allowed general-purpose computation on graphics cards. By 2012, a pair of NVIDIA GTX 580 GPUs with 3 GB of VRAM each gave the Toronto team enough raw computational power to train a massive network in a tractable amount of time (about 5–6 days). This hardware innovation made AlexNet possible.</p>
<hr>
</section>
</section>
<section id="architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h2>
<p>AlexNet is a deep convolutional neural network with <strong>8 learned layers</strong>: 5 convolutional layers and 3 fully connected layers. The network takes a fixed-size input of 224×224 RGB images (in practice the paper used 227×227 — a common source of confusion due to an off-by-one in the original paper’s dimension calculations) and outputs a probability distribution over 1,000 classes via a softmax function.</p>
<p>Here is a high-level summary of the architecture:</p>
<div id="tbl-arch" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-arch-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: AlexNet architecture summary
</figcaption>
<div aria-describedby="tbl-arch-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 8%">
<col style="width: 31%">
<col style="width: 15%">
<col style="width: 44%">
</colgroup>
<thead>
<tr class="header">
<th>Layer</th>
<th>Type</th>
<th>Output Size</th>
<th>Key Parameters</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Input</td>
<td>—</td>
<td>227×227×3</td>
<td>—</td>
</tr>
<tr class="even">
<td>Conv1</td>
<td>Conv + ReLU + LRN + Pool</td>
<td>27×27×96</td>
<td>96 filters, 11×11, stride 4</td>
</tr>
<tr class="odd">
<td>Conv2</td>
<td>Conv + ReLU + LRN + Pool</td>
<td>13×13×256</td>
<td>256 filters, 5×5, stride 1, pad 2</td>
</tr>
<tr class="even">
<td>Conv3</td>
<td>Conv + ReLU</td>
<td>13×13×384</td>
<td>384 filters, 3×3, stride 1, pad 1</td>
</tr>
<tr class="odd">
<td>Conv4</td>
<td>Conv + ReLU</td>
<td>13×13×384</td>
<td>384 filters, 3×3, stride 1, pad 1</td>
</tr>
<tr class="even">
<td>Conv5</td>
<td>Conv + ReLU + Pool</td>
<td>6×6×256</td>
<td>256 filters, 3×3, stride 1, pad 1</td>
</tr>
<tr class="odd">
<td>FC6</td>
<td>FC + ReLU + Dropout</td>
<td>4096</td>
<td>4096 neurons</td>
</tr>
<tr class="even">
<td>FC7</td>
<td>FC + ReLU + Dropout</td>
<td>4096</td>
<td>4096 neurons</td>
</tr>
<tr class="odd">
<td>FC8</td>
<td>FC + Softmax</td>
<td>1000</td>
<td>1000 neurons</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>The total parameter count is approximately <strong>62.3 million</strong>, which was extraordinarily large for its time.</p>
<hr>
</section>
<section id="layer-by-layer-breakdown" class="level2">
<h2 class="anchored" data-anchor-id="layer-by-layer-breakdown" id="layer-by-layer-breakdown">Layer-by-Layer Breakdown</h2>
<section id="input" class="level3">
<h3 class="anchored" data-anchor-id="input" id="input">Input</h3>
<ul>
<li><strong>Size</strong>: 227×227×3 (height × width × RGB channels)</li>
<li>Images are preprocessed by subtracting the per-channel mean computed over the training set. This zero-centers the data, which helps with training stability.</li>
<li>During training, 227×227 patches are randomly cropped from 256×256 images (data augmentation — discussed in detail in <a href="#sec-augmentation" class="quarto-xref">Section&nbsp;1.5.6</a>).</li>
</ul>
</section>
<section id="layer-1-convolutional-layer-conv1" class="level3">
<h3 class="anchored" data-anchor-id="layer-1-convolutional-layer-conv1" id="layer-1-convolutional-layer-conv1">Layer 1 — Convolutional Layer (Conv1)</h3>
<p><strong>Operation</strong>: Convolution → ReLU → Local Response Normalization → Max Pooling</p>
<ul>
<li><strong>Filters</strong>: 96 kernels of size <strong>11×11×3</strong>, applied with <strong>stride 4</strong></li>
<li><strong>Output before pooling</strong>: (227 - 11) / 4 + 1 = <strong>55×55×96</strong></li>
<li><strong>LRN</strong>: Applied across channels (described in <a href="#sec-lrn" class="quarto-xref">Section&nbsp;1.5.3</a>)</li>
<li><strong>Max Pooling</strong>: 3×3 kernel, stride 2 → output <strong>27×27×96</strong></li>
<li><strong>Parameters</strong>: 96 × (11×11×3 + 1 bias) = 96 × 364 = <strong>34,944</strong></li>
</ul>
<p>The large 11×11 kernels in the first layer capture low-level features such as edges, colors, and basic textures at multiple orientations. The aggressive stride of 4 dramatically reduces spatial dimensions early, keeping computation tractable. The 96 filters learn a diverse set of Gabor-like edge detectors and color blobs — visualizations of these learned filters were famously included in the original paper and became iconic images in the deep learning literature.</p>
</section>
<section id="layer-2-convolutional-layer-conv2" class="level3">
<h3 class="anchored" data-anchor-id="layer-2-convolutional-layer-conv2" id="layer-2-convolutional-layer-conv2">Layer 2 — Convolutional Layer (Conv2)</h3>
<p><strong>Operation</strong>: Convolution → ReLU → Local Response Normalization → Max Pooling</p>
<ul>
<li><strong>Filters</strong>: 256 kernels of size <strong>5×5×48</strong> (per GPU, since the 96 channels are split across 2 GPUs), effectively <strong>5×5×96</strong> when combined</li>
<li><strong>Stride</strong>: 1, <strong>Padding</strong>: 2 (same padding)</li>
<li><strong>Output before pooling</strong>: <strong>27×27×256</strong></li>
<li><strong>LRN</strong>: Applied</li>
<li><strong>Max Pooling</strong>: 3×3 kernel, stride 2 → output <strong>13×13×256</strong></li>
<li><strong>Parameters</strong>: 256 × (5×5×96 + 1) = 256 × 2,401 = <strong>614,656</strong></li>
</ul>
<p>The smaller 5×5 kernels in Conv2 build on the edge detectors from Conv1, combining them into more complex texture and shape detectors. The large increase in filter count (from 96 to 256) allows the network to represent a richer vocabulary of intermediate features. This layer captures corners, curves, and simple textures.</p>
</section>
<section id="layer-3-convolutional-layer-conv3" class="level3">
<h3 class="anchored" data-anchor-id="layer-3-convolutional-layer-conv3" id="layer-3-convolutional-layer-conv3">Layer 3 — Convolutional Layer (Conv3)</h3>
<p><strong>Operation</strong>: Convolution → ReLU</p>
<ul>
<li><strong>Filters</strong>: 384 kernels of size <strong>3×3×256</strong></li>
<li><strong>Stride</strong>: 1, <strong>Padding</strong>: 1 (same padding)</li>
<li><strong>Output</strong>: <strong>13×13×384</strong></li>
<li><strong>No pooling, no LRN</strong></li>
<li><strong>Parameters</strong>: 384 × (3×3×256 + 1) = 384 × 2,305 = <strong>884,992</strong></li>
</ul>
<p>Conv3 is the first layer where both GPU streams interact — the full 256-channel input (from both halves of Conv2) feeds into all 384 filters. This cross-GPU communication was a deliberate design choice to allow the two GPU streams to mix learned representations. Conv3 captures higher-level textures and object parts.</p>
</section>
<section id="layer-4-convolutional-layer-conv4" class="level3">
<h3 class="anchored" data-anchor-id="layer-4-convolutional-layer-conv4" id="layer-4-convolutional-layer-conv4">Layer 4 — Convolutional Layer (Conv4)</h3>
<p><strong>Operation</strong>: Convolution → ReLU</p>
<ul>
<li><strong>Filters</strong>: 384 kernels of size <strong>3×3×192</strong> (per GPU, each seeing half the 384 channels)</li>
<li><strong>Stride</strong>: 1, <strong>Padding</strong>: 1</li>
<li><strong>Output</strong>: <strong>13×13×384</strong></li>
<li><strong>No pooling, no LRN</strong></li>
<li><strong>Parameters</strong>: 384 × (3×3×192 + 1) = 384 × 1,729 = <strong>663,936</strong></li>
</ul>
<p>Conv4 continues refining high-level feature representations. The two GPU streams remain separate in this layer (unlike Conv3). Neurons in this layer have receptive fields covering large portions of the original input, allowing them to detect object parts and their spatial relationships.</p>
</section>
<section id="layer-5-convolutional-layer-conv5" class="level3">
<h3 class="anchored" data-anchor-id="layer-5-convolutional-layer-conv5" id="layer-5-convolutional-layer-conv5">Layer 5 — Convolutional Layer (Conv5)</h3>
<p><strong>Operation</strong>: Convolution → ReLU → Max Pooling</p>
<ul>
<li><strong>Filters</strong>: 256 kernels of size <strong>3×3×192</strong> (per GPU)</li>
<li><strong>Stride</strong>: 1, <strong>Padding</strong>: 1</li>
<li><strong>Output before pooling</strong>: <strong>13×13×256</strong></li>
<li><strong>Max Pooling</strong>: 3×3 kernel, stride 2 → output <strong>6×6×256</strong></li>
<li><strong>Parameters</strong>: 256 × (3×3×192 + 1) = 256 × 1,729 = <strong>442,624</strong></li>
</ul>
<p>Conv5 is the final convolutional layer. After the max pooling, the spatial map is reduced to 6×6, and the output is flattened to a 6×6×256 = <strong>9,216-dimensional vector</strong> before entering the fully connected layers. By this stage, each neuron in the feature map has a receptive field spanning the majority of the original 227×227 image.</p>
</section>
<section id="layers-68-fully-connected-layers" class="level3">
<h3 class="anchored" data-anchor-id="layers-68-fully-connected-layers" id="layers-68-fully-connected-layers">Layers 6–8 — Fully Connected Layers</h3>
<p><strong>FC6</strong>:</p>
<ul>
<li><strong>Neurons</strong>: 4,096</li>
<li><strong>Operation</strong>: Linear → ReLU → Dropout (p=0.5)</li>
<li><strong>Input</strong>: 9,216-dimensional vector</li>
<li><strong>Parameters</strong>: 9,216 × 4,096 + 4,096 = <strong>37,752,832</strong></li>
</ul>
<p><strong>FC7</strong>:</p>
<ul>
<li><strong>Neurons</strong>: 4,096</li>
<li><strong>Operation</strong>: Linear → ReLU → Dropout (p=0.5)</li>
<li><strong>Input</strong>: 4,096-dimensional vector</li>
<li><strong>Parameters</strong>: 4,096 × 4,096 + 4,096 = <strong>16,781,312</strong></li>
</ul>
<p><strong>FC8</strong>:</p>
<ul>
<li><strong>Neurons</strong>: 1,000</li>
<li><strong>Operation</strong>: Linear → Softmax</li>
<li><strong>Input</strong>: 4,096-dimensional vector</li>
<li><strong>Parameters</strong>: 4,096 × 1,000 + 1,000 = <strong>4,097,000</strong></li>
</ul>
<p>The fully connected layers serve as the “classifier head” on top of the convolutional feature extractor. FC6 and FC7 learn complex non-linear combinations of the convolutional features. The 4,096-dimensional activations of FC6/FC7 became widely used as general-purpose image feature vectors (a precursor to modern transfer learning). FC8 maps these features to the 1,000 class logits, which are then normalized by softmax to produce a probability distribution.</p>
</section>
<section id="output-layer" class="level3">
<h3 class="anchored" data-anchor-id="output-layer" id="output-layer">Output Layer</h3>
<ul>
<li><strong>Neurons</strong>: 1,000 (one per ImageNet class)</li>
<li><strong>Activation</strong>: <strong>Softmax</strong></li>
</ul>
<p>The softmax function converts raw logits <span class="math inline">\(z_i\)</span> into probabilities:</p>
<p><span id="eq-softmax"><span class="math display">\[
P(\text{class} = i) = \frac{e^{z_i}}{\sum_j e^{z_j}}
\tag{1}\]</span></span></p>
<p>The predicted class is the one with the highest probability. During training, the cross-entropy loss between these probabilities and the one-hot encoded ground truth label is minimized.</p>
<hr>
</section>
</section>
<section id="key-innovations" class="level2">
<h2 class="anchored" data-anchor-id="key-innovations" id="key-innovations">Key Innovations</h2>
<p>AlexNet did not invent any single technique from scratch, but it brought together a set of innovations — some novel, some previously known but underutilized — into a package that decisively solved a hard practical problem. Each innovation is described in detail below.</p>
<section id="sec-relu" class="level3">
<h3 class="anchored" data-anchor-id="sec-relu" id="sec-relu">ReLU Activation</h3>
<p><strong>The Problem with Saturating Activations</strong></p>
<p>Prior to AlexNet, the most commonly used activation functions in neural networks were the <strong>sigmoid</strong> function:</p>
<p><span id="eq-sigmoid"><span class="math display">\[
\sigma(x) = \frac{1}{1 + e^{-x}}
\tag{2}\]</span></span></p>
<p>and the <strong>hyperbolic tangent (tanh)</strong>:</p>
<p><span id="eq-tanh"><span class="math display">\[
\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
\tag{3}\]</span></span></p>
<p>Both of these are <em>saturating</em> functions — their gradients approach zero as their inputs become very large or very small. This causes the <strong>vanishing gradient problem</strong>: during backpropagation through many layers, the gradients become exponentially small, making it impossible to train deep networks effectively. Neurons in earlier layers receive almost no gradient signal and fail to learn.</p>
<p><strong>The ReLU Solution</strong></p>
<p>The <strong>Rectified Linear Unit (ReLU)</strong> activation function is defined as:</p>
<p><span id="eq-relu"><span class="math display">\[
f(x) = \max(0, x)
\tag{4}\]</span></span></p>
<p>Its key properties:</p>
<ul>
<li><strong>Non-saturating on the positive side</strong>: For <span class="math inline">\(x &gt; 0\)</span>, the gradient is always 1, which flows back undiminished during backpropagation.</li>
<li><strong>Sparsity</strong>: For <span class="math inline">\(x \leq 0\)</span>, the output is exactly 0, effectively silencing that neuron. In practice, approximately half of all neurons are inactive at any given time, which introduces a useful form of sparsity.</li>
<li><strong>Computational efficiency</strong>: Computing <span class="math inline">\(\max(0, x)\)</span> is trivially fast — no exponentials required.</li>
<li><strong>Fast convergence</strong>: Krizhevsky et al.&nbsp;demonstrated that networks trained with ReLUs reach a given training error rate 6× faster than equivalent networks with tanh units.</li>
</ul>
<p><strong>Dead ReLU Problem</strong></p>
<p>A known drawback of ReLU is that neurons can “die” — if the inputs to a ReLU neuron are always negative, it will output 0 for every input and its gradient will always be 0, meaning it will never update. Proper weight initialization and careful learning rate selection mitigate this. Later variants like Leaky ReLU, PReLU, and ELU address this issue more directly.</p>
<p>Despite this limitation, ReLU was transformative and remains the default activation function in most modern deep learning architectures.</p>
</section>
<section id="sec-gpu" class="level3">
<h3 class="anchored" data-anchor-id="sec-gpu" id="sec-gpu">GPU Training</h3>
<p>Training AlexNet on the ImageNet dataset took approximately <strong>5–6 days</strong> using two NVIDIA GTX 580 GPUs, each with 3 GB of VRAM. The computational requirements were estimated at roughly 1.5 billion multiply-add operations per forward pass — completely infeasible on contemporary CPUs.</p>
<p>The authors used NVIDIA’s <strong>CUDA</strong> platform to implement highly optimized GPU kernels for convolution, pooling, and matrix multiplication. Because a single GPU at the time didn’t have enough memory to hold all the parameters and activations, the network was split across two GPUs (see <a href="#sec-dualgpu" class="quarto-xref">Section&nbsp;1.9</a> for details on the dual-GPU split).</p>
<p>This work demonstrated conclusively that deep learning was not just a theoretical pursuit — it could be engineered efficiently at scale with commodity hardware. It catalyzed the entire field’s shift toward GPU-based training, spawning a massive ecosystem of deep learning frameworks (Theano, Caffe, MXNet, TensorFlow, PyTorch) optimized for GPU execution.</p>
</section>
<section id="sec-lrn" class="level3">
<h3 class="anchored" data-anchor-id="sec-lrn" id="sec-lrn">Local Response Normalization (LRN)</h3>
<p>Local Response Normalization is a form of lateral inhibition inspired by biological neuroscience — the idea that highly activated neurons suppress their neighbors, creating competition among features.</p>
<p>For activity <span class="math inline">\(a^i_{x,y}\)</span> of neuron <span class="math inline">\(i\)</span> at position <span class="math inline">\((x, y)\)</span>, the normalized response <span class="math inline">\(b^i_{x,y}\)</span> is:</p>
<p><span id="eq-lrn"><span class="math display">\[
b^i_{x,y} = \frac{a^i_{x,y}}{\left(k + \alpha \sum_{j=\max(0,\, i-n/2)}^{\min(N-1,\, i+n/2)} \left(a^j_{x,y}\right)^2\right)^\beta}
\tag{5}\]</span></span></p>
<p>Where:</p>
<ul>
<li><span class="math inline">\(N\)</span> is the total number of feature maps</li>
<li><span class="math inline">\(n\)</span> is the number of adjacent kernel maps over which normalization occurs</li>
<li><span class="math inline">\(k\)</span>, <span class="math inline">\(\alpha\)</span>, <span class="math inline">\(\beta\)</span>, <span class="math inline">\(n\)</span> are hyperparameters (set to <span class="math inline">\(k=2\)</span>, <span class="math inline">\(\alpha=10^{-4}\)</span>, <span class="math inline">\(\beta=0.75\)</span>, <span class="math inline">\(n=5\)</span> in AlexNet)</li>
</ul>
<p>LRN normalizes across adjacent feature maps at the same spatial location, effectively making features compete across channels. The authors reported that LRN reduced their top-1 error rate by 1.4% and top-5 error rate by 1.2% on the validation set.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Historical Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>LRN fell out of use relatively quickly after AlexNet. Subsequent architectures found it provided marginal or no benefit, and Batch Normalization <span class="citation" data-cites="ioffe2015batch">[@ioffe2015batch]</span> emerged as a far more effective normalization strategy. LRN is rarely used today.</p>
</div>
</div>
</section>
<section id="sec-pooling" class="level3">
<h3 class="anchored" data-anchor-id="sec-pooling" id="sec-pooling">Overlapping Max Pooling</h3>
<p>Traditional pooling in CNNs used non-overlapping windows (i.e., the stride equaled the window size). AlexNet used <strong>overlapping max pooling</strong> with a pool size of <strong>3×3</strong> and stride <strong>2</strong>, meaning adjacent pooling windows overlap by 1 pixel in each direction.</p>
<p>Max pooling selects the maximum activation within each window:</p>
<p><span id="eq-maxpool"><span class="math display">\[
y_{i,j} = \max_{(p,q) \in \text{window at } (i,j)} x_{p,q}
\tag{6}\]</span></span></p>
<p>The overlapping scheme provides several advantages:</p>
<ul>
<li><strong>Translation invariance</strong>: The network becomes slightly more robust to small translations of features.</li>
<li><strong>Better generalization</strong>: The authors reported that overlapping pooling reduced top-1 and top-5 error rates by 0.4% and 0.3% respectively compared to non-overlapping pooling (stride=2, size=2).</li>
<li><strong>Richer information flow</strong>: By overlapping, some information from each region is passed forward through multiple pooling windows, providing redundancy.</li>
</ul>
</section>
<section id="sec-dropout" class="level3">
<h3 class="anchored" data-anchor-id="sec-dropout" id="sec-dropout">Dropout Regularization</h3>
<p><strong>The Overfitting Problem</strong></p>
<p>With ~62 million parameters and “only” 1.2 million training images, overfitting was a severe risk. A model with this many parameters can easily memorize the training set rather than learning generalizable features.</p>
<p><strong>What is Dropout?</strong></p>
<p>Dropout <span class="citation" data-cites="srivastava2014dropout">[@srivastava2014dropout]</span> is a regularization technique that, during each forward pass in training, randomly “drops” (sets to zero) each neuron’s activation with probability <span class="math inline">\(p\)</span> (typically <span class="math inline">\(p=0.5\)</span>). The neurons that are dropped do not contribute to the forward pass and do not receive gradient updates in the backward pass.</p>
<p>Mathematically, for a layer with activations <span class="math inline">\(\mathbf{h}\)</span>, the dropout mask <span class="math inline">\(\mathbf{m} \sim \text{Bernoulli}(1-p)\)</span> gives:</p>
<p><span id="eq-dropout"><span class="math display">\[
\mathbf{h}_{\text{dropped}} = \mathbf{h} \odot \mathbf{m}
\tag{7}\]</span></span></p>
<p>During inference (test time), all neurons are active, but their outputs are scaled by <span class="math inline">\((1-p)\)</span> to compensate for the fact that during training only <span class="math inline">\((1-p)\)</span> fraction of neurons were active on average. (Equivalently, using “inverted dropout” — the standard modern approach — you scale by <span class="math inline">\(\frac{1}{1-p}\)</span> during training and do no scaling at test time.)</p>
<p><strong>Why Does Dropout Work?</strong></p>
<p>Dropout can be understood in several complementary ways:</p>
<ol type="1">
<li><p><strong>Ensemble approximation</strong>: Each forward pass uses a different random subset of neurons, effectively training an exponential number of different “thinned” networks. At test time, the full network approximates averaging over this ensemble.</p></li>
<li><p><strong>Prevention of co-adaptation</strong>: Neurons cannot rely on the presence of specific other neurons. This forces each neuron to learn features that are independently useful, rather than complex co-dependent features that are highly specific to the training data.</p></li>
<li><p><strong>Noise injection</strong>: Adding Bernoulli noise to activations acts as a regularizer that prevents overfitting.</p></li>
</ol>
<p>In AlexNet, dropout with <span class="math inline">\(p=0.5\)</span> was applied to the outputs of FC6 and FC7. The authors estimated it roughly doubled the training time to convergence (because of the noise introduced) but significantly reduced overfitting. They noted that without dropout, their model exhibited substantially worse generalization.</p>
</section>
<section id="sec-augmentation" class="level3">
<h3 class="anchored" data-anchor-id="sec-augmentation" id="sec-augmentation">Data Augmentation</h3>
<p>The second major technique used to combat overfitting was <strong>data augmentation</strong> — artificially increasing the effective size and diversity of the training dataset by applying label-preserving transformations to the images.</p>
<p>AlexNet used two forms of data augmentation:</p>
<p><strong>1. Random Cropping and Horizontal Flipping</strong></p>
<ul>
<li>Training images were resized to <strong>256×256</strong> pixels.</li>
<li>During training, <strong>227×227 patches</strong> were randomly extracted from random positions in the 256×256 image, along with their horizontal reflections. This gave <span class="math inline">\((256-227)^2 \times 2 = 841 \times 2 \approx 1{,}682\)</span> unique patches per image.</li>
<li>At test time, 5 fixed crops (four corners + center) plus their horizontal reflections (10 patches total) were extracted, and the softmax probabilities were averaged over all 10 predictions.</li>
</ul>
<p><strong>2. PCA Color Augmentation</strong></p>
<p>PCA was performed on the set of RGB pixel values across the training set. For each training image, random multiples of the found principal components were added to each pixel:</p>
<p><span id="eq-pca-aug"><span class="math display">\[
\Delta \mathbf{p} = [\mathbf{p}_1, \mathbf{p}_2, \mathbf{p}_3]\, [\alpha_1 \lambda_1,\; \alpha_2 \lambda_2,\; \alpha_3 \lambda_3]^\top
\tag{8}\]</span></span></p>
<p>where <span class="math inline">\(\mathbf{p}_i\)</span> and <span class="math inline">\(\lambda_i\)</span> are the eigenvectors and eigenvalues of the 3×3 RGB covariance matrix, and <span class="math inline">\(\alpha_i \sim \mathcal{N}(0, 0.1)\)</span> are random Gaussian scalings drawn once per training image.</p>
<p>This augmentation captures the property that object identity is approximately invariant to changes in illumination color and intensity. The authors reported it reduced top-1 error by over 1%.</p>
<hr>
</section>
</section>
<section id="training-details" class="level2">
<h2 class="anchored" data-anchor-id="training-details" id="training-details">Training Details</h2>
<p>The AlexNet training procedure was carefully tuned with several key hyperparameter choices.</p>
<p><strong>Optimizer</strong>: Stochastic Gradient Descent (SGD) with momentum and weight decay. The update rule was:</p>
<p><span id="eq-sgd-v"><span class="math display">\[
\mathbf{v}_{i+1} = 0.9\, \mathbf{v}_i - 0.0005\, \varepsilon\, \mathbf{w}_i - \varepsilon \left.\frac{\partial L}{\partial \mathbf{w}}\right|_{\mathbf{w}_i}
\tag{9}\]</span></span></p>
<p><span id="eq-sgd-w"><span class="math display">\[
\mathbf{w}_{i+1} = \mathbf{w}_i + \mathbf{v}_{i+1}
\tag{10}\]</span></span></p>
<p>where <span class="math inline">\(\mathbf{v}\)</span> is the velocity (momentum), <span class="math inline">\(\varepsilon\)</span> is the learning rate, and <span class="math inline">\(\partial L / \partial \mathbf{w}\)</span> is the gradient of the loss with respect to the weights.</p>
<p>Key training hyperparameters are summarized below:</p>
<div id="tbl-hparams" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-hparams-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: AlexNet training hyperparameters
</figcaption>
<div aria-describedby="tbl-hparams-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Hyperparameter</th>
<th>Value</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Optimizer</td>
<td>SGD</td>
</tr>
<tr class="even">
<td>Momentum</td>
<td>0.9</td>
</tr>
<tr class="odd">
<td>Weight decay</td>
<td>0.0005</td>
</tr>
<tr class="even">
<td>Initial learning rate</td>
<td>0.01</td>
</tr>
<tr class="odd">
<td>LR schedule</td>
<td>÷10 manually when val. error plateaus (3×)</td>
</tr>
<tr class="even">
<td>Final learning rate</td>
<td>0.00001</td>
</tr>
<tr class="odd">
<td>Batch size</td>
<td>128</td>
</tr>
<tr class="even">
<td>Epochs</td>
<td>~90</td>
</tr>
<tr class="odd">
<td>Weight init std</td>
<td>0.01 (Gaussian)</td>
</tr>
<tr class="even">
<td>Hardware</td>
<td>2× NVIDIA GTX 580 (3 GB VRAM)</td>
</tr>
<tr class="odd">
<td>Training time</td>
<td>~5–6 days</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<hr>
</section>
<section id="performance-and-results" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-results" id="performance-and-results">Performance and Results</h2>
<p>AlexNet’s results at ILSVRC-2012 were startling:</p>
<div id="tbl-results-2012" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-results-2012-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;3: ILSVRC-2012 results
</figcaption>
<div aria-describedby="tbl-results-2012-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Top-5 Error (%)</th>
<th>Top-1 Error (%)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>AlexNet (1 model)</td>
<td>18.2</td>
<td>~38.1</td>
</tr>
<tr class="even">
<td><strong>AlexNet (7-model ensemble)</strong></td>
<td><strong>15.3</strong></td>
<td><strong>36.7</strong></td>
</tr>
<tr class="odd">
<td>2nd place (non-CNN)</td>
<td>26.2</td>
<td>—</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>The top-5 error rate refers to the fraction of test images for which the correct class was not among the model’s 5 most confident predictions. AlexNet’s ensemble top-5 error of 15.3% compared to the non-neural second place at 26.2% was a decisive victory.</p>
<p><strong>On ILSVRC-2010</strong> (where the test set labels were available):</p>
<div id="tbl-results-2010" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-results-2010-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;4: ILSVRC-2010 results
</figcaption>
<div aria-describedby="tbl-results-2010-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Top-5 Error (%)</th>
<th>Top-1 Error (%)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>AlexNet</td>
<td>17.0</td>
<td>37.5</td>
</tr>
<tr class="even">
<td>Best ILSVRC-2010 winner</td>
<td>25.7</td>
<td>47.1</td>
</tr>
<tr class="odd">
<td>Dense SIFT + FV + SVM</td>
<td>26.2</td>
<td>—</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>These results established that deep CNNs had categorically surpassed traditional computer vision pipelines on large-scale image classification.</p>
<hr>
</section>
<section id="parameter-count-and-complexity" class="level2">
<h2 class="anchored" data-anchor-id="parameter-count-and-complexity" id="parameter-count-and-complexity">Parameter Count and Complexity</h2>
<p>A detailed breakdown of the trainable parameters by layer:</p>
<div id="tbl-params" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-params-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;5: Parameter counts by layer
</figcaption>
<div aria-describedby="tbl-params-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Layer</th>
<th>Weight Parameters</th>
<th>Bias</th>
<th>Total</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Conv1</td>
<td>11×11×3×96 = 34,848</td>
<td>96</td>
<td>34,944</td>
</tr>
<tr class="even">
<td>Conv2</td>
<td>5×5×96×256 = 614,400</td>
<td>256</td>
<td>614,656</td>
</tr>
<tr class="odd">
<td>Conv3</td>
<td>3×3×256×384 = 884,736</td>
<td>384</td>
<td>885,120</td>
</tr>
<tr class="even">
<td>Conv4</td>
<td>3×3×192×384 = 663,552</td>
<td>384</td>
<td>663,936</td>
</tr>
<tr class="odd">
<td>Conv5</td>
<td>3×3×192×256 = 442,368</td>
<td>256</td>
<td>442,624</td>
</tr>
<tr class="even">
<td>FC6</td>
<td>9216×4096 = 37,748,736</td>
<td>4,096</td>
<td>37,752,832</td>
</tr>
<tr class="odd">
<td>FC7</td>
<td>4096×4096 = 16,777,216</td>
<td>4,096</td>
<td>16,781,312</td>
</tr>
<tr class="even">
<td>FC8</td>
<td>4096×1000 = 4,096,000</td>
<td>1,000</td>
<td>4,097,000</td>
</tr>
<tr class="odd">
<td><strong>Total</strong></td>
<td>—</td>
<td>—</td>
<td><strong>~62.3M</strong></td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Important
</div>
</div>
<div class="callout-body-container callout-body">
<p>The vast majority of parameters (~94%) reside in the fully connected layers, particularly FC6 (~60% of all parameters). This observation motivated later architectures like GoogLeNet and ResNet to use global average pooling instead of large FC heads, dramatically reducing parameter counts.</p>
</div>
</div>
<hr>
</section>
<section id="sec-dualgpu" class="level2">
<h2 class="anchored" data-anchor-id="sec-dualgpu" id="sec-dualgpu">Dual-GPU Split Architecture</h2>
<p>Due to VRAM constraints (3 GB per GPU in 2012), AlexNet was designed to run across two GPUs in parallel. The network was split “horizontally” — half the neurons on each GPU. This split was managed carefully across layers:</p>
<ul>
<li><strong>Conv1, Conv2, Conv5, FC6, FC7, FC8</strong>: Each GPU processes half the feature maps independently. There is no cross-GPU communication within these layers.</li>
<li><strong>Conv3</strong>: Both GPUs share their feature maps — the input to Conv3 on each GPU is the full output from both GPU streams in Conv2. This cross-GPU communication allows mixing of learned representations.</li>
</ul>
<p>In practice, this architecture is a form of <strong>model parallelism</strong>. The GPU-to-GPU communication occurred via direct transfers over the PCIe bus, which was a non-trivial engineering challenge at the time.</p>
<p>Interestingly, the two GPU streams tend to specialize: one GPU learns primarily <strong>color-agnostic</strong> (edge, texture) features, while the other learns <strong>color-selective</strong> (chromatic) features. This specialization emerges purely from the training dynamics — it is not explicitly programmed.</p>
<hr>
</section>
<section id="strengths-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="strengths-and-limitations" id="strengths-and-limitations">Strengths and Limitations</h2>
<section id="strengths" class="level3">
<h3 class="anchored" data-anchor-id="strengths" id="strengths">Strengths</h3>
<ul>
<li><strong>End-to-end learning</strong>: Unlike classical pipelines, AlexNet learns features directly from raw pixels, requiring no hand-crafted feature engineering.</li>
<li><strong>Scalability</strong>: The architecture clearly benefits from more data and more computation — a property that subsequent years of research would confirm as a general principle.</li>
<li><strong>Transfer learning</strong>: Features learned by AlexNet generalize remarkably well to other visual tasks. Fine-tuning a pretrained AlexNet (or just using its FC6/FC7 activations as features) became a standard baseline for many vision tasks for years.</li>
<li><strong>Practical training innovations</strong>: ReLU, dropout, data augmentation, and GPU training are all practically important techniques that were packaged into a single working system.</li>
</ul>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<ul>
<li><strong>Extremely large FC layers</strong>: The three fully connected layers consume ~95% of the parameters while contributing relatively little to representational power. This is computationally and memory-inefficient.</li>
<li><strong>Fixed input size</strong>: The FC layers require a fixed-size input, which means all images must be resized to 227×227. This is inflexible for tasks requiring variable-resolution inputs.</li>
<li><strong>Large first-layer kernels</strong>: The 11×11 kernel with stride 4 in Conv1 is very large and can miss fine-grained details at the first convolutional stage.</li>
<li><strong>Local Response Normalization</strong>: LRN adds complexity and was later found to provide minimal benefit.</li>
<li><strong>No skip connections</strong>: AlexNet’s sequential stack makes it susceptible to the vanishing gradient problem in deeper variants. Residual connections <span class="citation" data-cites="he2016deep">[@he2016deep]</span> solve this.</li>
<li><strong>Relatively shallow</strong>: By modern standards, 8 layers is shallow. Networks today routinely have hundreds of layers.</li>
<li><strong>Dual-GPU complexity</strong>: The split architecture added engineering complexity for marginal benefit; modern hardware easily fits the entire network in VRAM.</li>
</ul>
<hr>
</section>
</section>
<section id="legacy-and-influence" class="level2">
<h2 class="anchored" data-anchor-id="legacy-and-influence" id="legacy-and-influence">Legacy and Influence</h2>
<p>AlexNet’s influence on the field of machine learning and computer vision cannot be overstated. It initiated a paradigm shift that continues to this day.</p>
<section id="direct-successors" class="level3">
<h3 class="anchored" data-anchor-id="direct-successors" id="direct-successors">Direct Successors</h3>
<ul>
<li><strong>ZFNet (2013)</strong> <span class="citation" data-cites="zeiler2014visualizing">[@zeiler2014visualizing]</span>: Matthew Zeiler and Rob Fergus made incremental improvements to AlexNet (smaller first-layer kernel, modified strides), winning ILSVRC-2013 with a top-5 error of ~11.7%.</li>
<li><strong>VGGNet (2014)</strong> <span class="citation" data-cites="simonyan2015very">[@simonyan2015very]</span>: Simonyan and Zisserman replaced all large kernels with stacks of small 3×3 convolutions, showing that depth was the key to performance. Top-5 error: ~7.3%.</li>
<li><strong>GoogLeNet/Inception (2014)</strong>: Szegedy et al.&nbsp;introduced Inception modules and global average pooling, massively reducing parameter count while improving accuracy. Top-5 error: ~6.7%.</li>
<li><strong>ResNet (2015)</strong> <span class="citation" data-cites="he2016deep">[@he2016deep]</span>: He et al.&nbsp;introduced residual (skip) connections, enabling training of networks with 100s of layers. Surpassed human-level performance on ImageNet with ~3.6% top-5 error.</li>
</ul>
</section>
<section id="broader-impact" class="level3">
<h3 class="anchored" data-anchor-id="broader-impact" id="broader-impact">Broader Impact</h3>
<ul>
<li><strong>Transfer learning revolution</strong>: AlexNet popularized the idea of pretraining a CNN on ImageNet and fine-tuning it for downstream tasks. This paradigm — which evolved into the massive pretrained models of today (GPT, BERT, ViT) — is arguably AlexNet’s most lasting contribution.</li>
<li><strong>Deep learning hardware ecosystem</strong>: AlexNet’s success accelerated development of GPU hardware and software specifically for deep learning. NVIDIA’s revenue from data center GPUs grew from near-zero in 2012 to tens of billions of dollars annually by the early 2020s.</li>
<li><strong>Benchmark culture</strong>: The success of ILSVRC popularized the use of large-scale benchmarks to drive research progress. This benchmark-driven culture (for better or worse) shapes ML research to this day.</li>
<li><strong>Democratization</strong>: AlexNet demonstrated that groundbreaking results in AI could be achieved with commodity hardware (consumer GPUs), lowering the barrier to entry for researchers worldwide.</li>
<li><strong>Industry transformation</strong>: The dramatic demonstration of deep learning’s potential at ILSVRC-2012 triggered massive investment from tech companies, reshaping research labs at Google, Facebook, Microsoft, Baidu, and others within months.</li>
</ul>
<hr>
</section>
</section>
<section id="implementing-alexnet-in-pytorch" class="level2">
<h2 class="anchored" data-anchor-id="implementing-alexnet-in-pytorch" id="implementing-alexnet-in-pytorch">Implementing AlexNet in PyTorch</h2>
<p>Below is a complete, annotated PyTorch implementation of AlexNet.</p>
<section id="model-definition" class="level3">
<h3 class="anchored" data-anchor-id="model-definition" id="model-definition">Model Definition</h3>
<div id="b9986d3f" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AlexNet(nn.Module):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co">    AlexNet: Krizhevsky, Sutskever, Hinton (2012).</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="co">    Architecture:</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="co">        5 convolutional layers followed by 3 fully connected layers.</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a><span class="co">        Uses ReLU activations, overlapping max pooling, and dropout.</span></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a><span class="co">        num_classes (int): Number of output classes. Default: 1000 (ImageNet).</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a><span class="co">        dropout (float): Dropout probability in FC layers. Default: 0.5.</span></span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1000</span>, dropout: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.5</span>) <span class="op">-&gt;</span> <span class="va">None</span>:</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(AlexNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feature extractor (convolutional layers)</span></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Conv1: 227x227x3 -&gt; 55x55x96 -&gt; (pool) -&gt; 27x27x96</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, <span class="dv">96</span>, kernel_size<span class="op">=</span><span class="dv">11</span>, stride<span class="op">=</span><span class="dv">4</span>, padding<span class="op">=</span><span class="dv">0</span>),</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Conv2: 27x27x96 -&gt; 27x27x256 -&gt; (pool) -&gt; 13x13x256</span></span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">96</span>, <span class="dv">256</span>, kernel_size<span class="op">=</span><span class="dv">5</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Conv3: 13x13x256 -&gt; 13x13x384  (no pooling)</span></span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">256</span>, <span class="dv">384</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Conv4: 13x13x384 -&gt; 13x13x384  (no pooling)</span></span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">384</span>, <span class="dv">384</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Conv5: 13x13x384 -&gt; 13x13x256 -&gt; (pool) -&gt; 6x6x256</span></span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">384</span>, <span class="dv">256</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Adaptive average pooling (allows flexible input resolution)</span></span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.avgpool <span class="op">=</span> nn.AdaptiveAvgPool2d((<span class="dv">6</span>, <span class="dv">6</span>))</span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classifier head (fully connected layers)</span></span>
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(p<span class="op">=</span>dropout),</span>
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">256</span> <span class="op">*</span> <span class="dv">6</span> <span class="op">*</span> <span class="dv">6</span>, <span class="dv">4096</span>),   <span class="co"># FC6</span></span>
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(p<span class="op">=</span>dropout),</span>
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">4096</span>, <span class="dv">4096</span>),            <span class="co"># FC7</span></span>
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">4096</span>, num_classes),     <span class="co"># FC8</span></span>
<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._initialize_weights()</span>
<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x: torch.Tensor) <span class="op">-&gt;</span> torch.Tensor:</span>
<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.avgpool(x)</span>
<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _initialize_weights(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="va">None</span>:</span>
<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize weights following the original paper."""</span></span>
<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> m <span class="kw">in</span> <span class="va">self</span>.modules():</span>
<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a>                nn.init.normal_(m.weight, mean<span class="op">=</span><span class="dv">0</span>, std<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> m.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a>                    nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.Linear):</span>
<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a>                nn.init.normal_(m.weight, mean<span class="op">=</span><span class="dv">0</span>, std<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.bias, <span class="dv">1</span>)  <span class="co"># Positive bias for ReLU</span></span>
<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_feature_vector(</span>
<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>, x: torch.Tensor, layer: <span class="bu">str</span> <span class="op">=</span> <span class="st">"fc7"</span></span>
<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> torch.Tensor:</span>
<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a><span class="co">        Extract feature vectors from FC6 or FC7 for transfer learning.</span></span>
<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a><span class="co">            x: Input tensor of shape (batch, 3, H, W)</span></span>
<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a><span class="co">            layer: 'fc6' or 'fc7'</span></span>
<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a><span class="co">            Feature vector of shape (batch, 4096)</span></span>
<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.avgpool(x)</span>
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">0</span>](x)  <span class="co"># Dropout</span></span>
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">1</span>](x)  <span class="co"># Linear (FC6)</span></span>
<span id="cb1-99"><a href="#cb1-99" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">2</span>](x)  <span class="co"># ReLU</span></span>
<span id="cb1-100"><a href="#cb1-100" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> layer <span class="op">==</span> <span class="st">"fc6"</span>:</span>
<span id="cb1-101"><a href="#cb1-101" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> x</span>
<span id="cb1-102"><a href="#cb1-102" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">3</span>](x)  <span class="co"># Dropout</span></span>
<span id="cb1-103"><a href="#cb1-103" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">4</span>](x)  <span class="co"># Linear (FC7)</span></span>
<span id="cb1-104"><a href="#cb1-104" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier[<span class="dv">5</span>](x)  <span class="co"># ReLU</span></span>
<span id="cb1-105"><a href="#cb1-105" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</div>
</section>
<section id="example-usage" class="level3">
<h3 class="anchored" data-anchor-id="example-usage" id="example-usage">Example Usage</h3>
<div id="ce787328" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Standard ImageNet preprocessing</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>preprocess <span class="op">=</span> transforms.Compose([</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    transforms.CenterCrop(<span class="dv">227</span>),</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>],</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    ),</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AlexNet(num_classes<span class="op">=</span><span class="dv">1000</span>)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Parameter counts</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>total_params <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Total parameters: </span><span class="sc">{</span>total_params<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Forward pass</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">4</span>, <span class="dv">3</span>, <span class="dv">227</span>, <span class="dv">227</span>)</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> model(dummy_input)</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Input shape:  </span><span class="sc">{</span>dummy_input<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Output shape: </span><span class="sc">{</span>output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Transfer learning features</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>features <span class="op">=</span> model.get_feature_vector(dummy_input, layer<span class="op">=</span><span class="st">"fc7"</span>)</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"FC7 feature vector shape: </span><span class="sc">{</span>features<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="training-loop" class="level3">
<h3 class="anchored" data-anchor-id="training-loop" id="training-loop">Training Loop</h3>
<div id="7832dca8" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_alexnet(model, train_loader, val_loader, num_epochs<span class="op">=</span><span class="dv">90</span>):</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> model.to(device)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.SGD(</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        model.parameters(),</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        lr<span class="op">=</span><span class="fl">0.01</span>,</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        momentum<span class="op">=</span><span class="fl">0.9</span>,</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        weight_decay<span class="op">=</span><span class="fl">5e-4</span>,</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> optim.lr_scheduler.ReduceLROnPlateau(</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        optimizer, mode<span class="op">=</span><span class="st">"min"</span>, factor<span class="op">=</span><span class="fl">0.1</span>, patience<span class="op">=</span><span class="dv">5</span>, verbose<span class="op">=</span><span class="va">True</span></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        model.train()</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> images, labels <span class="kw">in</span> train_loader:</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>            images, labels <span class="op">=</span> images.to(device), labels.to(device)</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(images)</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        val_loss <span class="op">=</span> validate(model, val_loader, criterion, device)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        scheduler.step(val_loss)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"Epoch [</span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>num_epochs<span class="sc">}</span><span class="ss">] | "</span></span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"Train Loss: </span><span class="sc">{</span>running_loss<span class="op">/</span><span class="bu">len</span>(train_loader)<span class="sc">:.4f}</span><span class="ss"> | "</span></span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">"</span></span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate(model, val_loader, criterion, device):</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>    total_loss, correct_top1, correct_top5, total <span class="op">=</span> <span class="fl">0.0</span>, <span class="dv">0</span>, <span class="dv">0</span>, <span class="dv">0</span></span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> images, labels <span class="kw">in</span> val_loader:</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>            images, labels <span class="op">=</span> images.to(device), labels.to(device)</span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(images)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> criterion(outputs, labels).item()</span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>            correct_top1 <span class="op">+=</span> predicted.eq(labels).<span class="bu">sum</span>().item()</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>            _, top5_preds <span class="op">=</span> outputs.topk(<span class="dv">5</span>, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>            correct_top5 <span class="op">+=</span> (</span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>                top5_preds.eq(labels.view(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>)).<span class="bu">any</span>(dim<span class="op">=</span><span class="dv">1</span>).<span class="bu">sum</span>().item()</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Top-1 Accuracy: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct_top1<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Top-5 Accuracy: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct_top5<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(val_loader)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="comparison-with-successor-architectures" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-successor-architectures" id="comparison-with-successor-architectures">Comparison with Successor Architectures</h2>
<div id="tbl-comparison" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;6: AlexNet vs.&nbsp;successor architectures on ImageNet
</figcaption>
<div aria-describedby="tbl-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 19%">
<col style="width: 6%">
<col style="width: 7%">
<col style="width: 14%">
<col style="width: 8%">
<col style="width: 42%">
</colgroup>
<thead>
<tr class="header">
<th>Architecture</th>
<th>Year</th>
<th>Depth</th>
<th>Top-5 Error</th>
<th>Params</th>
<th>Key Innovation</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>AlexNet</td>
<td>2012</td>
<td>8</td>
<td>15.3%</td>
<td>62.3M</td>
<td>ReLU, GPU, dropout at scale</td>
</tr>
<tr class="even">
<td>ZFNet</td>
<td>2013</td>
<td>8</td>
<td>11.7%</td>
<td>~62M</td>
<td>Visualization, architectural tuning</td>
</tr>
<tr class="odd">
<td>VGGNet-16</td>
<td>2014</td>
<td>16</td>
<td>7.3%</td>
<td>138M</td>
<td>Deep stacks of small 3×3 kernels</td>
</tr>
<tr class="even">
<td>VGGNet-19</td>
<td>2014</td>
<td>19</td>
<td>7.3%</td>
<td>144M</td>
<td>Even deeper stack of 3×3 kernels</td>
</tr>
<tr class="odd">
<td>GoogLeNet</td>
<td>2014</td>
<td>22</td>
<td>6.7%</td>
<td>6.8M</td>
<td>Inception modules, global avg pool</td>
</tr>
<tr class="even">
<td>ResNet-50</td>
<td>2015</td>
<td>50</td>
<td>5.2%</td>
<td>25.6M</td>
<td>Residual connections</td>
</tr>
<tr class="odd">
<td>ResNet-152</td>
<td>2015</td>
<td>152</td>
<td>3.6%</td>
<td>60.2M</td>
<td>Very deep residual networks</td>
</tr>
<tr class="even">
<td>DenseNet-201</td>
<td>2017</td>
<td>201</td>
<td>~5.5%</td>
<td>20M</td>
<td>Dense connections between all layers</td>
</tr>
<tr class="odd">
<td>EfficientNet-B7</td>
<td>2019</td>
<td>813</td>
<td>~2.9%</td>
<td>66M</td>
<td>Compound scaling</td>
</tr>
<tr class="even">
<td>ViT-L/16</td>
<td>2021</td>
<td>—</td>
<td>~1.7%</td>
<td>307M</td>
<td>Vision Transformer, attention-only</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>This table illustrates the broad trajectory of the field since 2012: deeper networks, smaller parameter counts (relative to depth), and dramatically lower error rates. AlexNet’s 15.3% seems primitive compared to modern architectures, but its contribution lies not in being the current state of the art (it isn’t) but in being the existence proof that launched everything that followed.</p>
<hr>
</section>
<section id="summary" class="level2">
<h2 class="anchored" data-anchor-id="summary" id="summary">Summary</h2>
<p>AlexNet did not just win a competition. It changed what researchers, engineers, and technology companies believed was possible with artificial intelligence. The era of hand-crafted features ended in September 2012. The era of deep learning began. AlexNet is simultaneously a historical artifact and a living lesson in how to solve hard problems. Its contributions can be distilled as follows:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>What AlexNet Got Right
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Deep, hierarchical feature learning from raw pixels works better than hand-crafted features at scale.</li>
<li>ReLU activations are essential for training deep networks efficiently.</li>
<li>Regularization (dropout + data augmentation) is essential to generalize from millions of examples.</li>
<li>GPU computation is essential for making large-scale deep learning feasible.</li>
<li>The right combination of architecture, regularization, and hardware can produce qualitatively transformative results.</li>
</ul>
</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>What Has Been Superseded
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Local Response Normalization → replaced by Batch Normalization.</li>
<li>Large fully connected heads → replaced by global average pooling.</li>
<li>Large first-layer kernels → replaced by stacks of small 3×3 kernels.</li>
<li>Shallow depth (8 layers) → networks now routinely use 50–1,000+ layers.</li>
<li>Dual-GPU model parallelism → unnecessary on modern hardware.</li>
</ul>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Training Computer Vision Models and Running Them with ONNX Runtime]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/onnx-cv/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/onnx-cv/</guid>
      <pubDate>Tue, 19 May 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="training-computer-vision-models-and-running-them-with-onnx-runtime" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/onnx-cv/onnx.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Computer vision is one of the most vibrant areas of applied machine learning. Whether you are building an image classifier, a real-time object detector, a segmentation model, or a pose estimator, the challenge after training is always the same: how do you efficiently deploy the model across diverse hardware—cloud GPUs, edge CPUs, mobile SoCs, FPGAs, or browsers—without rewriting your inference code for every target?</p>
<p><strong>ONNX (Open Neural Network Exchange)</strong> and <strong>ONNX Runtime</strong> solve this problem. ONNX provides a standardized intermediate representation for neural network computation graphs, while ONNX Runtime is a high-performance inference engine that executes those graphs across a wide range of hardware backends.</p>
<p>This guide walks you through the entire lifecycle: training a vision model in PyTorch or TensorFlow, exporting it to the ONNX format, validating and optimizing the exported graph, running production-grade inference with ONNX Runtime, and deploying to various targets. By the end, you will have a reliable, reproducible workflow you can apply to nearly any computer vision project.</p>
<hr>
</section>
<section id="what-is-onnx" class="level2">
<h2 class="anchored" data-anchor-id="what-is-onnx" id="what-is-onnx">What is ONNX?</h2>
<p>ONNX is an open standard created jointly by Microsoft and Facebook (Meta) in 2017, now maintained by the Linux Foundation under the <strong>ONNX community</strong>. Its core purpose is to allow models trained in one framework to be run in another.</p>
<p>At its heart, an ONNX model is a <strong>protobuf-serialized computation graph</strong>. Each node in the graph corresponds to a mathematical operator (Conv, BatchNormalization, Relu, MaxPool, Gemm, etc.), and edges are typed tensors that flow between them.</p>
<p>Key concepts:</p>
<ul>
<li><strong>Opset version</strong>: ONNX defines its operators in versioned <em>opsets</em>. As of 2025, opset 19–21 are the most current. Always export with the highest opset your runtime supports to access the latest operator set.</li>
<li><strong>IR version</strong>: The overall file format version, independent of opset.</li>
<li><strong>Initializers</strong>: Constant tensors (model weights) stored inside the graph.</li>
<li><strong>Dynamic shapes</strong>: Axes can be marked symbolic (e.g., <code>batch_size</code>, <code>height</code>) to allow variable-size inputs at runtime.</li>
</ul>
<hr>
</section>
<section id="prerequisites-and-environment-setup" class="level2">
<h2 class="anchored" data-anchor-id="prerequisites-and-environment-setup" id="prerequisites-and-environment-setup">Prerequisites and Environment Setup</h2>
<section id="python-environment" class="level3">
<h3 class="anchored" data-anchor-id="python-environment" id="python-environment">Python Environment</h3>
<p>It is best practice to use a virtual environment or conda environment per project.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Using conda</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> create <span class="at">-n</span> cv-onnx python=3.11</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> activate cv-onnx</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Or using venv</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> venv .venv</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="bu">source</span> .venv/bin/activate   <span class="co"># Linux / macOS</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="ex">.venv\Scripts\activate</span>      <span class="co"># Windows</span></span></code></pre></div></div>
</section>
<section id="core-packages" class="level3">
<h3 class="anchored" data-anchor-id="core-packages" id="core-packages">Core Packages</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Deep learning framework (choose one or both)</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision torchaudio <span class="at">--index-url</span> https://download.pytorch.org/whl/cu121</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install tensorflow</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># ONNX core</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnx onnxscript</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="co"># ONNX Runtime — CPU only</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnxruntime</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="co"># ONNX Runtime — GPU (CUDA 12.x)</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnxruntime-gpu</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimization and quantization tools</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnxruntime-tools</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnxoptimizer</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Visualization</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install netron   <span class="co"># or open https://netron.app in a browser</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Utilities</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install numpy pillow opencv-python-headless matplotlib</span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p><code>onnxruntime</code> and <code>onnxruntime-gpu</code> are mutually exclusive packages. Install only one per environment. The GPU package automatically falls back to CPU when CUDA is unavailable.</p>
</div>
</div>
</section>
<section id="verifying-the-installation" class="level3">
<h3 class="anchored" data-anchor-id="verifying-the-installation" id="verifying-the-installation">Verifying the Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"ONNX version:           </span><span class="sc">{</span>onnx<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"ONNX Runtime version:   </span><span class="sc">{</span>ort<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"PyTorch version:        </span><span class="sc">{</span>torch<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Available ORT providers:</span><span class="sc">{</span>ort<span class="sc">.</span>get_available_providers()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="understanding-the-onnx-ecosystem" class="level2">
<h2 class="anchored" data-anchor-id="understanding-the-onnx-ecosystem" id="understanding-the-onnx-ecosystem">Understanding the ONNX Ecosystem</h2>
<p>Before diving into code, it helps to understand how the different components fit together. Training frameworks export to the ONNX intermediate representation, which is then consumed by ONNX Runtime or converted into other deployment backends.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">%%{init: {
  "theme": "base"
}}%%

flowchart LR
    subgraph TRAIN["Training Frameworks"]
        direction TB
        PT["PyTorch"]
        TF["TensorFlow / Keras"]
        SK["scikit-learn"]
        JAX["JAX / Flax"]
    end

    subgraph ONNX_CORE["ONNX Ecosystem"]
        direction TB
        MODEL["ONNX Model (.onnx)"]
        OPT["Optimizer / Quantizer"]
        MODEL -- "graph opt + quantize" --&gt; OPT
        OPT -- "optimized model" --&gt; MODEL
    end

    subgraph DEPLOY["Deployment Targets"]
        direction TB
        ORT["ONNX Runtime Python · C++ · C# · Java"]
        WEB["ONNX Runtime Web WASM · WebGL"]
        TRT["TensorRT (NVIDIA)"]
        OVI["OpenVINO (Intel)"]
        CML["CoreML (Apple)"]
        WML["Windows ML"]
    end

    PT  -- export --&gt; MODEL
    TF  -- export --&gt; MODEL
    SK  -- export --&gt; MODEL
    JAX -- export --&gt; MODEL

    MODEL --&gt; ORT
    MODEL --&gt; WEB
    MODEL --&gt; TRT
    MODEL --&gt; OVI
    MODEL --&gt; CML
    MODEL --&gt; WML
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<p>The <strong>ONNX Runtime (ORT)</strong> sits at the center of the deployment story. It supports multiple <strong>Execution Providers (EPs)</strong>:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Execution Provider</th>
<th>Hardware Target</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>CPUExecutionProvider</code></td>
<td>Any x86/ARM CPU</td>
</tr>
<tr class="even">
<td><code>CUDAExecutionProvider</code></td>
<td>NVIDIA GPUs (CUDA)</td>
</tr>
<tr class="odd">
<td><code>TensorrtExecutionProvider</code></td>
<td>NVIDIA GPUs (TensorRT)</td>
</tr>
<tr class="even">
<td><code>ROCMExecutionProvider</code></td>
<td>AMD GPUs</td>
</tr>
<tr class="odd">
<td><code>CoreMLExecutionProvider</code></td>
<td>Apple Silicon / iOS</td>
</tr>
<tr class="even">
<td><code>DirectMLExecutionProvider</code></td>
<td>Windows GPU via DirectML</td>
</tr>
<tr class="odd">
<td><code>OpenVINOExecutionProvider</code></td>
<td>Intel CPUs, iGPUs, VPUs</td>
</tr>
<tr class="even">
<td><code>QNNExecutionProvider</code></td>
<td>Qualcomm NPU</td>
</tr>
</tbody>
</table>
<hr>
</section>
<section id="training-a-computer-vision-model" class="level2">
<h2 class="anchored" data-anchor-id="training-a-computer-vision-model" id="training-a-computer-vision-model">Training a Computer Vision Model</h2>
<section id="pytorch-workflow" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-workflow" id="pytorch-workflow">PyTorch Workflow</h3>
<p>We will train a simple ResNet-18-based image classifier on CIFAR-10 as a concrete example. The principles generalize to any architecture.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># train_cifar_pytorch.py</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms, models</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Hyperparameters</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>BATCH_SIZE   <span class="op">=</span> <span class="dv">128</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>NUM_EPOCHS   <span class="op">=</span> <span class="dv">20</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>LEARNING_RATE <span class="op">=</span> <span class="fl">1e-3</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>NUM_CLASSES  <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>DEVICE       <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Data pipeline</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>train_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    transforms.RandomHorizontalFlip(),</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    transforms.RandomCrop(<span class="dv">32</span>, padding<span class="op">=</span><span class="dv">4</span>),</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>],</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>                         std<span class="op">=</span> [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>val_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>],</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>                         std<span class="op">=</span> [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>train_dataset <span class="op">=</span> datasets.CIFAR10(root<span class="op">=</span><span class="st">"./data"</span>, train<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>                                  download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>train_transform)</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>val_dataset   <span class="op">=</span> datasets.CIFAR10(root<span class="op">=</span><span class="st">"./data"</span>, train<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>                                  download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>val_transform)</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>BATCH_SIZE,</span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>                          shuffle<span class="op">=</span><span class="va">True</span>,  num_workers<span class="op">=</span><span class="dv">4</span>, pin_memory<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>val_loader   <span class="op">=</span> DataLoader(val_dataset,   batch_size<span class="op">=</span>BATCH_SIZE,</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>                          shuffle<span class="op">=</span><span class="va">False</span>, num_workers<span class="op">=</span><span class="dv">4</span>, pin_memory<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Model definition</span></span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a><span class="co"># ResNet-18 adapted for CIFAR-10's 32×32 inputs</span></span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.resnet18(weights<span class="op">=</span><span class="va">None</span>)</span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a><span class="co"># Replace the first conv to handle small spatial dimensions</span></span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>model.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>model.maxpool <span class="op">=</span> nn.Identity()   <span class="co"># remove aggressive spatial downsampling</span></span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>model.fc <span class="op">=</span> nn.Linear(model.fc.in_features, NUM_CLASSES)</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(DEVICE)</span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. Loss, optimizer, scheduler</span></span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss(label_smoothing<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> optim.AdamW(model.parameters(), lr<span class="op">=</span>LEARNING_RATE, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max<span class="op">=</span>NUM_EPOCHS)</span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a><span class="co"># 5. Training loop</span></span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_one_epoch(model, loader, criterion, optimizer, device):</span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>    running_loss, correct, total <span class="op">=</span> <span class="fl">0.0</span>, <span class="dv">0</span>, <span class="dv">0</span></span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> images, labels <span class="kw">in</span> loader:</span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>        images, labels <span class="op">=</span> images.to(device), labels.to(device)</span>
<span id="cb4-71"><a href="#cb4-71" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad(set_to_none<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-72"><a href="#cb4-72" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(images)</span>
<span id="cb4-73"><a href="#cb4-73" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb4-74"><a href="#cb4-74" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb4-75"><a href="#cb4-75" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb4-76"><a href="#cb4-76" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">+=</span> loss.item() <span class="op">*</span> images.size(<span class="dv">0</span>)</span>
<span id="cb4-77"><a href="#cb4-77" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">+=</span> (outputs.argmax(<span class="dv">1</span>) <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb4-78"><a href="#cb4-78" aria-hidden="true" tabindex="-1"></a>        total   <span class="op">+=</span> images.size(<span class="dv">0</span>)</span>
<span id="cb4-79"><a href="#cb4-79" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> running_loss <span class="op">/</span> total, correct <span class="op">/</span> total</span>
<span id="cb4-80"><a href="#cb4-80" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-81"><a href="#cb4-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-82"><a href="#cb4-82" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate(model, loader, criterion, device):</span>
<span id="cb4-83"><a href="#cb4-83" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb4-84"><a href="#cb4-84" aria-hidden="true" tabindex="-1"></a>    running_loss, correct, total <span class="op">=</span> <span class="fl">0.0</span>, <span class="dv">0</span>, <span class="dv">0</span></span>
<span id="cb4-85"><a href="#cb4-85" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb4-86"><a href="#cb4-86" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> images, labels <span class="kw">in</span> loader:</span>
<span id="cb4-87"><a href="#cb4-87" aria-hidden="true" tabindex="-1"></a>            images, labels <span class="op">=</span> images.to(device), labels.to(device)</span>
<span id="cb4-88"><a href="#cb4-88" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(images)</span>
<span id="cb4-89"><a href="#cb4-89" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb4-90"><a href="#cb4-90" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item() <span class="op">*</span> images.size(<span class="dv">0</span>)</span>
<span id="cb4-91"><a href="#cb4-91" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (outputs.argmax(<span class="dv">1</span>) <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb4-92"><a href="#cb4-92" aria-hidden="true" tabindex="-1"></a>            total   <span class="op">+=</span> images.size(<span class="dv">0</span>)</span>
<span id="cb4-93"><a href="#cb4-93" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> running_loss <span class="op">/</span> total, correct <span class="op">/</span> total</span>
<span id="cb4-94"><a href="#cb4-94" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-95"><a href="#cb4-95" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-96"><a href="#cb4-96" aria-hidden="true" tabindex="-1"></a>best_val_acc <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb4-97"><a href="#cb4-97" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NUM_EPOCHS <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb4-98"><a href="#cb4-98" aria-hidden="true" tabindex="-1"></a>    train_loss, train_acc <span class="op">=</span> train_one_epoch(model, train_loader,</span>
<span id="cb4-99"><a href="#cb4-99" aria-hidden="true" tabindex="-1"></a>                                            criterion, optimizer, DEVICE)</span>
<span id="cb4-100"><a href="#cb4-100" aria-hidden="true" tabindex="-1"></a>    val_loss,   val_acc   <span class="op">=</span> evaluate(model, val_loader, criterion, DEVICE)</span>
<span id="cb4-101"><a href="#cb4-101" aria-hidden="true" tabindex="-1"></a>    scheduler.step()</span>
<span id="cb4-102"><a href="#cb4-102" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-103"><a href="#cb4-103" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">:02d}</span><span class="ss">/</span><span class="sc">{</span>NUM_EPOCHS<span class="sc">}</span><span class="ss"> | "</span></span>
<span id="cb4-104"><a href="#cb4-104" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss"> Acc: </span><span class="sc">{</span>train_acc<span class="sc">:.4f}</span><span class="ss"> | "</span></span>
<span id="cb4-105"><a href="#cb4-105" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss"> Acc: </span><span class="sc">{</span>val_acc<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb4-106"><a href="#cb4-106" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-107"><a href="#cb4-107" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> val_acc <span class="op">&gt;</span> best_val_acc:</span>
<span id="cb4-108"><a href="#cb4-108" aria-hidden="true" tabindex="-1"></a>        best_val_acc <span class="op">=</span> val_acc</span>
<span id="cb4-109"><a href="#cb4-109" aria-hidden="true" tabindex="-1"></a>        torch.save(model.state_dict(), <span class="st">"best_resnet18_cifar10.pth"</span>)</span>
<span id="cb4-110"><a href="#cb4-110" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-111"><a href="#cb4-111" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f" Best validation accuracy: </span><span class="sc">{</span>best_val_acc<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="tensorflow-keras-workflow" class="level3">
<h3 class="anchored" data-anchor-id="tensorflow-keras-workflow" id="tensorflow-keras-workflow">TensorFlow / Keras Workflow</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># train_cifar_tf.py</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> tensorflow <span class="im">as</span> tf</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> tensorflow.keras <span class="im">import</span> layers, models, optimizers, callbacks</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Load and preprocess data</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>(x_train, y_train), (x_test, y_test) <span class="op">=</span> tf.keras.datasets.cifar10.load_data()</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Normalize to [0, 1]</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>x_train <span class="op">=</span> x_train.astype(<span class="st">"float32"</span>) <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>x_test  <span class="op">=</span> x_test.astype(<span class="st">"float32"</span>)  <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Channel-wise normalization (ImageNet-like stats repurposed for CIFAR)</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>mean <span class="op">=</span> tf.constant([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>], dtype<span class="op">=</span>tf.float32)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>std  <span class="op">=</span> tf.constant([<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>], dtype<span class="op">=</span>tf.float32)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>x_train <span class="op">=</span> (x_train <span class="op">-</span> mean) <span class="op">/</span> std</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>x_test  <span class="op">=</span> (x_test  <span class="op">-</span> mean) <span class="op">/</span> std</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Model: EfficientNetB0 with custom head</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>base <span class="op">=</span> tf.keras.applications.EfficientNetB0(</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    include_top<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    weights<span class="op">=</span><span class="va">None</span>,</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    input_shape<span class="op">=</span>(<span class="dv">32</span>, <span class="dv">32</span>, <span class="dv">3</span>),</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>inputs <span class="op">=</span> tf.keras.Input(shape<span class="op">=</span>(<span class="dv">32</span>, <span class="dv">32</span>, <span class="dv">3</span>))</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> base(inputs, training<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> layers.GlobalAveragePooling2D()(x)</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> layers.Dropout(<span class="fl">0.3</span>)(x)</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>outputs <span class="op">=</span> layers.Dense(<span class="dv">10</span>, activation<span class="op">=</span><span class="st">"softmax"</span>)(x)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> tf.keras.Model(inputs, outputs)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">compile</span>(</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    optimizer<span class="op">=</span>optimizers.Adam(learning_rate<span class="op">=</span><span class="fl">1e-3</span>),</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>    loss<span class="op">=</span><span class="st">"sparse_categorical_crossentropy"</span>,</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>    metrics<span class="op">=</span>[<span class="st">"accuracy"</span>],</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Training</span></span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>cb <span class="op">=</span> [</span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>    callbacks.ReduceLROnPlateau(patience<span class="op">=</span><span class="dv">5</span>, factor<span class="op">=</span><span class="fl">0.5</span>, verbose<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>    callbacks.EarlyStopping(patience<span class="op">=</span><span class="dv">10</span>, restore_best_weights<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>    callbacks.ModelCheckpoint(<span class="st">"best_efficientnet_cifar10.h5"</span>,</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>                               save_best_only<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>model.fit(</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>    x_train, y_train,</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>    validation_data<span class="op">=</span>(x_test, y_test),</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>    epochs<span class="op">=</span><span class="dv">50</span>,</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>    batch_size<span class="op">=</span><span class="dv">128</span>,</span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>    callbacks<span class="op">=</span>cb,</span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="exporting-a-trained-model-to-onnx" class="level2">
<h2 class="anchored" data-anchor-id="exporting-a-trained-model-to-onnx" id="exporting-a-trained-model-to-onnx">Exporting a Trained Model to ONNX</h2>
<section id="exporting-from-pytorch" class="level3">
<h3 class="anchored" data-anchor-id="exporting-from-pytorch" id="exporting-from-pytorch">Exporting from PyTorch</h3>
<p>PyTorch has two export pathways: the classic <strong><code>torch.onnx.export</code></strong> and the newer <strong><code>torch.onnx.dynamo_export</code></strong> (available since PyTorch 2.0). The dynamo path handles more complex dynamic models but is still maturing.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Trained PyTorch Model (.pth weights)"] --&gt; B{"Export Strategy?"}
    B --&gt; C["Tracing torch.onnx.export"]
    B --&gt; D["Dynamo torch.onnx.dynamo_export (PyTorch ≥ 2.0)"]

    C --&gt; E["Standard CNNs ResNet · EfficientNet · YOLO"]
    C --&gt; F["Fixed control flow no data-dependent branches"]

    D --&gt; G["Transformers ViT · DETR · CLIP"]
    D --&gt; H["Dynamic control flow data-dependent branches"]

    E --&gt; I["ONNX Model (.onnx)"]
    F --&gt; I
    G --&gt; I
    H --&gt; I
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<section id="classic-export-tracing" class="level4">
<h4 class="anchored" data-anchor-id="classic-export-tracing">Classic Export (Tracing)</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># export_pytorch_to_onnx.py</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Reconstruct the model and load weights</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>DEVICE <span class="op">=</span> torch.device(<span class="st">"cpu"</span>)  <span class="co"># always export from CPU</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.resnet18(weights<span class="op">=</span><span class="va">None</span>)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>model.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>model.maxpool <span class="op">=</span> nn.Identity()</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>model.fc <span class="op">=</span> nn.Linear(model.fc.in_features, <span class="dv">10</span>)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>model.load_state_dict(torch.load(<span class="st">"best_resnet18_cifar10.pth"</span>, map_location<span class="op">=</span>DEVICE))</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Set model to evaluation mode — CRITICAL</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a><span class="co">#    This disables dropout and switches BatchNorm to eval statistics.</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Create a representative dummy input</span></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="co">#    Shape: (batch_size, channels, height, width)</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>, device<span class="op">=</span>DEVICE)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. Export</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>ONNX_PATH <span class="op">=</span> <span class="st">"resnet18_cifar10.onnx"</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>    dummy_input,</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    ONNX_PATH,</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>    export_params<span class="op">=</span><span class="va">True</span>,          <span class="co"># store weights inside the .onnx file</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>    opset_version<span class="op">=</span><span class="dv">18</span>,            <span class="co"># target opset; 17–19 recommended</span></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>    do_constant_folding<span class="op">=</span><span class="va">True</span>,    <span class="co"># fold constant expressions into weights</span></span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>    input_names<span class="op">=</span>[<span class="st">"images"</span>],      <span class="co"># name the input tensor(s)</span></span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>    output_names<span class="op">=</span>[<span class="st">"logits"</span>],     <span class="co"># name the output tensor(s)</span></span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>    dynamic_axes<span class="op">=</span>{               <span class="co"># mark batch dimension as dynamic</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        <span class="st">"images"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>},</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        <span class="st">"logits"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>},</span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>    verbose<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Model exported to </span><span class="sc">{</span>ONNX_PATH<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a><span class="co"># 5. Quick sanity check</span></span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>onnx_model <span class="op">=</span> onnx.load(ONNX_PATH)</span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(onnx_model)</span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"ONNX model is valid ✓"</span>)</span></code></pre></div></div>
</section>
<section id="dynamo-based-export-pytorch-2.0" class="level4">
<h4 class="anchored" data-anchor-id="dynamo-based-export-pytorch-2.0">Dynamo-Based Export (PyTorch ≥ 2.0)</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.onnx</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># The dynamo exporter captures the full computational graph</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="co"># including Python control flow, which tracing cannot.</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>export_output <span class="op">=</span> torch.onnx.dynamo_export(</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    dummy_input,</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>export_output.save(<span class="st">"resnet18_cifar10_dynamo.onnx"</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled" title="When to use tracing vs. dynamo">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>When to use tracing vs.&nbsp;dynamo
</div>
</div>
<div class="callout-body-container callout-body">
<p>Tracing records a single execution path and may miss data-dependent control flow (e.g., <code>if x.shape[0] &gt; 1:</code>). Dynamo (TorchDynamo + FX graph) captures the full Python graph. For standard CNN architectures, tracing is simpler and more mature. For transformer models with dynamic attention patterns, dynamo is preferred.</p>
</div>
</div>
</section>
</section>
<section id="exporting-from-tensorflow-keras" class="level3">
<h3 class="anchored" data-anchor-id="exporting-from-tensorflow-keras" id="exporting-from-tensorflow-keras">Exporting from TensorFlow / Keras</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install tf2onnx</span></code></pre></div></div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># export_tf_to_onnx.py</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> tensorflow <span class="im">as</span> tf</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> tf2onnx</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load the saved Keras model</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> tf.keras.models.load_model(<span class="st">"best_efficientnet_cifar10.h5"</span>)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Specify the input signature explicitly for reliable export</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>input_signature <span class="op">=</span> [</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    tf.TensorSpec(shape<span class="op">=</span>[<span class="va">None</span>, <span class="dv">32</span>, <span class="dv">32</span>, <span class="dv">3</span>], dtype<span class="op">=</span>tf.float32, name<span class="op">=</span><span class="st">"images"</span>)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to ONNX</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>onnx_model, _ <span class="op">=</span> tf2onnx.convert.from_keras(</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    input_signature<span class="op">=</span>input_signature,</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    opset<span class="op">=</span><span class="dv">18</span>,</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    output_path<span class="op">=</span><span class="st">"efficientnet_cifar10.onnx"</span>,</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"TensorFlow model successfully converted to ONNX ✓"</span>)</span></code></pre></div></div>
<p>You can also convert from a TensorFlow SavedModel directory:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> tf2onnx.convert <span class="dt">\</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="at">--saved-model</span> ./saved_model_dir <span class="dt">\</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="at">--output</span> efficientnet_cifar10.onnx <span class="dt">\</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="at">--opset</span> 18 <span class="dt">\</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="at">--inputs</span> images:0<span class="pp">[</span><span class="ss">batch,32,32,3</span><span class="pp">]</span> <span class="dt">\</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="at">--outputs</span> softmax:0</span></code></pre></div></div>
</section>
<section id="exporting-from-scikit-learn-sklearn-onnx" class="level3">
<h3 class="anchored" data-anchor-id="exporting-from-scikit-learn-sklearn-onnx" id="exporting-from-scikit-learn-sklearn-onnx">Exporting from scikit-learn (sklearn-onnx)</h3>
<p>While scikit-learn models are rarely used for deep vision, they appear in feature-based vision pipelines (e.g., HOG + SVM).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install skl2onnx</span></code></pre></div></div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.svm <span class="im">import</span> SVC</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.pipeline <span class="im">import</span> Pipeline</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.preprocessing <span class="im">import</span> StandardScaler</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> skl2onnx</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> skl2onnx <span class="im">import</span> convert_sklearn</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> skl2onnx.common.data_types <span class="im">import</span> FloatTensorType</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Assume `pipeline` is a trained sklearn Pipeline</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a><span class="co"># with input features of dimension 1764 (HOG features from 32x32 images)</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>initial_type <span class="op">=</span> [(<span class="st">"float_input"</span>, FloatTensorType([<span class="va">None</span>, <span class="dv">1764</span>]))]</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>onnx_model <span class="op">=</span> convert_sklearn(pipeline, initial_types<span class="op">=</span>initial_type,</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>                              target_opset<span class="op">=</span><span class="dv">18</span>)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> <span class="bu">open</span>(<span class="st">"hog_svm.onnx"</span>, <span class="st">"wb"</span>) <span class="im">as</span> f:</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    f.write(onnx_model.SerializeToString())</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="validating-and-inspecting-the-onnx-model" class="level2">
<h2 class="anchored" data-anchor-id="validating-and-inspecting-the-onnx-model" id="validating-and-inspecting-the-onnx-model">Validating and Inspecting the ONNX Model</h2>
<p>Before deploying, always validate and inspect the exported model. Subtle bugs in export (wrong opset, un-exported operators, shape errors) can silently produce wrong predictions.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Exported .onnx file"] --&gt; B["Structural Validation onnx.checker.check_model"]
    B --&gt; C{"Valid?"}
    C -- No --&gt; D["Fix export: check opset, custom ops, eval mode"]
    D --&gt; A
    C -- Yes --&gt; E["Shape Inference onnx.shape_inference.infer_shapes"]
    E --&gt; F["Numerical Validation Compare ORT vs source framework"]
    F --&gt; G{"Max diff &lt; 1e-4?"}
    G -- No --&gt; H["Investigate: NHWC/NCHW mismatch, Dropout not disabled, opset operator gap"]
    H --&gt; A
    G -- Yes --&gt; I["Visual Inspection Netron"]
    I --&gt; J["Model Ready for Optimization"]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<section id="structural-validation" class="level3">
<h3 class="anchored" data-anchor-id="structural-validation" id="structural-validation">Structural Validation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"resnet18_cifar10.onnx"</span>)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Full graph validity check (type-checking, shape propagation)</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model, full_check<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Print a human-readable summary</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(onnx.helper.printable_graph(model.graph))</span></code></pre></div></div>
</section>
<section id="inspecting-model-metadata" class="level3">
<h3 class="anchored" data-anchor-id="inspecting-model-metadata" id="inspecting-model-metadata">Inspecting Model Metadata</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"resnet18_cifar10.onnx"</span>)</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"IR version:      </span><span class="sc">{</span>model<span class="sc">.</span>ir_version<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Opset imports:   </span><span class="sc">{</span>[op.version <span class="cf">for</span> op <span class="kw">in</span> model.opset_import]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Graph name:      </span><span class="sc">{</span>model<span class="sc">.</span>graph<span class="sc">.</span>name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Inputs:"</span>)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> inp <span class="kw">in</span> model.graph.<span class="bu">input</span>:</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    shape <span class="op">=</span> [d.dim_value <span class="kw">or</span> d.dim_param</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>             <span class="cf">for</span> d <span class="kw">in</span> inp.<span class="bu">type</span>.tensor_type.shape.dim]</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>inp<span class="sc">.</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Outputs:"</span>)</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> out <span class="kw">in</span> model.graph.output:</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    shape <span class="op">=</span> [d.dim_value <span class="kw">or</span> d.dim_param</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>             <span class="cf">for</span> d <span class="kw">in</span> out.<span class="bu">type</span>.tensor_type.shape.dim]</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>out<span class="sc">.</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="shape-inference" class="level3">
<h3 class="anchored" data-anchor-id="shape-inference" id="shape-inference">Shape Inference</h3>
<p>ONNX can propagate shapes through the graph without running it:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> shape_inference</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"resnet18_cifar10.onnx"</span>)</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>inferred <span class="op">=</span> shape_inference.infer_shapes(model)</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>onnx.save(inferred, <span class="st">"resnet18_cifar10_inferred.onnx"</span>)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Shape inference complete. Intermediate shapes are now annotated."</span>)</span></code></pre></div></div>
</section>
<section id="numerical-validation-against-the-source-framework" class="level3">
<h3 class="anchored" data-anchor-id="numerical-validation-against-the-source-framework" id="numerical-validation-against-the-source-framework">Numerical Validation Against the Source Framework</h3>
<p>This is the most important validation step—compare ONNX Runtime outputs against the original framework:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Original PyTorch model ──</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>DEVICE <span class="op">=</span> torch.device(<span class="st">"cpu"</span>)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>pt_model <span class="op">=</span> models.resnet18(weights<span class="op">=</span><span class="va">None</span>)</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>pt_model.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>pt_model.maxpool <span class="op">=</span> nn.Identity()</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>pt_model.fc <span class="op">=</span> nn.Linear(pt_model.fc.in_features, <span class="dv">10</span>)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>pt_model.load_state_dict(torch.load(<span class="st">"best_resnet18_cifar10.pth"</span>, map_location<span class="op">=</span>DEVICE))</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>pt_model.<span class="bu">eval</span>()</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a><span class="co"># ── ONNX Runtime session ──</span></span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>                             providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Generate random test batch ──</span></span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>np.random.seed(<span class="dv">42</span>)</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>dummy_np <span class="op">=</span> np.random.randn(<span class="dv">4</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>).astype(np.float32)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>dummy_pt <span class="op">=</span> torch.from_numpy(dummy_np)</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Run both ──</span></span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>    pt_out <span class="op">=</span> pt_model(dummy_pt).numpy()</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>ort_out <span class="op">=</span> sess.run(<span class="va">None</span>, {<span class="st">"images"</span>: dummy_np})[<span class="dv">0</span>]</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Compare ──</span></span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>max_diff <span class="op">=</span> np.<span class="bu">abs</span>(pt_out <span class="op">-</span> ort_out).<span class="bu">max</span>()</span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Max absolute difference: </span><span class="sc">{</span>max_diff<span class="sc">:.2e}</span><span class="ss">"</span>)</span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a><span class="cf">assert</span> max_diff <span class="op">&lt;</span> <span class="fl">1e-4</span>, <span class="ss">f"Outputs diverge! Max diff = </span><span class="sc">{</span>max_diff<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Numerical validation passed ✓"</span>)</span></code></pre></div></div>
</section>
<section id="visual-inspection-with-netron" class="level3">
<h3 class="anchored" data-anchor-id="visual-inspection-with-netron" id="visual-inspection-with-netron">Visual Inspection with Netron</h3>
<p><a href="https://netron.app">Netron</a> is a browser-based ONNX graph visualizer. Simply drag and drop your <code>.onnx</code> file to see the full operator graph, tensor shapes, and weight statistics. It supports all major model formats (ONNX, TFLite, CoreML, PyTorch, etc.).</p>
<hr>
</section>
</section>
<section id="optimizing-the-onnx-model" class="level2">
<h2 class="anchored" data-anchor-id="optimizing-the-onnx-model" id="optimizing-the-onnx-model">Optimizing the ONNX Model</h2>
<section id="graph-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="graph-optimizations" id="graph-optimizations">Graph Optimizations</h3>
<p>ONNX Runtime applies optimizations automatically during session creation. You can also apply offline optimizations.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A["FP32 ONNX Model"] --&gt; B["Graph Optimization ORT_ENABLE_ALL"]
    B --&gt; C["Constant Folding pre-compute static subgraphs"]
    B --&gt; D["Redundant Node Elimination no-op Reshape, Identity"]
    B --&gt; E["Operator Fusion Conv + BN + ReLU → single kernel"]
    B --&gt; F["Layout Optimization NHWC ↔ NCHW reordering"]
    C &amp; D &amp; E &amp; F --&gt; G["Optimized FP32 Model"]
    G --&gt; H{"Need further speedup?"}
    H -- "Yes, latency-critical" --&gt; I["Static INT8 Quantization + calibration dataset"]
    H -- "Yes, no calib data" --&gt; J["Dynamic INT8 Quantization weights only"]
    H -- "No" --&gt; K["Deploy"]
    I --&gt; K
    J --&gt; K
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.transformers <span class="im">import</span> optimizer <span class="im">as</span> ort_optimizer</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime <span class="im">import</span> SessionOptions, GraphOptimizationLevel, InferenceSession</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Option 1: Let ORT apply optimizations at session creation ──</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> SessionOptions()</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Levels: DISABLE_ALL, ENABLE_BASIC, ENABLE_EXTENDED, ENABLE_ALL</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>opts.graph_optimization_level <span class="op">=</span> GraphOptimizationLevel.ORT_ENABLE_ALL</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Save the optimized graph to disk for inspection</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>opts.optimized_model_filepath <span class="op">=</span> <span class="st">"resnet18_cifar10_optimized.onnx"</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> InferenceSession(</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    sess_options<span class="op">=</span>opts,</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>],</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Optimized model saved to resnet18_cifar10_optimized.onnx"</span>)</span></code></pre></div></div>
<p>The optimizations applied include:</p>
<ul>
<li><strong>Constant folding</strong>: Pre-compute subgraphs with only constant inputs</li>
<li><strong>Redundant node elimination</strong>: Remove no-op Reshape, Identity, etc.</li>
<li><strong>Operator fusion</strong>: Fuse Conv + BatchNorm + Relu into a single kernel</li>
<li><strong>Layout optimization</strong>: Reorder memory layouts for cache efficiency (NHWC → NCHW or vice versa depending on EP)</li>
</ul>
</section>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h3>
<p>Quantization reduces model size and improves inference speed (often 2–4×) by converting float32 weights and/or activations to int8 or uint8.</p>
<section id="post-training-static-quantization-ptq" class="level4">
<h4 class="anchored" data-anchor-id="post-training-static-quantization-ptq">Post-Training Static Quantization (PTQ)</h4>
<p>Static quantization requires a <strong>calibration dataset</strong> to compute the activation ranges.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># quantize_static.py</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.quantization <span class="im">import</span> (</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    quantize_static,</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    CalibrationDataReader,</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    QuantFormat,</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    QuantType,</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Calibration data reader</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CIFAR10CalibReader(CalibrationDataReader):</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_batches: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">32</span>):</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        val_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>],</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>                                  [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        dataset <span class="op">=</span> datasets.CIFAR10(<span class="st">"./data"</span>, train<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>                                    transform<span class="op">=</span>val_transform)</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.loader <span class="op">=</span> <span class="bu">iter</span>(</span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>            DataLoader(dataset, batch_size<span class="op">=</span>batch_size, shuffle<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_batches <span class="op">=</span> num_batches</span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_next(<span class="va">self</span>):</span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.count <span class="op">&gt;=</span> <span class="va">self</span>.num_batches:</span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>            images, _ <span class="op">=</span> <span class="bu">next</span>(<span class="va">self</span>.loader)</span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {<span class="st">"images"</span>: images.numpy()}</span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">StopIteration</span>:</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Quantize</span></span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>quantize_static(</span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>    model_input<span class="op">=</span><span class="st">"resnet18_cifar10_optimized.onnx"</span>,</span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>    model_output<span class="op">=</span><span class="st">"resnet18_cifar10_int8.onnx"</span>,</span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>    calibration_data_reader<span class="op">=</span>CIFAR10CalibReader(num_batches<span class="op">=</span><span class="dv">20</span>),</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>    quant_format<span class="op">=</span>QuantFormat.QDQ,       <span class="co"># QDQ or QOperator</span></span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>    activation_type<span class="op">=</span>QuantType.QUInt8,</span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>    weight_type<span class="op">=</span>QuantType.QInt8,</span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>    per_channel<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>    reduce_range<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Static INT8 quantization complete ✓"</span>)</span></code></pre></div></div>
</section>
<section id="post-training-dynamic-quantization-faster-no-calibration-data-needed" class="level4">
<h4 class="anchored" data-anchor-id="post-training-dynamic-quantization-faster-no-calibration-data-needed">Post-Training Dynamic Quantization (faster, no calibration data needed)</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.quantization <span class="im">import</span> quantize_dynamic, QuantType</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>quantize_dynamic(</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    model_input<span class="op">=</span><span class="st">"resnet18_cifar10_optimized.onnx"</span>,</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    model_output<span class="op">=</span><span class="st">"resnet18_cifar10_dynamic_int8.onnx"</span>,</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    weight_type<span class="op">=</span>QuantType.QInt8,</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    per_channel<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Dynamic INT8 quantization complete ✓"</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled" title="Static vs. Dynamic Quantization">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Static vs.&nbsp;Dynamic Quantization
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Dynamic quantization</strong> only quantizes weights ahead of time; activations are quantized at runtime. No calibration data is needed. Works well for transformer layers (Gemm / MatMul) but is less effective for convolutions.</p>
<p><strong>Static quantization</strong> quantizes both weights and activations using pre-computed scale/zero-point from a calibration dataset. Faster inference, especially for CNNs, but requires a representative calibration set.</p>
</div>
</div>
</section>
</section>
<section id="pruning-before-export" class="level3">
<h3 class="anchored" data-anchor-id="pruning-before-export" id="pruning-before-export">Pruning Before Export</h3>
<p>For maximum compression, prune the model <em>before</em> exporting to ONNX. PyTorch’s <code>torch.nn.utils.prune</code> module makes this straightforward.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.prune <span class="im">as</span> prune</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply magnitude-based unstructured pruning to all Conv2d layers</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">isinstance</span>(module, torch.nn.Conv2d):</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        prune.l1_unstructured(module, name<span class="op">=</span><span class="st">"weight"</span>, amount<span class="op">=</span><span class="fl">0.3</span>)  <span class="co"># 30% sparsity</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        prune.remove(module, <span class="st">"weight"</span>)  <span class="co"># make permanent</span></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Fine-tune the pruned model for a few epochs, then export to ONNX</span></span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a><span class="co"># ... (fine-tuning loop) ...</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(model, dummy_input, <span class="st">"resnet18_pruned.onnx"</span>, opset_version<span class="op">=</span><span class="dv">18</span>)</span></code></pre></div></div>
<p>Note that unstructured pruning introduces <em>sparsity</em> but does not reduce parameter count in standard dense ONNX kernels. To get actual speedup, you need either structured pruning (whole channels) or a sparse execution provider.</p>
<hr>
</section>
</section>
<section id="running-inference-with-onnx-runtime" class="level2">
<h2 class="anchored" data-anchor-id="running-inference-with-onnx-runtime" id="running-inference-with-onnx-runtime">Running Inference with ONNX Runtime</h2>
<section id="basic-inference-session" class="level3">
<h3 class="anchored" data-anchor-id="basic-inference-session" id="basic-inference-session">Basic Inference Session</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># infer_basic.py</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Create the inference session</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>],</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Inspect input/output metadata</span></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> inp <span class="kw">in</span> sess.get_inputs():</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Input  name=</span><span class="sc">{</span>inp<span class="sc">.</span>name<span class="sc">!r}</span><span class="ss">  shape=</span><span class="sc">{</span>inp<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">  dtype=</span><span class="sc">{</span>inp<span class="sc">.</span><span class="bu">type</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> out <span class="kw">in</span> sess.get_outputs():</span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Output name=</span><span class="sc">{</span>out<span class="sc">.</span>name<span class="sc">!r}</span><span class="ss">  shape=</span><span class="sc">{</span>out<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">  dtype=</span><span class="sc">{</span>out<span class="sc">.</span><span class="bu">type</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Preprocess a single image</span></span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>CLASSES <span class="op">=</span> [<span class="st">"airplane"</span>, <span class="st">"automobile"</span>, <span class="st">"bird"</span>, <span class="st">"cat"</span>, <span class="st">"deer"</span>,</span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>           <span class="st">"dog"</span>, <span class="st">"frog"</span>, <span class="st">"horse"</span>, <span class="st">"ship"</span>, <span class="st">"truck"</span>]</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> T.Compose([</span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>    T.Resize((<span class="dv">32</span>, <span class="dv">32</span>)),</span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>    T.Normalize([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>], [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">"test_image.jpg"</span>).convert(<span class="st">"RGB"</span>)</span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>tensor <span class="op">=</span> transform(image).unsqueeze(<span class="dv">0</span>).numpy()  <span class="co"># shape: (1, 3, 32, 32)</span></span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. Run inference</span></span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a>input_name <span class="op">=</span> sess.get_inputs()[<span class="dv">0</span>].name   <span class="co"># "images"</span></span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>outputs <span class="op">=</span> sess.run(<span class="va">None</span>, {input_name: tensor})</span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a>logits <span class="op">=</span> outputs[<span class="dv">0</span>]   <span class="co"># shape: (1, 10)</span></span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a><span class="co"># 5. Decode prediction</span></span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a><span class="co"># ──────────────────────────────────────────────────────────────</span></span>
<span id="cb21-49"><a href="#cb21-49" aria-hidden="true" tabindex="-1"></a>probabilities <span class="op">=</span> np.exp(logits) <span class="op">/</span> np.exp(logits).<span class="bu">sum</span>(axis<span class="op">=-</span><span class="dv">1</span>, keepdims<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb21-50"><a href="#cb21-50" aria-hidden="true" tabindex="-1"></a>predicted_class <span class="op">=</span> probabilities.argmax(axis<span class="op">=-</span><span class="dv">1</span>)[<span class="dv">0</span>]</span>
<span id="cb21-51"><a href="#cb21-51" aria-hidden="true" tabindex="-1"></a>confidence <span class="op">=</span> probabilities[<span class="dv">0</span>, predicted_class]</span>
<span id="cb21-52"><a href="#cb21-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-53"><a href="#cb21-53" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Predicted: </span><span class="sc">{</span>CLASSES[predicted_class]<span class="sc">}</span><span class="ss"> (</span><span class="sc">{</span>confidence<span class="sc">:.1%}</span><span class="ss"> confidence)"</span>)</span></code></pre></div></div>
</section>
<section id="configuring-session-options" class="level3">
<h3 class="anchored" data-anchor-id="configuring-session-options" id="configuring-session-options">Configuring Session Options</h3>
<p><code>SessionOptions</code> is how you tune ORT’s behavior:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Threading</span></span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>opts.intra_op_num_threads <span class="op">=</span> <span class="dv">4</span>   <span class="co"># threads within a single operator (e.g., matrix mul)</span></span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>opts.inter_op_num_threads <span class="op">=</span> <span class="dv">2</span>   <span class="co"># threads across independent operators</span></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Memory</span></span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>opts.enable_cpu_mem_arena <span class="op">=</span> <span class="va">True</span>          <span class="co"># pre-allocate a memory arena</span></span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>opts.enable_mem_pattern   <span class="op">=</span> <span class="va">True</span>          <span class="co"># reuse memory across runs (same input shapes)</span></span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>opts.enable_mem_reuse     <span class="op">=</span> <span class="va">True</span></span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Logging</span></span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>opts.log_severity_level <span class="op">=</span> <span class="dv">3</span>   <span class="co"># 0=VERBOSE, 1=INFO, 2=</span><span class="al">WARNING</span><span class="co">, 3=ERROR, 4=FATAL</span></span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Graph optimization (see above for levels)</span></span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>opts.graph_optimization_level <span class="op">=</span> ort.GraphOptimizationLevel.ORT_ENABLE_ALL</span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Profiling — dumps a JSON Chrome trace file</span></span>
<span id="cb22-21"><a href="#cb22-21" aria-hidden="true" tabindex="-1"></a>opts.enable_profiling <span class="op">=</span> <span class="va">False</span></span>
<span id="cb22-22"><a href="#cb22-22" aria-hidden="true" tabindex="-1"></a><span class="co"># opts.profile_file_prefix = "ort_profile"</span></span>
<span id="cb22-23"><a href="#cb22-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-24"><a href="#cb22-24" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb22-25"><a href="#cb22-25" aria-hidden="true" tabindex="-1"></a>    <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb22-26"><a href="#cb22-26" aria-hidden="true" tabindex="-1"></a>    sess_options<span class="op">=</span>opts,</span>
<span id="cb22-27"><a href="#cb22-27" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>],</span>
<span id="cb22-28"><a href="#cb22-28" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="execution-providers" class="level3">
<h3 class="anchored" data-anchor-id="execution-providers" id="execution-providers">Execution Providers</h3>
<p>ORT tries each EP in the order you provide them. Operators that an EP cannot handle fall back to the next EP in the list.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Inference Request"] --&gt; B["Try EP #1 e.g. CUDAExecutionProvider"]
    B --&gt; C{"Operator supported?"}
    C -- Yes --&gt; D["Run on GPU"]
    C -- No --&gt; E["Try EP #2 e.g. CPUExecutionProvider"]
    E --&gt; F{"Operator supported?"}
    F -- Yes --&gt; G["Run on CPU"]
    F -- No --&gt; H["RuntimeError: No EP can handle operator"]
    D --&gt; I["Output Tensor"]
    G --&gt; I
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a><span class="co"># List EPs available on this machine</span></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(ort.get_available_providers())</span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a><span class="co"># e.g.: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']</span></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Use CUDA with CPU fallback</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>        (<span class="st">"CUDAExecutionProvider"</span>, {</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">"device_id"</span>: <span class="dv">0</span>,</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">"arena_extend_strategy"</span>: <span class="st">"kNextPowerOfTwo"</span>,</span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">"gpu_mem_limit"</span>: <span class="dv">2</span> <span class="op">*</span> <span class="dv">1024</span> <span class="op">**</span> <span class="dv">3</span>,   <span class="co"># 2 GB</span></span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">"cudnn_conv_algo_search"</span>: <span class="st">"EXHAUSTIVE"</span>,</span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">"do_copy_in_default_stream"</span>: <span class="va">True</span>,</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a>        }),</span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a>        <span class="st">"CPUExecutionProvider"</span>,</span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="gpu-inference-with-cuda" class="level3">
<h3 class="anchored" data-anchor-id="gpu-inference-with-cuda" id="gpu-inference-with-cuda">GPU Inference with CUDA</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="co"># gpu_inference.py</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Create CUDA session ──</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>providers <span class="op">=</span> [</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>    (<span class="st">"CUDAExecutionProvider"</span>, {<span class="st">"device_id"</span>: <span class="dv">0</span>}),</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"CPUExecutionProvider"</span>,</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"resnet18_cifar10.onnx"</span>, providers<span class="op">=</span>providers)</span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Confirm which EP owns the compute</span></span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Active providers:"</span>, sess.get_providers())</span>
<span id="cb24-15"><a href="#cb24-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-16"><a href="#cb24-16" aria-hidden="true" tabindex="-1"></a><span class="co"># ── IO Binding — zero-copy for GPU tensors ──</span></span>
<span id="cb24-17"><a href="#cb24-17" aria-hidden="true" tabindex="-1"></a><span class="co"># This avoids an implicit host↔device copy for each run() call.</span></span>
<span id="cb24-18"><a href="#cb24-18" aria-hidden="true" tabindex="-1"></a>io_binding <span class="op">=</span> sess.io_binding()</span>
<span id="cb24-19"><a href="#cb24-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-20"><a href="#cb24-20" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb24-21"><a href="#cb24-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Allocate input tensor directly on GPU</span></span>
<span id="cb24-22"><a href="#cb24-22" aria-hidden="true" tabindex="-1"></a>gpu_input <span class="op">=</span> torch.randn(<span class="dv">8</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>, device<span class="op">=</span><span class="st">"cuda"</span>).contiguous()</span>
<span id="cb24-23"><a href="#cb24-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-24"><a href="#cb24-24" aria-hidden="true" tabindex="-1"></a>io_binding.bind_input(</span>
<span id="cb24-25"><a href="#cb24-25" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"images"</span>,</span>
<span id="cb24-26"><a href="#cb24-26" aria-hidden="true" tabindex="-1"></a>    device_type<span class="op">=</span><span class="st">"cuda"</span>,</span>
<span id="cb24-27"><a href="#cb24-27" aria-hidden="true" tabindex="-1"></a>    device_id<span class="op">=</span><span class="dv">0</span>,</span>
<span id="cb24-28"><a href="#cb24-28" aria-hidden="true" tabindex="-1"></a>    element_type<span class="op">=</span>np.float32,</span>
<span id="cb24-29"><a href="#cb24-29" aria-hidden="true" tabindex="-1"></a>    shape<span class="op">=</span><span class="bu">tuple</span>(gpu_input.shape),</span>
<span id="cb24-30"><a href="#cb24-30" aria-hidden="true" tabindex="-1"></a>    buffer_ptr<span class="op">=</span>gpu_input.data_ptr(),</span>
<span id="cb24-31"><a href="#cb24-31" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb24-32"><a href="#cb24-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-33"><a href="#cb24-33" aria-hidden="true" tabindex="-1"></a><span class="co"># Allocate output tensor</span></span>
<span id="cb24-34"><a href="#cb24-34" aria-hidden="true" tabindex="-1"></a>gpu_output <span class="op">=</span> torch.empty(<span class="dv">8</span>, <span class="dv">10</span>, device<span class="op">=</span><span class="st">"cuda"</span>).contiguous()</span>
<span id="cb24-35"><a href="#cb24-35" aria-hidden="true" tabindex="-1"></a>io_binding.bind_output(</span>
<span id="cb24-36"><a href="#cb24-36" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"logits"</span>,</span>
<span id="cb24-37"><a href="#cb24-37" aria-hidden="true" tabindex="-1"></a>    device_type<span class="op">=</span><span class="st">"cuda"</span>,</span>
<span id="cb24-38"><a href="#cb24-38" aria-hidden="true" tabindex="-1"></a>    device_id<span class="op">=</span><span class="dv">0</span>,</span>
<span id="cb24-39"><a href="#cb24-39" aria-hidden="true" tabindex="-1"></a>    element_type<span class="op">=</span>np.float32,</span>
<span id="cb24-40"><a href="#cb24-40" aria-hidden="true" tabindex="-1"></a>    shape<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">10</span>),</span>
<span id="cb24-41"><a href="#cb24-41" aria-hidden="true" tabindex="-1"></a>    buffer_ptr<span class="op">=</span>gpu_output.data_ptr(),</span>
<span id="cb24-42"><a href="#cb24-42" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb24-43"><a href="#cb24-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-44"><a href="#cb24-44" aria-hidden="true" tabindex="-1"></a><span class="co"># Run without any host↔device copies</span></span>
<span id="cb24-45"><a href="#cb24-45" aria-hidden="true" tabindex="-1"></a>sess.run_with_iobinding(io_binding)</span>
<span id="cb24-46"><a href="#cb24-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-47"><a href="#cb24-47" aria-hidden="true" tabindex="-1"></a>logits <span class="op">=</span> gpu_output.cpu().numpy()</span>
<span id="cb24-48"><a href="#cb24-48" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"GPU inference output shape:"</span>, logits.shape)</span></code></pre></div></div>
</section>
<section id="batch-inference" class="level3">
<h3 class="anchored" data-anchor-id="batch-inference" id="batch-inference">Batch Inference</h3>
<p>Processing images in batches amortizes kernel launch overhead and maximizes hardware utilization.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># batch_inference.py</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pathlib <span class="im">import</span> Path</span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List</span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a>CLASSES <span class="op">=</span> [<span class="st">"airplane"</span>, <span class="st">"automobile"</span>, <span class="st">"bird"</span>, <span class="st">"cat"</span>, <span class="st">"deer"</span>,</span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a>           <span class="st">"dog"</span>, <span class="st">"frog"</span>, <span class="st">"horse"</span>, <span class="st">"ship"</span>, <span class="st">"truck"</span>]</span>
<span id="cb25-12"><a href="#cb25-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-13"><a href="#cb25-13" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> T.Compose([</span>
<span id="cb25-14"><a href="#cb25-14" aria-hidden="true" tabindex="-1"></a>    T.Resize((<span class="dv">32</span>, <span class="dv">32</span>)),</span>
<span id="cb25-15"><a href="#cb25-15" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb25-16"><a href="#cb25-16" aria-hidden="true" tabindex="-1"></a>    T.Normalize([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>], [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb25-17"><a href="#cb25-17" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb25-18"><a href="#cb25-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-19"><a href="#cb25-19" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_batch(image_paths: List[Path]) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb25-20"><a href="#cb25-20" aria-hidden="true" tabindex="-1"></a>    tensors <span class="op">=</span> []</span>
<span id="cb25-21"><a href="#cb25-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> p <span class="kw">in</span> image_paths:</span>
<span id="cb25-22"><a href="#cb25-22" aria-hidden="true" tabindex="-1"></a>        img <span class="op">=</span> Image.<span class="bu">open</span>(p).convert(<span class="st">"RGB"</span>)</span>
<span id="cb25-23"><a href="#cb25-23" aria-hidden="true" tabindex="-1"></a>        tensors.append(transform(img).numpy())</span>
<span id="cb25-24"><a href="#cb25-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> np.stack(tensors, axis<span class="op">=</span><span class="dv">0</span>)   <span class="co"># (N, 3, 32, 32)</span></span>
<span id="cb25-25"><a href="#cb25-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-26"><a href="#cb25-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-27"><a href="#cb25-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> infer_batch(sess: ort.InferenceSession,</span>
<span id="cb25-28"><a href="#cb25-28" aria-hidden="true" tabindex="-1"></a>                image_paths: List[Path],</span>
<span id="cb25-29"><a href="#cb25-29" aria-hidden="true" tabindex="-1"></a>                batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">32</span>) <span class="op">-&gt;</span> List[<span class="bu">str</span>]:</span>
<span id="cb25-30"><a href="#cb25-30" aria-hidden="true" tabindex="-1"></a>    input_name <span class="op">=</span> sess.get_inputs()[<span class="dv">0</span>].name</span>
<span id="cb25-31"><a href="#cb25-31" aria-hidden="true" tabindex="-1"></a>    predictions <span class="op">=</span> []</span>
<span id="cb25-32"><a href="#cb25-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-33"><a href="#cb25-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(image_paths), batch_size):</span>
<span id="cb25-34"><a href="#cb25-34" aria-hidden="true" tabindex="-1"></a>        batch_paths <span class="op">=</span> image_paths[i : i <span class="op">+</span> batch_size]</span>
<span id="cb25-35"><a href="#cb25-35" aria-hidden="true" tabindex="-1"></a>        batch_np <span class="op">=</span> preprocess_batch(batch_paths)</span>
<span id="cb25-36"><a href="#cb25-36" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> sess.run(<span class="va">None</span>, {input_name: batch_np})[<span class="dv">0</span>]</span>
<span id="cb25-37"><a href="#cb25-37" aria-hidden="true" tabindex="-1"></a>        batch_preds <span class="op">=</span> logits.argmax(axis<span class="op">=-</span><span class="dv">1</span>).tolist()</span>
<span id="cb25-38"><a href="#cb25-38" aria-hidden="true" tabindex="-1"></a>        predictions.extend([CLASSES[p] <span class="cf">for</span> p <span class="kw">in</span> batch_preds])</span>
<span id="cb25-39"><a href="#cb25-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-40"><a href="#cb25-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> predictions</span>
<span id="cb25-41"><a href="#cb25-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-42"><a href="#cb25-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-43"><a href="#cb25-43" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb25-44"><a href="#cb25-44" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb25-45"><a href="#cb25-45" aria-hidden="true" tabindex="-1"></a>                             providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb25-46"><a href="#cb25-46" aria-hidden="true" tabindex="-1"></a>image_files <span class="op">=</span> <span class="bu">list</span>(Path(<span class="st">"./test_images"</span>).glob(<span class="st">"*.jpg"</span>))</span>
<span id="cb25-47"><a href="#cb25-47" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> infer_batch(sess, image_files, batch_size<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb25-48"><a href="#cb25-48" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> path, pred <span class="kw">in</span> <span class="bu">zip</span>(image_files, results):</span>
<span id="cb25-49"><a href="#cb25-49" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>path<span class="sc">.</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>pred<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="preprocessing-and-postprocessing-pipelines" class="level2">
<h2 class="anchored" data-anchor-id="preprocessing-and-postprocessing-pipelines" id="preprocessing-and-postprocessing-pipelines">Preprocessing and Postprocessing Pipelines</h2>
<section id="image-classification" class="level3">
<h3 class="anchored" data-anchor-id="image-classification" id="image-classification">Image Classification</h3>
<p>The complete classification pipeline, including softmax and top-k decoding:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> softmax(x: np.ndarray, axis: <span class="bu">int</span> <span class="op">=</span> <span class="op">-</span><span class="dv">1</span>) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a>    e <span class="op">=</span> np.exp(x <span class="op">-</span> x.<span class="bu">max</span>(axis<span class="op">=</span>axis, keepdims<span class="op">=</span><span class="va">True</span>))</span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> e <span class="op">/</span> e.<span class="bu">sum</span>(axis<span class="op">=</span>axis, keepdims<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> classify_topk(model_path: <span class="bu">str</span>, image_path: <span class="bu">str</span>,</span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>                  class_names: <span class="bu">list</span>, k: <span class="bu">int</span> <span class="op">=</span> <span class="dv">5</span>):</span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>    sess <span class="op">=</span> ort.InferenceSession(model_path,</span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a>                                 providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb26-15"><a href="#cb26-15" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> T.Compose([</span>
<span id="cb26-16"><a href="#cb26-16" aria-hidden="true" tabindex="-1"></a>        T.Resize((<span class="dv">32</span>, <span class="dv">32</span>)),</span>
<span id="cb26-17"><a href="#cb26-17" aria-hidden="true" tabindex="-1"></a>        T.ToTensor(),</span>
<span id="cb26-18"><a href="#cb26-18" aria-hidden="true" tabindex="-1"></a>        T.Normalize([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>], [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb26-19"><a href="#cb26-19" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb26-20"><a href="#cb26-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-21"><a href="#cb26-21" aria-hidden="true" tabindex="-1"></a>    img <span class="op">=</span> Image.<span class="bu">open</span>(image_path).convert(<span class="st">"RGB"</span>)</span>
<span id="cb26-22"><a href="#cb26-22" aria-hidden="true" tabindex="-1"></a>    inp <span class="op">=</span> transform(img).unsqueeze(<span class="dv">0</span>).numpy()</span>
<span id="cb26-23"><a href="#cb26-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-24"><a href="#cb26-24" aria-hidden="true" tabindex="-1"></a>    logits <span class="op">=</span> sess.run(<span class="va">None</span>, {<span class="st">"images"</span>: inp})[<span class="dv">0</span>][<span class="dv">0</span>]   <span class="co"># (10,)</span></span>
<span id="cb26-25"><a href="#cb26-25" aria-hidden="true" tabindex="-1"></a>    probs  <span class="op">=</span> softmax(logits)</span>
<span id="cb26-26"><a href="#cb26-26" aria-hidden="true" tabindex="-1"></a>    top_k  <span class="op">=</span> probs.argsort()[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb26-27"><a href="#cb26-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-28"><a href="#cb26-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Top-</span><span class="sc">{</span>k<span class="sc">}</span><span class="ss"> predictions:"</span>)</span>
<span id="cb26-29"><a href="#cb26-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> rank, idx <span class="kw">in</span> <span class="bu">enumerate</span>(top_k, <span class="dv">1</span>):</span>
<span id="cb26-30"><a href="#cb26-30" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>rank<span class="sc">}</span><span class="ss">. </span><span class="sc">{</span>class_names[idx]<span class="sc">:&lt;15}</span><span class="ss"> </span><span class="sc">{</span>probs[idx]<span class="sc">:.2%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="object-detection" class="level3">
<h3 class="anchored" data-anchor-id="object-detection" id="object-detection">Object Detection</h3>
<p>For models like YOLOv8 or DETR exported to ONNX, the postprocessing involves non-maximum suppression (NMS) and bounding box decoding.</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Input Image BGR uint8"] --&gt; B["BGR → RGB cv2.cvtColor"]
    B --&gt; C["Letterbox Resize 640×640 with padding"]
    C --&gt; D["Normalize ÷255 → float32"]
    D --&gt; E["HWC → CHW np.transpose"]
    E --&gt; F["Add Batch Dim np.expand_dims"]
    F --&gt; G["ORT Session sess.run()"]
    G --&gt; H["Raw Output (1, 84, 8400)"]
    H --&gt; I["Transpose → (8400, 84) boxes xywh + class scores"]
    I --&gt; J["Confidence Filter score ≥ threshold"]
    J --&gt; K["xywh → xyxy bounding box decode"]
    K --&gt; L["NMS IoU-based deduplication"]
    L --&gt; M["Final Detections boxes · scores · class IDs"]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="co"># yolo_inference.py — demonstrates the postprocessing pattern</span></span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Tuple</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> letterbox(image: np.ndarray, target_size: Tuple[<span class="bu">int</span>, <span class="bu">int</span>] <span class="op">=</span> (<span class="dv">640</span>, <span class="dv">640</span>),</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>              fill_value: <span class="bu">int</span> <span class="op">=</span> <span class="dv">114</span>) <span class="op">-&gt;</span> Tuple[np.ndarray, <span class="bu">float</span>, Tuple[<span class="bu">int</span>, <span class="bu">int</span>]]:</span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Resize with preserved aspect ratio and pad to square."""</span></span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a>    h, w <span class="op">=</span> image.shape[:<span class="dv">2</span>]</span>
<span id="cb27-12"><a href="#cb27-12" aria-hidden="true" tabindex="-1"></a>    th, tw <span class="op">=</span> target_size</span>
<span id="cb27-13"><a href="#cb27-13" aria-hidden="true" tabindex="-1"></a>    ratio <span class="op">=</span> <span class="bu">min</span>(th <span class="op">/</span> h, tw <span class="op">/</span> w)</span>
<span id="cb27-14"><a href="#cb27-14" aria-hidden="true" tabindex="-1"></a>    new_h, new_w <span class="op">=</span> <span class="bu">int</span>(h <span class="op">*</span> ratio), <span class="bu">int</span>(w <span class="op">*</span> ratio)</span>
<span id="cb27-15"><a href="#cb27-15" aria-hidden="true" tabindex="-1"></a>    resized <span class="op">=</span> cv2.resize(image, (new_w, new_h), interpolation<span class="op">=</span>cv2.INTER_LINEAR)</span>
<span id="cb27-16"><a href="#cb27-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-17"><a href="#cb27-17" aria-hidden="true" tabindex="-1"></a>    pad_top  <span class="op">=</span> (th <span class="op">-</span> new_h) <span class="op">//</span> <span class="dv">2</span></span>
<span id="cb27-18"><a href="#cb27-18" aria-hidden="true" tabindex="-1"></a>    pad_left <span class="op">=</span> (tw <span class="op">-</span> new_w) <span class="op">//</span> <span class="dv">2</span></span>
<span id="cb27-19"><a href="#cb27-19" aria-hidden="true" tabindex="-1"></a>    padded   <span class="op">=</span> np.full((th, tw, <span class="dv">3</span>), fill_value, dtype<span class="op">=</span>np.uint8)</span>
<span id="cb27-20"><a href="#cb27-20" aria-hidden="true" tabindex="-1"></a>    padded[pad_top:pad_top <span class="op">+</span> new_h, pad_left:pad_left <span class="op">+</span> new_w] <span class="op">=</span> resized</span>
<span id="cb27-21"><a href="#cb27-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-22"><a href="#cb27-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> padded, ratio, (pad_left, pad_top)</span>
<span id="cb27-23"><a href="#cb27-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-24"><a href="#cb27-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-25"><a href="#cb27-25" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_detection(image_bgr: np.ndarray) <span class="op">-&gt;</span> Tuple[np.ndarray, <span class="bu">float</span>, <span class="bu">tuple</span>]:</span>
<span id="cb27-26"><a href="#cb27-26" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert BGR OpenCV image to ONNX-ready float32 tensor."""</span></span>
<span id="cb27-27"><a href="#cb27-27" aria-hidden="true" tabindex="-1"></a>    image_rgb <span class="op">=</span> cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)</span>
<span id="cb27-28"><a href="#cb27-28" aria-hidden="true" tabindex="-1"></a>    padded, ratio, padding <span class="op">=</span> letterbox(image_rgb, (<span class="dv">640</span>, <span class="dv">640</span>))</span>
<span id="cb27-29"><a href="#cb27-29" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> padded.astype(np.float32) <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb27-30"><a href="#cb27-30" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> np.transpose(blob, (<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">1</span>))   <span class="co"># HWC → CHW</span></span>
<span id="cb27-31"><a href="#cb27-31" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> np.expand_dims(blob, axis<span class="op">=</span><span class="dv">0</span>)    <span class="co"># add batch dim</span></span>
<span id="cb27-32"><a href="#cb27-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> blob, ratio, padding</span>
<span id="cb27-33"><a href="#cb27-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-34"><a href="#cb27-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-35"><a href="#cb27-35" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> nms(boxes: np.ndarray, scores: np.ndarray,</span>
<span id="cb27-36"><a href="#cb27-36" aria-hidden="true" tabindex="-1"></a>        iou_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.45</span>) <span class="op">-&gt;</span> List[<span class="bu">int</span>]:</span>
<span id="cb27-37"><a href="#cb27-37" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simple greedy NMS."""</span></span>
<span id="cb27-38"><a href="#cb27-38" aria-hidden="true" tabindex="-1"></a>    x1, y1, x2, y2 <span class="op">=</span> boxes[:, <span class="dv">0</span>], boxes[:, <span class="dv">1</span>], boxes[:, <span class="dv">2</span>], boxes[:, <span class="dv">3</span>]</span>
<span id="cb27-39"><a href="#cb27-39" aria-hidden="true" tabindex="-1"></a>    areas <span class="op">=</span> (x2 <span class="op">-</span> x1) <span class="op">*</span> (y2 <span class="op">-</span> y1)</span>
<span id="cb27-40"><a href="#cb27-40" aria-hidden="true" tabindex="-1"></a>    order <span class="op">=</span> scores.argsort()[::<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb27-41"><a href="#cb27-41" aria-hidden="true" tabindex="-1"></a>    keep  <span class="op">=</span> []</span>
<span id="cb27-42"><a href="#cb27-42" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> order.size <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb27-43"><a href="#cb27-43" aria-hidden="true" tabindex="-1"></a>        i <span class="op">=</span> order[<span class="dv">0</span>]</span>
<span id="cb27-44"><a href="#cb27-44" aria-hidden="true" tabindex="-1"></a>        keep.append(i)</span>
<span id="cb27-45"><a href="#cb27-45" aria-hidden="true" tabindex="-1"></a>        xx1 <span class="op">=</span> np.maximum(x1[i], x1[order[<span class="dv">1</span>:]])</span>
<span id="cb27-46"><a href="#cb27-46" aria-hidden="true" tabindex="-1"></a>        yy1 <span class="op">=</span> np.maximum(y1[i], y1[order[<span class="dv">1</span>:]])</span>
<span id="cb27-47"><a href="#cb27-47" aria-hidden="true" tabindex="-1"></a>        xx2 <span class="op">=</span> np.minimum(x2[i], x2[order[<span class="dv">1</span>:]])</span>
<span id="cb27-48"><a href="#cb27-48" aria-hidden="true" tabindex="-1"></a>        yy2 <span class="op">=</span> np.minimum(y2[i], y2[order[<span class="dv">1</span>:]])</span>
<span id="cb27-49"><a href="#cb27-49" aria-hidden="true" tabindex="-1"></a>        inter <span class="op">=</span> np.maximum(<span class="dv">0</span>, xx2 <span class="op">-</span> xx1) <span class="op">*</span> np.maximum(<span class="dv">0</span>, yy2 <span class="op">-</span> yy1)</span>
<span id="cb27-50"><a href="#cb27-50" aria-hidden="true" tabindex="-1"></a>        iou   <span class="op">=</span> inter <span class="op">/</span> (areas[i] <span class="op">+</span> areas[order[<span class="dv">1</span>:]] <span class="op">-</span> inter)</span>
<span id="cb27-51"><a href="#cb27-51" aria-hidden="true" tabindex="-1"></a>        order <span class="op">=</span> order[<span class="dv">1</span>:][iou <span class="op">&lt;=</span> iou_threshold]</span>
<span id="cb27-52"><a href="#cb27-52" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> keep</span>
<span id="cb27-53"><a href="#cb27-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-54"><a href="#cb27-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-55"><a href="#cb27-55" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> detect(model_path: <span class="bu">str</span>, image_path: <span class="bu">str</span>,</span>
<span id="cb27-56"><a href="#cb27-56" aria-hidden="true" tabindex="-1"></a>           conf_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.25</span>):</span>
<span id="cb27-57"><a href="#cb27-57" aria-hidden="true" tabindex="-1"></a>    sess <span class="op">=</span> ort.InferenceSession(model_path,</span>
<span id="cb27-58"><a href="#cb27-58" aria-hidden="true" tabindex="-1"></a>                                 providers<span class="op">=</span>[<span class="st">"CUDAExecutionProvider"</span>,</span>
<span id="cb27-59"><a href="#cb27-59" aria-hidden="true" tabindex="-1"></a>                                            <span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb27-60"><a href="#cb27-60" aria-hidden="true" tabindex="-1"></a>    image_bgr <span class="op">=</span> cv2.imread(image_path)</span>
<span id="cb27-61"><a href="#cb27-61" aria-hidden="true" tabindex="-1"></a>    blob, ratio, (pad_left, pad_top) <span class="op">=</span> preprocess_detection(image_bgr)</span>
<span id="cb27-62"><a href="#cb27-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-63"><a href="#cb27-63" aria-hidden="true" tabindex="-1"></a>    input_name  <span class="op">=</span> sess.get_inputs()[<span class="dv">0</span>].name</span>
<span id="cb27-64"><a href="#cb27-64" aria-hidden="true" tabindex="-1"></a>    output_name <span class="op">=</span> sess.get_outputs()[<span class="dv">0</span>].name</span>
<span id="cb27-65"><a href="#cb27-65" aria-hidden="true" tabindex="-1"></a>    raw_output  <span class="op">=</span> sess.run([output_name], {input_name: blob})[<span class="dv">0</span>]</span>
<span id="cb27-66"><a href="#cb27-66" aria-hidden="true" tabindex="-1"></a>    <span class="co"># YOLOv8 output shape: (1, 84, 8400) — [batch, 4+num_classes, anchors]</span></span>
<span id="cb27-67"><a href="#cb27-67" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-68"><a href="#cb27-68" aria-hidden="true" tabindex="-1"></a>    predictions <span class="op">=</span> raw_output[<span class="dv">0</span>].T    <span class="co"># (8400, 84)</span></span>
<span id="cb27-69"><a href="#cb27-69" aria-hidden="true" tabindex="-1"></a>    boxes_xywh  <span class="op">=</span> predictions[:, :<span class="dv">4</span>]</span>
<span id="cb27-70"><a href="#cb27-70" aria-hidden="true" tabindex="-1"></a>    class_scores <span class="op">=</span> predictions[:, <span class="dv">4</span>:]</span>
<span id="cb27-71"><a href="#cb27-71" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-72"><a href="#cb27-72" aria-hidden="true" tabindex="-1"></a>    confidences <span class="op">=</span> class_scores.<span class="bu">max</span>(axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb27-73"><a href="#cb27-73" aria-hidden="true" tabindex="-1"></a>    class_ids   <span class="op">=</span> class_scores.argmax(axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb27-74"><a href="#cb27-74" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-75"><a href="#cb27-75" aria-hidden="true" tabindex="-1"></a>    mask <span class="op">=</span> confidences <span class="op">&gt;=</span> conf_threshold</span>
<span id="cb27-76"><a href="#cb27-76" aria-hidden="true" tabindex="-1"></a>    boxes_xywh  <span class="op">=</span> boxes_xywh[mask]</span>
<span id="cb27-77"><a href="#cb27-77" aria-hidden="true" tabindex="-1"></a>    confidences <span class="op">=</span> confidences[mask]</span>
<span id="cb27-78"><a href="#cb27-78" aria-hidden="true" tabindex="-1"></a>    class_ids   <span class="op">=</span> class_ids[mask]</span>
<span id="cb27-79"><a href="#cb27-79" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-80"><a href="#cb27-80" aria-hidden="true" tabindex="-1"></a>    <span class="co"># xywh → xyxy</span></span>
<span id="cb27-81"><a href="#cb27-81" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy       <span class="op">=</span> np.empty_like(boxes_xywh)</span>
<span id="cb27-82"><a href="#cb27-82" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy[:, <span class="dv">0</span>] <span class="op">=</span> boxes_xywh[:, <span class="dv">0</span>] <span class="op">-</span> boxes_xywh[:, <span class="dv">2</span>] <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb27-83"><a href="#cb27-83" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy[:, <span class="dv">1</span>] <span class="op">=</span> boxes_xywh[:, <span class="dv">1</span>] <span class="op">-</span> boxes_xywh[:, <span class="dv">3</span>] <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb27-84"><a href="#cb27-84" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy[:, <span class="dv">2</span>] <span class="op">=</span> boxes_xywh[:, <span class="dv">0</span>] <span class="op">+</span> boxes_xywh[:, <span class="dv">2</span>] <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb27-85"><a href="#cb27-85" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy[:, <span class="dv">3</span>] <span class="op">=</span> boxes_xywh[:, <span class="dv">1</span>] <span class="op">+</span> boxes_xywh[:, <span class="dv">3</span>] <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb27-86"><a href="#cb27-86" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-87"><a href="#cb27-87" aria-hidden="true" tabindex="-1"></a>    keep <span class="op">=</span> nms(boxes_xyxy, confidences)</span>
<span id="cb27-88"><a href="#cb27-88" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Detected </span><span class="sc">{</span><span class="bu">len</span>(keep)<span class="sc">}</span><span class="ss"> objects"</span>)</span>
<span id="cb27-89"><a href="#cb27-89" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> boxes_xyxy[keep], confidences[keep], class_ids[keep]</span></code></pre></div></div>
</section>
<section id="semantic-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="semantic-segmentation" id="semantic-segmentation">Semantic Segmentation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="co"># segmentation_inference.py</span></span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> run_segmentation(model_path: <span class="bu">str</span>, image_path: <span class="bu">str</span>,</span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a>                     input_size: <span class="bu">tuple</span> <span class="op">=</span> (<span class="dv">512</span>, <span class="dv">512</span>),</span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a>                     num_classes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">21</span>):   <span class="co"># VOC Pascal classes</span></span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a>    sess <span class="op">=</span> ort.InferenceSession(model_path,</span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a>                                 providers<span class="op">=</span>[<span class="st">"CUDAExecutionProvider"</span>,</span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a>                                            <span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> cv2.imread(image_path)</span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a>    original_shape <span class="op">=</span> image.shape[:<span class="dv">2</span>]</span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Preprocess</span></span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a>    resized <span class="op">=</span> cv2.resize(image, input_size)</span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> resized[:, :, ::<span class="op">-</span><span class="dv">1</span>].astype(np.float32)   <span class="co"># BGR → RGB</span></span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> (blob <span class="op">/</span> <span class="fl">255.0</span> <span class="op">-</span> np.array([<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>])) <span class="op">\</span></span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a>         <span class="op">/</span> np.array([<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb28-21"><a href="#cb28-21" aria-hidden="true" tabindex="-1"></a>    blob <span class="op">=</span> np.transpose(blob, (<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">1</span>))[np.newaxis]   <span class="co"># (1, 3, H, W)</span></span>
<span id="cb28-22"><a href="#cb28-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-23"><a href="#cb28-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Inference</span></span>
<span id="cb28-24"><a href="#cb28-24" aria-hidden="true" tabindex="-1"></a>    input_name  <span class="op">=</span> sess.get_inputs()[<span class="dv">0</span>].name</span>
<span id="cb28-25"><a href="#cb28-25" aria-hidden="true" tabindex="-1"></a>    output      <span class="op">=</span> sess.run(<span class="va">None</span>, {input_name: blob})[<span class="dv">0</span>]</span>
<span id="cb28-26"><a href="#cb28-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># output shape: (1, num_classes, H, W)</span></span>
<span id="cb28-27"><a href="#cb28-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-28"><a href="#cb28-28" aria-hidden="true" tabindex="-1"></a>    seg_map <span class="op">=</span> output[<span class="dv">0</span>].argmax(axis<span class="op">=</span><span class="dv">0</span>).astype(np.uint8)   <span class="co"># (H, W)</span></span>
<span id="cb28-29"><a href="#cb28-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-30"><a href="#cb28-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Resize back to original</span></span>
<span id="cb28-31"><a href="#cb28-31" aria-hidden="true" tabindex="-1"></a>    seg_map <span class="op">=</span> cv2.resize(seg_map, (original_shape[<span class="dv">1</span>], original_shape[<span class="dv">0</span>]),</span>
<span id="cb28-32"><a href="#cb28-32" aria-hidden="true" tabindex="-1"></a>                          interpolation<span class="op">=</span>cv2.INTER_NEAREST)</span>
<span id="cb28-33"><a href="#cb28-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-34"><a href="#cb28-34" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Colorize for visualization</span></span>
<span id="cb28-35"><a href="#cb28-35" aria-hidden="true" tabindex="-1"></a>    palette <span class="op">=</span> np.random.randint(<span class="dv">0</span>, <span class="dv">255</span>, (num_classes, <span class="dv">3</span>), dtype<span class="op">=</span>np.uint8)</span>
<span id="cb28-36"><a href="#cb28-36" aria-hidden="true" tabindex="-1"></a>    colorized <span class="op">=</span> palette[seg_map]</span>
<span id="cb28-37"><a href="#cb28-37" aria-hidden="true" tabindex="-1"></a>    blended   <span class="op">=</span> cv2.addWeighted(image, <span class="fl">0.6</span>, colorized, <span class="fl">0.4</span>, <span class="dv">0</span>)</span>
<span id="cb28-38"><a href="#cb28-38" aria-hidden="true" tabindex="-1"></a>    cv2.imwrite(<span class="st">"segmentation_result.png"</span>, blended)</span>
<span id="cb28-39"><a href="#cb28-39" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Segmentation map shape: </span><span class="sc">{</span>seg_map<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-40"><a href="#cb28-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> seg_map</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="benchmarking-and-profiling" class="level2">
<h2 class="anchored" data-anchor-id="benchmarking-and-profiling" id="benchmarking-and-profiling">Benchmarking and Profiling</h2>
<section id="latency-and-throughput-benchmark" class="level3">
<h3 class="anchored" data-anchor-id="latency-and-throughput-benchmark" id="latency-and-throughput-benchmark">Latency and Throughput Benchmark</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="co"># benchmark.py</span></span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark(model_path: <span class="bu">str</span>,</span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a>              input_shape: <span class="bu">tuple</span> <span class="op">=</span> (<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>),</span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a>              warmup_runs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>,</span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a>              benchmark_runs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">200</span>,</span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a>              providers: <span class="bu">list</span> <span class="op">=</span> <span class="va">None</span>):</span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> providers <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a>        providers <span class="op">=</span> [<span class="st">"CPUExecutionProvider"</span>]</span>
<span id="cb29-14"><a href="#cb29-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-15"><a href="#cb29-15" aria-hidden="true" tabindex="-1"></a>    sess <span class="op">=</span> ort.InferenceSession(model_path, providers<span class="op">=</span>providers)</span>
<span id="cb29-16"><a href="#cb29-16" aria-hidden="true" tabindex="-1"></a>    input_name <span class="op">=</span> sess.get_inputs()[<span class="dv">0</span>].name</span>
<span id="cb29-17"><a href="#cb29-17" aria-hidden="true" tabindex="-1"></a>    dummy <span class="op">=</span> np.random.randn(<span class="op">*</span>input_shape).astype(np.float32)</span>
<span id="cb29-18"><a href="#cb29-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-19"><a href="#cb29-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warm-up</span></span>
<span id="cb29-20"><a href="#cb29-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(warmup_runs):</span>
<span id="cb29-21"><a href="#cb29-21" aria-hidden="true" tabindex="-1"></a>        sess.run(<span class="va">None</span>, {input_name: dummy})</span>
<span id="cb29-22"><a href="#cb29-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-23"><a href="#cb29-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Benchmark</span></span>
<span id="cb29-24"><a href="#cb29-24" aria-hidden="true" tabindex="-1"></a>    latencies <span class="op">=</span> []</span>
<span id="cb29-25"><a href="#cb29-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(benchmark_runs):</span>
<span id="cb29-26"><a href="#cb29-26" aria-hidden="true" tabindex="-1"></a>        t0 <span class="op">=</span> time.perf_counter()</span>
<span id="cb29-27"><a href="#cb29-27" aria-hidden="true" tabindex="-1"></a>        sess.run(<span class="va">None</span>, {input_name: dummy})</span>
<span id="cb29-28"><a href="#cb29-28" aria-hidden="true" tabindex="-1"></a>        latencies.append((time.perf_counter() <span class="op">-</span> t0) <span class="op">*</span> <span class="dv">1000</span>)   <span class="co"># ms</span></span>
<span id="cb29-29"><a href="#cb29-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-30"><a href="#cb29-30" aria-hidden="true" tabindex="-1"></a>    latencies <span class="op">=</span> np.array(latencies)</span>
<span id="cb29-31"><a href="#cb29-31" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> input_shape[<span class="dv">0</span>]</span>
<span id="cb29-32"><a href="#cb29-32" aria-hidden="true" tabindex="-1"></a>    fps <span class="op">=</span> batch_size <span class="op">/</span> (latencies.mean() <span class="op">/</span> <span class="dv">1000</span>)</span>
<span id="cb29-33"><a href="#cb29-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-34"><a href="#cb29-34" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Model: </span><span class="sc">{</span>model_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb29-35"><a href="#cb29-35" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Providers: </span><span class="sc">{</span>providers<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb29-36"><a href="#cb29-36" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Batch size: </span><span class="sc">{</span>batch_size<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb29-37"><a href="#cb29-37" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Latency — mean: </span><span class="sc">{</span>latencies<span class="sc">.</span>mean()<span class="sc">:.2f}</span><span class="ss"> ms  "</span></span>
<span id="cb29-38"><a href="#cb29-38" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"p50: </span><span class="sc">{</span>np<span class="sc">.</span>percentile(latencies, <span class="dv">50</span>)<span class="sc">:.2f}</span><span class="ss"> ms  "</span></span>
<span id="cb29-39"><a href="#cb29-39" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"p99: </span><span class="sc">{</span>np<span class="sc">.</span>percentile(latencies, <span class="dv">99</span>)<span class="sc">:.2f}</span><span class="ss"> ms"</span>)</span>
<span id="cb29-40"><a href="#cb29-40" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Throughput: </span><span class="sc">{</span>fps<span class="sc">:.1f}</span><span class="ss"> images/sec"</span>)</span>
<span id="cb29-41"><a href="#cb29-41" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> latencies</span>
<span id="cb29-42"><a href="#cb29-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-43"><a href="#cb29-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-44"><a href="#cb29-44" aria-hidden="true" tabindex="-1"></a><span class="co"># Compare FP32 vs INT8</span></span>
<span id="cb29-45"><a href="#cb29-45" aria-hidden="true" tabindex="-1"></a>benchmark(<span class="st">"resnet18_cifar10.onnx"</span>,          input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>))</span>
<span id="cb29-46"><a href="#cb29-46" aria-hidden="true" tabindex="-1"></a>benchmark(<span class="st">"resnet18_cifar10_int8.onnx"</span>,     input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>))</span>
<span id="cb29-47"><a href="#cb29-47" aria-hidden="true" tabindex="-1"></a>benchmark(<span class="st">"resnet18_cifar10.onnx"</span>,          input_shape<span class="op">=</span>(<span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>))  <span class="co"># batch=32</span></span></code></pre></div></div>
</section>
<section id="profiling-operator-timings" class="level3">
<h3 class="anchored" data-anchor-id="profiling-operator-timings" id="profiling-operator-timings">Profiling Operator Timings</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a>opts.enable_profiling <span class="op">=</span> <span class="va">True</span></span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a>opts.profile_file_prefix <span class="op">=</span> <span class="st">"ort_profile"</span></span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-8"><a href="#cb30-8" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb30-9"><a href="#cb30-9" aria-hidden="true" tabindex="-1"></a>                             sess_options<span class="op">=</span>opts,</span>
<span id="cb30-10"><a href="#cb30-10" aria-hidden="true" tabindex="-1"></a>                             providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb30-11"><a href="#cb30-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-12"><a href="#cb30-12" aria-hidden="true" tabindex="-1"></a>dummy <span class="op">=</span> np.random.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>).astype(np.float32)</span>
<span id="cb30-13"><a href="#cb30-13" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):</span>
<span id="cb30-14"><a href="#cb30-14" aria-hidden="true" tabindex="-1"></a>    sess.run(<span class="va">None</span>, {<span class="st">"images"</span>: dummy})</span>
<span id="cb30-15"><a href="#cb30-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-16"><a href="#cb30-16" aria-hidden="true" tabindex="-1"></a>profile_path <span class="op">=</span> sess.end_profiling()</span>
<span id="cb30-17"><a href="#cb30-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Profile saved to: </span><span class="sc">{</span>profile_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-18"><a href="#cb30-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-19"><a href="#cb30-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Load and inspect top-N slowest operators</span></span>
<span id="cb30-20"><a href="#cb30-20" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> <span class="bu">open</span>(profile_path) <span class="im">as</span> f:</span>
<span id="cb30-21"><a href="#cb30-21" aria-hidden="true" tabindex="-1"></a>    events <span class="op">=</span> json.load(f)</span>
<span id="cb30-22"><a href="#cb30-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-23"><a href="#cb30-23" aria-hidden="true" tabindex="-1"></a>op_events <span class="op">=</span> [e <span class="cf">for</span> e <span class="kw">in</span> events <span class="cf">if</span> e.get(<span class="st">"cat"</span>) <span class="op">==</span> <span class="st">"Node"</span>]</span>
<span id="cb30-24"><a href="#cb30-24" aria-hidden="true" tabindex="-1"></a>op_events.sort(key<span class="op">=</span><span class="kw">lambda</span> e: e[<span class="st">"dur"</span>], reverse<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb30-25"><a href="#cb30-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-26"><a href="#cb30-26" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">" Top-10 slowest operators (microseconds):"</span>)</span>
<span id="cb30-27"><a href="#cb30-27" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> ev <span class="kw">in</span> op_events[:<span class="dv">10</span>]:</span>
<span id="cb30-28"><a href="#cb30-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>ev[<span class="st">'name'</span>]<span class="sc">:&lt;50}</span><span class="ss"> </span><span class="sc">{</span>ev[<span class="st">'dur'</span>]<span class="sc">:&gt;8}</span><span class="ss"> µs"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="deploying-onnx-models" class="level2">
<h2 class="anchored" data-anchor-id="deploying-onnx-models" id="deploying-onnx-models">Deploying ONNX Models</h2>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Optimized ONNX Model"] --&gt; B{"Deployment Target?"}

    B --&gt; C["Cloud / Server"]
    B --&gt; D["Edge / Embedded"]
    B --&gt; E["Browser"]
    B --&gt; F["Mobile"]

    C --&gt; C1["FastAPI + ONNX Runtime CUDA or CPU EP"]

    D --&gt; D1{"Hardware?"}
    D1 --&gt; D2["NVIDIA Jetson CUDA EP / TensorRT EP"]
    D1 --&gt; D3["ARM CPU Raspberry Pi CPU EP + NEON"]
    D1 --&gt; D4["Intel CPU/VPU OpenVINO EP"]

    E --&gt; E1["onnxruntime-web WASM or WebGL"]

    F --&gt; F1{"Platform?"}
    F1 --&gt; F2["Android QNN EP / NNAPI"]
    F1 --&gt; F3["iOS / macOS CoreML EP"]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<section id="python-service-fastapi" class="level3">
<h3 class="anchored" data-anchor-id="python-service-fastapi" id="python-service-fastapi">Python Service (FastAPI)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="co"># app.py — production-ready FastAPI inference server</span></span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> fastapi <span class="im">import</span> FastAPI, File, UploadFile, HTTPException</span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> fastapi.responses <span class="im">import</span> JSONResponse</span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> asynccontextmanager</span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb31-9"><a href="#cb31-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb31-10"><a href="#cb31-10" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io, logging</span>
<span id="cb31-11"><a href="#cb31-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-12"><a href="#cb31-12" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb31-13"><a href="#cb31-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-14"><a href="#cb31-14" aria-hidden="true" tabindex="-1"></a>CLASSES <span class="op">=</span> [<span class="st">"airplane"</span>, <span class="st">"automobile"</span>, <span class="st">"bird"</span>, <span class="st">"cat"</span>, <span class="st">"deer"</span>,</span>
<span id="cb31-15"><a href="#cb31-15" aria-hidden="true" tabindex="-1"></a>           <span class="st">"dog"</span>, <span class="st">"frog"</span>, <span class="st">"horse"</span>, <span class="st">"ship"</span>, <span class="st">"truck"</span>]</span>
<span id="cb31-16"><a href="#cb31-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-17"><a href="#cb31-17" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> T.Compose([</span>
<span id="cb31-18"><a href="#cb31-18" aria-hidden="true" tabindex="-1"></a>    T.Resize((<span class="dv">32</span>, <span class="dv">32</span>)),</span>
<span id="cb31-19"><a href="#cb31-19" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb31-20"><a href="#cb31-20" aria-hidden="true" tabindex="-1"></a>    T.Normalize([<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>], [<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>]),</span>
<span id="cb31-21"><a href="#cb31-21" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb31-22"><a href="#cb31-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-23"><a href="#cb31-23" aria-hidden="true" tabindex="-1"></a><span class="co"># ── Global session (initialized once at startup) ──</span></span>
<span id="cb31-24"><a href="#cb31-24" aria-hidden="true" tabindex="-1"></a>session: ort.InferenceSession <span class="op">|</span> <span class="va">None</span> <span class="op">=</span> <span class="va">None</span></span>
<span id="cb31-25"><a href="#cb31-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-26"><a href="#cb31-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-27"><a href="#cb31-27" aria-hidden="true" tabindex="-1"></a><span class="at">@asynccontextmanager</span></span>
<span id="cb31-28"><a href="#cb31-28" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> lifespan(app: FastAPI):</span>
<span id="cb31-29"><a href="#cb31-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> session</span>
<span id="cb31-30"><a href="#cb31-30" aria-hidden="true" tabindex="-1"></a>    logger.info(<span class="st">"Loading ONNX model…"</span>)</span>
<span id="cb31-31"><a href="#cb31-31" aria-hidden="true" tabindex="-1"></a>    opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb31-32"><a href="#cb31-32" aria-hidden="true" tabindex="-1"></a>    opts.graph_optimization_level <span class="op">=</span> ort.GraphOptimizationLevel.ORT_ENABLE_ALL</span>
<span id="cb31-33"><a href="#cb31-33" aria-hidden="true" tabindex="-1"></a>    opts.intra_op_num_threads <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb31-34"><a href="#cb31-34" aria-hidden="true" tabindex="-1"></a>    session <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb31-35"><a href="#cb31-35" aria-hidden="true" tabindex="-1"></a>        <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb31-36"><a href="#cb31-36" aria-hidden="true" tabindex="-1"></a>        sess_options<span class="op">=</span>opts,</span>
<span id="cb31-37"><a href="#cb31-37" aria-hidden="true" tabindex="-1"></a>        providers<span class="op">=</span>[<span class="st">"CUDAExecutionProvider"</span>, <span class="st">"CPUExecutionProvider"</span>],</span>
<span id="cb31-38"><a href="#cb31-38" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb31-39"><a href="#cb31-39" aria-hidden="true" tabindex="-1"></a>    logger.info(<span class="st">"Model loaded ✓"</span>)</span>
<span id="cb31-40"><a href="#cb31-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">yield</span></span>
<span id="cb31-41"><a href="#cb31-41" aria-hidden="true" tabindex="-1"></a>    session <span class="op">=</span> <span class="va">None</span></span>
<span id="cb31-42"><a href="#cb31-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-43"><a href="#cb31-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-44"><a href="#cb31-44" aria-hidden="true" tabindex="-1"></a>app <span class="op">=</span> FastAPI(title<span class="op">=</span><span class="st">"CIFAR-10 Classifier"</span>, lifespan<span class="op">=</span>lifespan)</span>
<span id="cb31-45"><a href="#cb31-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-46"><a href="#cb31-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-47"><a href="#cb31-47" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> softmax(x: np.ndarray) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb31-48"><a href="#cb31-48" aria-hidden="true" tabindex="-1"></a>    e <span class="op">=</span> np.exp(x <span class="op">-</span> x.<span class="bu">max</span>())</span>
<span id="cb31-49"><a href="#cb31-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> e <span class="op">/</span> e.<span class="bu">sum</span>()</span>
<span id="cb31-50"><a href="#cb31-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-51"><a href="#cb31-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-52"><a href="#cb31-52" aria-hidden="true" tabindex="-1"></a><span class="at">@app.post</span>(<span class="st">"/predict"</span>)</span>
<span id="cb31-53"><a href="#cb31-53" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> predict(<span class="bu">file</span>: UploadFile <span class="op">=</span> File(...)):</span>
<span id="cb31-54"><a href="#cb31-54" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> session <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb31-55"><a href="#cb31-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">503</span>, detail<span class="op">=</span><span class="st">"Model not ready"</span>)</span>
<span id="cb31-56"><a href="#cb31-56" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-57"><a href="#cb31-57" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb31-58"><a href="#cb31-58" aria-hidden="true" tabindex="-1"></a>        data <span class="op">=</span> <span class="cf">await</span> <span class="bu">file</span>.read()</span>
<span id="cb31-59"><a href="#cb31-59" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(data)).convert(<span class="st">"RGB"</span>)</span>
<span id="cb31-60"><a href="#cb31-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> exc:</span>
<span id="cb31-61"><a href="#cb31-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">400</span>, detail<span class="op">=</span><span class="ss">f"Invalid image: </span><span class="sc">{</span>exc<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb31-62"><a href="#cb31-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-63"><a href="#cb31-63" aria-hidden="true" tabindex="-1"></a>    tensor <span class="op">=</span> transform(image).unsqueeze(<span class="dv">0</span>).numpy()</span>
<span id="cb31-64"><a href="#cb31-64" aria-hidden="true" tabindex="-1"></a>    logits <span class="op">=</span> session.run(<span class="va">None</span>, {<span class="st">"images"</span>: tensor})[<span class="dv">0</span>][<span class="dv">0</span>]</span>
<span id="cb31-65"><a href="#cb31-65" aria-hidden="true" tabindex="-1"></a>    probs  <span class="op">=</span> softmax(logits)</span>
<span id="cb31-66"><a href="#cb31-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-67"><a href="#cb31-67" aria-hidden="true" tabindex="-1"></a>    top5_idx  <span class="op">=</span> probs.argsort()[::<span class="op">-</span><span class="dv">1</span>][:<span class="dv">5</span>]</span>
<span id="cb31-68"><a href="#cb31-68" aria-hidden="true" tabindex="-1"></a>    top5 <span class="op">=</span> [{<span class="st">"class"</span>: CLASSES[i], <span class="st">"probability"</span>: <span class="bu">float</span>(probs[i])}</span>
<span id="cb31-69"><a href="#cb31-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> top5_idx]</span>
<span id="cb31-70"><a href="#cb31-70" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-71"><a href="#cb31-71" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> JSONResponse({<span class="st">"top5"</span>: top5, <span class="st">"predicted"</span>: CLASSES[top5_idx[<span class="dv">0</span>]]})</span>
<span id="cb31-72"><a href="#cb31-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-73"><a href="#cb31-73" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-74"><a href="#cb31-74" aria-hidden="true" tabindex="-1"></a><span class="co"># Run with: uvicorn app:app --host 0.0.0.0 --port 8000 --workers 1</span></span></code></pre></div></div>
</section>
<section id="edge-devices-and-mobile" class="level3">
<h3 class="anchored" data-anchor-id="edge-devices-and-mobile" id="edge-devices-and-mobile">Edge Devices and Mobile</h3>
<p>For edge deployment (Raspberry Pi, NVIDIA Jetson, Android, iOS), the recommended approach is:</p>
<ol type="1">
<li><p><strong>ARM CPU</strong>: Use <code>onnxruntime</code> Python package or the C/C++ shared library. ORT’s CPU provider is highly optimized via MLAS (Microsoft Linear Algebra Subprograms) and uses NEON intrinsics on ARM.</p></li>
<li><p><strong>NVIDIA Jetson</strong>: Install <code>onnxruntime-gpu</code> built for JetPack (ARM64 + CUDA). Alternatively, convert to TensorRT via the TensorRT EP.</p></li>
<li><p><strong>Qualcomm SoC (Android)</strong>: Use the QNN execution provider with <code>onnxruntime-android</code> AAR.</p></li>
<li><p><strong>Apple Silicon / iOS</strong>: Use <code>CoreMLExecutionProvider</code> (available on macOS 12+ / iOS 15+).</p></li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="co"># CoreML on Apple Silicon</span></span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"resnet18_cifar10.onnx"</span>,</span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[</span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a>        (<span class="st">"CoreMLExecutionProvider"</span>, {</span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">"MLComputeUnits"</span>: <span class="st">"ALL"</span>,   <span class="co"># CPU | GPU | NeuralEngine</span></span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>        }),</span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"CPUExecutionProvider"</span>,</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="onnx-runtime-web-browser" class="level3">
<h3 class="anchored" data-anchor-id="onnx-runtime-web-browser" id="onnx-runtime-web-browser">ONNX Runtime Web (Browser)</h3>
<p>ONNX Runtime Web (<code>onnxruntime-web</code>) runs ONNX models in a browser using WebAssembly (WASM) or WebGL.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb33"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="ex">npm</span> install onnxruntime-web</span></code></pre></div></div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb34"><pre class="sourceCode javascript code-with-copy"><code class="sourceCode javascript"><span id="cb34-1"><a href="#cb34-1" aria-hidden="true" tabindex="-1"></a><span class="co">// classifier.js</span></span>
<span id="cb34-2"><a href="#cb34-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> <span class="op">*</span> <span class="im">as</span> ort <span class="im">from</span> <span class="st">'onnxruntime-web'</span><span class="op">;</span></span>
<span id="cb34-3"><a href="#cb34-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-4"><a href="#cb34-4" aria-hidden="true" tabindex="-1"></a><span class="kw">async</span> <span class="kw">function</span> <span class="fu">classifyImage</span>(imageData) {</span>
<span id="cb34-5"><a href="#cb34-5" aria-hidden="true" tabindex="-1"></a>  <span class="co">// Load model once and cache the session</span></span>
<span id="cb34-6"><a href="#cb34-6" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> session <span class="op">=</span> <span class="cf">await</span> ort<span class="op">.</span><span class="at">InferenceSession</span><span class="op">.</span><span class="fu">create</span>(<span class="st">'./resnet18_cifar10.onnx'</span><span class="op">,</span> {</span>
<span id="cb34-7"><a href="#cb34-7" aria-hidden="true" tabindex="-1"></a>    <span class="dt">executionProviders</span><span class="op">:</span> [<span class="st">'webgl'</span>]<span class="op">,</span>   <span class="co">// or 'wasm' for CPU</span></span>
<span id="cb34-8"><a href="#cb34-8" aria-hidden="true" tabindex="-1"></a>    <span class="dt">graphOptimizationLevel</span><span class="op">:</span> <span class="st">'all'</span><span class="op">,</span></span>
<span id="cb34-9"><a href="#cb34-9" aria-hidden="true" tabindex="-1"></a>  })<span class="op">;</span></span>
<span id="cb34-10"><a href="#cb34-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-11"><a href="#cb34-11" aria-hidden="true" tabindex="-1"></a>  <span class="co">// Preprocess: imageData is a Float32Array of shape [1, 3, 32, 32]</span></span>
<span id="cb34-12"><a href="#cb34-12" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> tensor <span class="op">=</span> <span class="kw">new</span> ort<span class="op">.</span><span class="fu">Tensor</span>(<span class="st">'float32'</span><span class="op">,</span> imageData<span class="op">,</span> [<span class="dv">1</span><span class="op">,</span> <span class="dv">3</span><span class="op">,</span> <span class="dv">32</span><span class="op">,</span> <span class="dv">32</span>])<span class="op">;</span></span>
<span id="cb34-13"><a href="#cb34-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-14"><a href="#cb34-14" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> results <span class="op">=</span> <span class="cf">await</span> session<span class="op">.</span><span class="fu">run</span>({ <span class="dt">images</span><span class="op">:</span> tensor })<span class="op">;</span></span>
<span id="cb34-15"><a href="#cb34-15" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> logits  <span class="op">=</span> results[<span class="st">'logits'</span>]<span class="op">.</span><span class="at">data</span><span class="op">;</span>   <span class="co">// Float32Array of length 10</span></span>
<span id="cb34-16"><a href="#cb34-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-17"><a href="#cb34-17" aria-hidden="true" tabindex="-1"></a>  <span class="co">// Softmax and argmax</span></span>
<span id="cb34-18"><a href="#cb34-18" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> max  <span class="op">=</span> <span class="bu">Math</span><span class="op">.</span><span class="fu">max</span>(<span class="op">...</span>logits)<span class="op">;</span></span>
<span id="cb34-19"><a href="#cb34-19" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> exps <span class="op">=</span> logits<span class="op">.</span><span class="fu">map</span>(v <span class="kw">=&gt;</span> <span class="bu">Math</span><span class="op">.</span><span class="fu">exp</span>(v <span class="op">-</span> max))<span class="op">;</span></span>
<span id="cb34-20"><a href="#cb34-20" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> sum  <span class="op">=</span> exps<span class="op">.</span><span class="fu">reduce</span>((a<span class="op">,</span> b) <span class="kw">=&gt;</span> a <span class="op">+</span> b<span class="op">,</span> <span class="dv">0</span>)<span class="op">;</span></span>
<span id="cb34-21"><a href="#cb34-21" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> probs <span class="op">=</span> exps<span class="op">.</span><span class="fu">map</span>(v <span class="kw">=&gt;</span> v <span class="op">/</span> sum)<span class="op">;</span></span>
<span id="cb34-22"><a href="#cb34-22" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> classIdx <span class="op">=</span> probs<span class="op">.</span><span class="fu">indexOf</span>(<span class="bu">Math</span><span class="op">.</span><span class="fu">max</span>(<span class="op">...</span>probs))<span class="op">;</span></span>
<span id="cb34-23"><a href="#cb34-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-24"><a href="#cb34-24" aria-hidden="true" tabindex="-1"></a>  <span class="kw">const</span> CLASSES <span class="op">=</span> [<span class="st">'airplane'</span><span class="op">,</span><span class="st">'automobile'</span><span class="op">,</span><span class="st">'bird'</span><span class="op">,</span><span class="st">'cat'</span><span class="op">,</span><span class="st">'deer'</span><span class="op">,</span></span>
<span id="cb34-25"><a href="#cb34-25" aria-hidden="true" tabindex="-1"></a>                   <span class="st">'dog'</span><span class="op">,</span><span class="st">'frog'</span><span class="op">,</span><span class="st">'horse'</span><span class="op">,</span><span class="st">'ship'</span><span class="op">,</span><span class="st">'truck'</span>]<span class="op">;</span></span>
<span id="cb34-26"><a href="#cb34-26" aria-hidden="true" tabindex="-1"></a>  <span class="bu">console</span><span class="op">.</span><span class="fu">log</span>(<span class="vs">`Predicted: </span><span class="sc">${</span>CLASSES[classIdx]<span class="sc">}</span><span class="vs"> (</span><span class="sc">${</span>(probs[classIdx]<span class="op">*</span><span class="dv">100</span>)<span class="op">.</span><span class="fu">toFixed</span>(<span class="dv">1</span>)<span class="sc">}</span><span class="vs">%)`</span>)<span class="op">;</span></span>
<span id="cb34-27"><a href="#cb34-27" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="common-pitfalls-and-troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="common-pitfalls-and-troubleshooting" id="common-pitfalls-and-troubleshooting">Common Pitfalls and Troubleshooting</h2>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A["Inference produces bad results or crashes"] --&gt; B{"Error type?"}

    B --&gt; C["Inconsistent predictions low accuracy"]
    B --&gt; D["InvalidGraph: opset version error"]
    B --&gt; E["Shape mismatch at runtime"]
    B --&gt; F["Unregistered op or custom op error"]
    B --&gt; G["Quantization fails shape undefined"]
    B --&gt; H["NHWC/NCHW garbage output"]
    B --&gt; I["Slow first call only"]

    C --&gt; C1["Fix: call model.eval() before export"]
    D --&gt; D1["Fix: lower opset_version to 17 or 18"]
    E --&gt; E1["Fix: add dynamic_axes to export call"]
    F --&gt; F1["Fix: rewrite with standard ONNX primitives or register custom ORT kernel"]
    G --&gt; G1["Fix: run shape_inference before quantize_static"]
    H --&gt; H1["Fix: transpose input NCHW ↔ NHWC"]
    I --&gt; I1["Fix: add warm-up inference calls"]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<section id="model-not-in-eval-mode-before-export" class="level3">
<h3 class="anchored" data-anchor-id="model-not-in-eval-mode-before-export" id="model-not-in-eval-mode-before-export">1. Model Not in Eval Mode Before Export</h3>
<p><strong>Symptom</strong>: Predictions are wildly inconsistent; accuracy in ORT is lower than during training.</p>
<p><strong>Cause</strong>: <code>torch.nn.Dropout</code> is active in training mode, randomly zeroing out activations. <code>BatchNorm</code> uses running stats in eval mode but batch stats in train mode.</p>
<p><strong>Fix</strong>: Always call <code>model.eval()</code> before <code>torch.onnx.export()</code>.</p>
<hr>
</section>
<section id="opset-mismatch" class="level3">
<h3 class="anchored" data-anchor-id="opset-mismatch" id="opset-mismatch">2. Opset Mismatch</h3>
<p><strong>Symptom</strong>: <code>onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: Node has unsupported opset version</code>.</p>
<p><strong>Cause</strong>: Exporting to a higher opset than ORT supports, or using operators only available in newer opsets.</p>
<p><strong>Fix</strong>: Check <code>ort.get_available_providers()</code> and consult the <a href="https://onnxruntime.ai/docs/reference/compatibility.html">ORT opset compatibility matrix</a>. Use <code>opset_version=17</code> or <code>18</code> for broad compatibility.</p>
<hr>
</section>
<section id="fixed-batch-size-in-the-model" class="level3">
<h3 class="anchored" data-anchor-id="fixed-batch-size-in-the-model" id="fixed-batch-size-in-the-model">3. Fixed Batch Size in the Model</h3>
<p><strong>Symptom</strong>: Running a batch of 8 images on a model exported with <code>dummy_input = torch.randn(1, 3, 32, 32)</code> fails with a shape error.</p>
<p><strong>Fix</strong>: Use <code>dynamic_axes</code> when exporting:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(</span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a>    model, dummy_input, path,</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a>    dynamic_axes<span class="op">=</span>{<span class="st">"images"</span>: {<span class="dv">0</span>: <span class="st">"batch"</span>}, <span class="st">"logits"</span>: {<span class="dv">0</span>: <span class="st">"batch"</span>}},</span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
<section id="custom-unsupported-operators" class="level3">
<h3 class="anchored" data-anchor-id="custom-unsupported-operators" id="custom-unsupported-operators">4. Custom / Unsupported Operators</h3>
<p><strong>Symptom</strong>: <code>onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: Node ... is not a registered function/op</code>.</p>
<p><strong>Cause</strong>: Your model uses a PyTorch custom op or a very new operator not yet in ORT.</p>
<p><strong>Fix</strong>: Either rewrite the custom op using standard ONNX primitives, or implement a <a href="https://onnxruntime.ai/docs/reference/operators/add-custom-op.html">custom ONNX Runtime operator</a> in C++.</p>
<hr>
</section>
<section id="shape-inference-failures-in-quantization" class="level3">
<h3 class="anchored" data-anchor-id="shape-inference-failures-in-quantization" id="shape-inference-failures-in-quantization">5. Shape Inference Failures in Quantization</h3>
<p><strong>Symptom</strong>: Quantization fails with <code>ValueError: Shape of input is not fully defined</code>.</p>
<p><strong>Fix</strong>: Run shape inference before quantization:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb36"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb36-1"><a href="#cb36-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.quantization <span class="im">import</span> shape_inference</span>
<span id="cb36-2"><a href="#cb36-2" aria-hidden="true" tabindex="-1"></a>shape_inference.quant_pre_process(<span class="st">"model.onnx"</span>, <span class="st">"model_inferred.onnx"</span>)</span>
<span id="cb36-3"><a href="#cb36-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Then quantize model_inferred.onnx</span></span></code></pre></div></div>
<hr>
</section>
<section id="memory-layout-mismatches-nhwc-vs.-nchw" class="level3">
<h3 class="anchored" data-anchor-id="memory-layout-mismatches-nhwc-vs.-nchw" id="memory-layout-mismatches-nhwc-vs.-nchw">6. Memory Layout Mismatches (NHWC vs.&nbsp;NCHW)</h3>
<p><strong>Symptom</strong>: Garbage outputs when running TensorFlow-exported models.</p>
<p><strong>Cause</strong>: TensorFlow uses NHWC (batch, height, width, channels) by default. PyTorch uses NCHW. The export may not reorder axes correctly.</p>
<p><strong>Fix</strong>: Explicitly transpose your NumPy array before feeding it to ORT, or add a Transpose node to the ONNX graph.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TF model exported with NHWC — transpose input accordingly</span></span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a>image_nhwc <span class="op">=</span> image_nchw.transpose(<span class="dv">0</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">1</span>)   <span class="co"># NCHW → NHWC</span></span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a>sess.run(<span class="va">None</span>, {<span class="st">"input"</span>: image_nhwc})</span></code></pre></div></div>
<hr>
</section>
<section id="slow-cold-start" class="level3">
<h3 class="anchored" data-anchor-id="slow-cold-start" id="slow-cold-start">7. Slow Cold-Start</h3>
<p><strong>Symptom</strong>: First inference call takes 500ms, subsequent calls are fast.</p>
<p><strong>Cause</strong>: ORT performs JIT compilation, kernel selection, and memory arena allocation on the first run.</p>
<p><strong>Fix</strong>: Run several warm-up inferences before measuring latency or serving traffic.</p>
<hr>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="dynamic-axes-and-variable-batch-sizes" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-axes-and-variable-batch-sizes" id="dynamic-axes-and-variable-batch-sizes">Dynamic Axes and Variable Batch Sizes</h3>
<p>Marking axes as symbolic allows one exported model to handle any input size:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb38"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb38-1"><a href="#cb38-1" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(</span>
<span id="cb38-2"><a href="#cb38-2" aria-hidden="true" tabindex="-1"></a>    model, dummy_input, <span class="st">"model.onnx"</span>,</span>
<span id="cb38-3"><a href="#cb38-3" aria-hidden="true" tabindex="-1"></a>    dynamic_axes<span class="op">=</span>{</span>
<span id="cb38-4"><a href="#cb38-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">"images"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>, <span class="dv">2</span>: <span class="st">"height"</span>, <span class="dv">3</span>: <span class="st">"width"</span>},</span>
<span id="cb38-5"><a href="#cb38-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"logits"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>},</span>
<span id="cb38-6"><a href="#cb38-6" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb38-7"><a href="#cb38-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>At inference time:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Works for any batch size and any spatial resolution</span></span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>sess.run(<span class="va">None</span>, {<span class="st">"images"</span>: np.random.randn(<span class="dv">16</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>).astype(np.float32)})</span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a>sess.run(<span class="va">None</span>, {<span class="st">"images"</span>: np.random.randn(<span class="dv">1</span>,  <span class="dv">3</span>, <span class="dv">512</span>, <span class="dv">512</span>).astype(np.float32)})</span></code></pre></div></div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Warning
</div>
</div>
<div class="callout-body-container callout-body">
<p>Not all models tolerate fully dynamic spatial axes. Models with fixed positional embeddings (e.g., ViT) may require a fixed spatial resolution and will silently produce wrong results if given an unexpected image size.</p>
</div>
</div>
</section>
<section id="custom-operators" class="level3">
<h3 class="anchored" data-anchor-id="custom-operators" id="custom-operators">Custom Operators</h3>
<p>If your model uses a custom PyTorch operator, you can register a corresponding ONNX operator and ORT kernel:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb40"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><a href="#cb40-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Step 1: Register a custom symbolic function in PyTorch</span></span>
<span id="cb40-2"><a href="#cb40-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.onnx <span class="im">import</span> register_custom_op_symbolic</span>
<span id="cb40-3"><a href="#cb40-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-4"><a href="#cb40-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> my_custom_op_symbolic(g, <span class="bu">input</span>, weight):</span>
<span id="cb40-5"><a href="#cb40-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> g.op(<span class="st">"custom_domain::MyCustomOp"</span>, <span class="bu">input</span>, weight)</span>
<span id="cb40-6"><a href="#cb40-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-7"><a href="#cb40-7" aria-hidden="true" tabindex="-1"></a>register_custom_op_symbolic(<span class="st">"my_package::my_custom_op"</span>, my_custom_op_symbolic, <span class="dv">1</span>)</span>
<span id="cb40-8"><a href="#cb40-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-9"><a href="#cb40-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Step 2: Implement the ORT kernel in C++ and compile as a shared library</span></span>
<span id="cb40-10"><a href="#cb40-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Step 3: Load the custom op library in ORT</span></span>
<span id="cb40-11"><a href="#cb40-11" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb40-12"><a href="#cb40-12" aria-hidden="true" tabindex="-1"></a>opts.register_custom_ops_library(<span class="st">"./libmy_custom_ops.so"</span>)</span>
<span id="cb40-13"><a href="#cb40-13" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"model_with_custom_op.onnx"</span>, sess_options<span class="op">=</span>opts)</span></code></pre></div></div>
</section>
<section id="onnx-training-api" class="level3">
<h3 class="anchored" data-anchor-id="onnx-training-api" id="onnx-training-api">ONNX Training API</h3>
<p>ONNX Runtime has an experimental <strong>Training API</strong> that allows on-device fine-tuning without a separate framework dependency — useful for federated learning and on-device personalization.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Experimental — requires onnxruntime-training package</span></span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.training <span class="im">import</span> api <span class="im">as</span> orttraining</span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Export training artifacts from PyTorch</span></span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.training.ortmodule <span class="im">import</span> ORTModule</span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb41-7"><a href="#cb41-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-8"><a href="#cb41-8" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> YourModel()</span>
<span id="cb41-9"><a href="#cb41-9" aria-hidden="true" tabindex="-1"></a>ort_model <span class="op">=</span> ORTModule(model)   <span class="co"># wraps the model; ORT handles the backward pass</span></span>
<span id="cb41-10"><a href="#cb41-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-11"><a href="#cb41-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop is identical to standard PyTorch</span></span>
<span id="cb41-12"><a href="#cb41-12" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(ort_model.parameters())</span>
<span id="cb41-13"><a href="#cb41-13" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> images, labels <span class="kw">in</span> train_loader:</span>
<span id="cb41-14"><a href="#cb41-14" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> ort_model(images)</span>
<span id="cb41-15"><a href="#cb41-15" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb41-16"><a href="#cb41-16" aria-hidden="true" tabindex="-1"></a>    loss.backward()</span>
<span id="cb41-17"><a href="#cb41-17" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span>
<span id="cb41-18"><a href="#cb41-18" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="summary-and-best-practices" class="level2">
<h2 class="anchored" data-anchor-id="summary-and-best-practices" id="summary-and-best-practices">Summary and Best Practices</h2>
<section id="workflow-checklist" class="level3">
<h3 class="anchored" data-anchor-id="workflow-checklist" id="workflow-checklist">Workflow Checklist</h3>
<ul class="task-list">
<li><label><input type="checkbox">Train your model with standard augmentation and regularization</label></li>
<li><label><input type="checkbox">Set <code>model.eval()</code> before any export</label></li>
<li><label><input type="checkbox">Export with <code>dynamic_axes</code> for the batch dimension (and optionally spatial dimensions)</label></li>
<li><label><input type="checkbox">Validate the exported model with <code>onnx.checker.check_model(..., full_check=True)</code></label></li>
<li><label><input type="checkbox">Numerically compare ORT output to the source framework output (max diff &lt; 1e-4)</label></li>
<li><label><input type="checkbox">Run shape inference to annotate intermediate tensor shapes</label></li>
<li><label><input type="checkbox">Apply graph optimization (<code>ORT_ENABLE_ALL</code>)</label></li>
<li><label><input type="checkbox">For resource-constrained targets: apply static INT8 quantization with a representative calibration set</label></li>
<li><label><input type="checkbox">Benchmark on target hardware with realistic batch sizes and input resolutions</label></li>
<li><label><input type="checkbox">Profile operator timings to identify bottlenecks</label></li>
<li><label><input type="checkbox">Always warm up the session before benchmarking or serving</label></li>
</ul>
</section>
<section id="optimization-priority" class="level3">
<h3 class="anchored" data-anchor-id="optimization-priority" id="optimization-priority">Optimization Priority</h3>
<p>The most impactful optimizations, roughly ordered by return on investment:</p>
<ol type="1">
<li><strong>Graph optimization</strong> — free, always apply</li>
<li><strong>INT8 quantization</strong> — 2–4× speedup, minimal accuracy loss with careful calibration</li>
<li><strong>Batching</strong> — dramatically increases GPU utilization</li>
<li><strong>IO binding</strong> — eliminates host↔︎device copies for GPU workloads</li>
<li><strong>TensorRT EP</strong> — maximum throughput on NVIDIA hardware</li>
<li><strong>Structured pruning</strong> — reduces model FLOPs before export</li>
</ol>
<p>With these tools in hand, a model trained on a research GPU cluster can be reliably deployed on everything from a server rack to a Raspberry Pi to a browser tab.</p>
<hr>
<div class="callout callout-style-default callout-note callout-titled" title="Version Compatibility">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Version Compatibility
</div>
</div>
<div class="callout-body-container callout-body">
<p>Guide written for ONNX opset 18, ONNX Runtime 1.18+, PyTorch 2.3+, and TensorFlow 2.16+. API details may change in future releases — always consult the <a href="https://onnxruntime.ai/docs/">official ONNX Runtime documentation</a> for the latest.</p>
</div>
</div>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Building Neural Network Architectures Using Only ONNX]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/onnx-model/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/onnx-model/</guid>
      <pubDate>Tue, 19 May 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="building-neural-network-architectures-using-only-onnx" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/onnx-model/onnx.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>ONNX (Open Neural Network Exchange) is most commonly known as an export target — a format you dump a PyTorch or TensorFlow model into for deployment. But ONNX is also a fully self-contained, expressive intermediate representation that you can build directly, without ever touching a training framework. This guide treats ONNX as a first-class construction language, not a second-class export artifact.</p>
<p>Why would you want to build networks directly in ONNX?</p>
<p><strong>Portability without framework lock-in.</strong> An ONNX graph runs on any hardware backend that supports ONNX Runtime: CPU, CUDA, DirectML, TensorRT, OpenVINO, CoreML, and more. If you define your architecture in ONNX directly, there is no intermediate framework to install or version-pin.</p>
<p><strong>Deterministic, inspectable graphs.</strong> When you export from PyTorch, the resulting graph depends on tracing or scripting heuristics that can produce surprising operator sequences. When you write ONNX directly, you know exactly what every node does.</p>
<p><strong>Extremely lightweight deployments.</strong> ONNX + ONNX Runtime is a tiny dependency footprint compared to PyTorch or TensorFlow. For embedded systems, edge devices, or serverless inference, this matters enormously.</p>
<p><strong>Fine-grained graph surgery.</strong> If you need to fuse operators, insert quantization nodes, rewire connections, or experiment with non-standard topologies, working at the ONNX level directly gives you exact control with no framework abstractions in the way.</p>
<p><strong>Learning how neural networks really work.</strong> Building an architecture from raw matrix multiply and activation nodes forces you to understand every dimension, every weight layout, every broadcasting rule. It is an excellent exercise and deeply illuminating.</p>
<p>This guide assumes basic Python proficiency and some familiarity with neural network concepts (layers, activations, convolutions). It does not assume you have ever used PyTorch or TensorFlow.</p>
</section>
<section id="understanding-the-onnx-format" class="level2">
<h2 class="anchored" data-anchor-id="understanding-the-onnx-format" id="understanding-the-onnx-format">Understanding the ONNX Format</h2>
<p>An ONNX model is a serialized <a href="https://protobuf.dev/">Protocol Buffer</a> file. The <code>.onnx</code> extension is standard, but the file is just a binary proto. The schema is defined in the <a href="https://github.com/onnx/onnx/blob/main/onnx/onnx.proto">ONNX specification</a>.</p>
<p>At the top level, a <code>ModelProto</code> contains:</p>
<ul>
<li><strong><code>ir_version</code></strong>: The ONNX IR (Intermediate Representation) version.</li>
<li><strong><code>opset_imports</code></strong>: Which operator sets (and which versions of them) this model uses. Most models use the default <code>""</code> domain with a version like <code>17</code> or <code>21</code>.</li>
<li><strong><code>graph</code></strong>: A <code>GraphProto</code> — the actual computation graph.</li>
<li><strong><code>producer_name</code></strong>, <strong><code>producer_version</code></strong>, <strong><code>domain</code></strong>, <strong><code>model_version</code></strong>, <strong><code>doc_string</code></strong>: Metadata fields.</li>
</ul>
<p>The <code>GraphProto</code> contains:</p>
<ul>
<li><strong><code>node</code></strong>: A list of <code>NodeProto</code> objects. Each node is one operation.</li>
<li><strong><code>initializer</code></strong>: A list of <code>TensorProto</code> objects representing constant tensors — weights, biases, embedding tables, etc.</li>
<li><strong><code>input</code></strong>: A list of <code>ValueInfoProto</code> describing the graph’s external inputs (their names, types, and shapes).</li>
<li><strong><code>output</code></strong>: A list of <code>ValueInfoProto</code> describing the graph’s outputs.</li>
</ul>
<p>Each <code>NodeProto</code> contains:</p>
<ul>
<li><strong><code>op_type</code></strong>: The name of the operator, e.g., <code>"Gemm"</code>, <code>"Conv"</code>, <code>"Relu"</code>.</li>
<li><strong><code>domain</code></strong>: Usually <code>""</code> for standard ONNX ops.</li>
<li><strong><code>input</code></strong>: A list of string names — the tensors this node consumes.</li>
<li><strong><code>output</code></strong>: A list of string names — the tensors this node produces.</li>
<li><strong><code>attribute</code></strong>: A list of <code>AttributeProto</code> objects — static hyperparameters like kernel size, axis, epsilon, etc.</li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Tensor names are just strings. They act as edges in the dataflow graph. If node A produces an output named <code>"relu_out"</code> and node B lists <code>"relu_out"</code> as an input, then B receives A’s output. This is the complete wiring mechanism.</p>
</div>
</div>
</section>
<section id="the-onnx-protobuf-schema" class="level2">
<h2 class="anchored" data-anchor-id="the-onnx-protobuf-schema" id="the-onnx-protobuf-schema">The ONNX Protobuf Schema</h2>
<p>You do not need to write raw protobuf. The <code>onnx</code> Python package provides a rich helper library (<code>onnx.helper</code>, <code>onnx.numpy_helper</code>, <code>onnx.checker</code>) that builds proto objects for you. However, understanding the schema directly will save you many hours of debugging.</p>
<section id="tensorproto-data-types" class="level3">
<h3 class="anchored" data-anchor-id="tensorproto-data-types" id="tensorproto-data-types">TensorProto Data Types</h3>
<p>Every tensor in ONNX has an element type, encoded as an integer enum:</p>
<table class="caption-top table">
<caption>ONNX TensorProto data type enum values</caption>
<thead>
<tr class="header">
<th>Enum Value</th>
<th>Name</th>
<th>Python/NumPy Equivalent</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>1</td>
<td><code>FLOAT</code></td>
<td><code>np.float32</code></td>
</tr>
<tr class="even">
<td>2</td>
<td><code>UINT8</code></td>
<td><code>np.uint8</code></td>
</tr>
<tr class="odd">
<td>3</td>
<td><code>INT8</code></td>
<td><code>np.int8</code></td>
</tr>
<tr class="even">
<td>4</td>
<td><code>UINT16</code></td>
<td><code>np.uint16</code></td>
</tr>
<tr class="odd">
<td>5</td>
<td><code>INT16</code></td>
<td><code>np.int16</code></td>
</tr>
<tr class="even">
<td>6</td>
<td><code>INT32</code></td>
<td><code>np.int32</code></td>
</tr>
<tr class="odd">
<td>7</td>
<td><code>INT64</code></td>
<td><code>np.int64</code></td>
</tr>
<tr class="even">
<td>8</td>
<td><code>STRING</code></td>
<td><code>bytes</code></td>
</tr>
<tr class="odd">
<td>9</td>
<td><code>BOOL</code></td>
<td><code>np.bool_</code></td>
</tr>
<tr class="even">
<td>10</td>
<td><code>FLOAT16</code></td>
<td><code>np.float16</code></td>
</tr>
<tr class="odd">
<td>11</td>
<td><code>DOUBLE</code></td>
<td><code>np.float64</code></td>
</tr>
<tr class="even">
<td>12</td>
<td><code>UINT32</code></td>
<td><code>np.uint32</code></td>
</tr>
<tr class="odd">
<td>13</td>
<td><code>UINT64</code></td>
<td><code>np.uint64</code></td>
</tr>
<tr class="even">
<td>14</td>
<td><code>COMPLEX64</code></td>
<td><code>np.complex64</code></td>
</tr>
<tr class="odd">
<td>15</td>
<td><code>COMPLEX128</code></td>
<td><code>np.complex128</code></td>
</tr>
<tr class="even">
<td>16</td>
<td><code>BFLOAT16</code></td>
<td>N/A (custom)</td>
</tr>
</tbody>
</table>
<p>You reference these via <code>onnx.TensorProto.FLOAT</code>, <code>onnx.TensorProto.INT64</code>, etc.</p>
</section>
<section id="valueinfoproto-and-shape" class="level3">
<h3 class="anchored" data-anchor-id="valueinfoproto-and-shape" id="valueinfoproto-and-shape">ValueInfoProto and Shape</h3>
<p>Inputs and outputs are described by <code>ValueInfoProto</code>, which pairs a name with a type. The type is a <code>TypeProto</code>, which for tensors includes the element type and a shape. Shapes can have:</p>
<ul>
<li><strong>Fixed dimensions</strong>: <code>dim_value = 4</code> means exactly 4 elements on that axis.</li>
<li><strong>Symbolic dimensions</strong>: <code>dim_param = "batch_size"</code> means the dimension is variable but named. ONNX Runtime will accept any runtime value for it.</li>
<li><strong>Fully unknown dimensions</strong>: Neither <code>dim_value</code> nor <code>dim_param</code> is set — completely dynamic.</li>
</ul>
</section>
</section>
<section id="setting-up-your-environment" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-your-environment" id="setting-up-your-environment">Setting Up Your Environment</h2>
<p>You need very few packages:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnx onnxruntime numpy</span></code></pre></div></div>
<p>For visualization (highly recommended):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install netron</span></code></pre></div></div>
<p>Netron is a browser-based ONNX graph visualizer. You open a <code>.onnx</code> file in it and see the full computation graph rendered as a node diagram, with attributes, shapes, and connections all visible.</p>
<p>Verify your installation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"ONNX version: </span><span class="sc">{</span>onnx<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"ONNX IR version: </span><span class="sc">{</span>onnx<span class="sc">.</span>IR_VERSION<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"ONNX Runtime version: </span><span class="sc">{</span>ort<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="core-building-blocks-onnx-operators" class="level2">
<h2 class="anchored" data-anchor-id="core-building-blocks-onnx-operators" id="core-building-blocks-onnx-operators">Core Building Blocks: ONNX Operators</h2>
<p>ONNX defines a large standard operator set. Here are the operators you will use most often when building architectures from scratch.</p>
<section id="linear-algebra" class="level3">
<h3 class="anchored" data-anchor-id="linear-algebra" id="linear-algebra">Linear Algebra</h3>
<p><strong><code>Gemm</code></strong> (General Matrix Multiply): Computes <code>alpha * A @ B + beta * C</code>. This is the workhorse for fully connected layers. Attributes include <code>transA</code>, <code>transB</code>, <code>alpha</code>, <code>beta</code>. The <code>C</code> input (bias) is optional.</p>
<p><strong><code>MatMul</code></strong>: Computes a simple matrix product <code>A @ B</code>, with numpy-style broadcasting for batched inputs. Has no attributes. Use this when you need raw matmul without the alpha/beta scaling of Gemm.</p>
<p><strong><code>Add</code></strong>, <strong><code>Sub</code></strong>, <strong><code>Mul</code></strong>, <strong><code>Div</code></strong>: Element-wise arithmetic with broadcasting.</p>
<p><strong><code>Transpose</code></strong>: Permutes axes. The <code>perm</code> attribute lists the new axis order, e.g., <code>perm=[0, 2, 1]</code> for a batch-first transpose of a 3D tensor.</p>
</section>
<section id="activations" class="level3">
<h3 class="anchored" data-anchor-id="activations" id="activations">Activations</h3>
<p><strong><code>Relu</code></strong>: Element-wise <code>max(0, x)</code>. No attributes.</p>
<p><strong><code>Sigmoid</code></strong>: Element-wise <code>1 / (1 + exp(-x))</code>. No attributes.</p>
<p><strong><code>Tanh</code></strong>: Element-wise hyperbolic tangent. No attributes.</p>
<p><strong><code>Gelu</code></strong>: Gaussian Error Linear Unit. Available in newer opsets.</p>
<p><strong><code>Softmax</code></strong>: Softmax along a specified <code>axis</code>. Default axis is <code>-1</code>.</p>
<p><strong><code>LeakyRelu</code></strong>: <code>max(alpha * x, x)</code> with a configurable <code>alpha</code> attribute (default 0.01).</p>
<p><strong><code>Elu</code></strong>: Exponential Linear Unit. Attribute: <code>alpha</code>.</p>
</section>
<section id="normalization" class="level3">
<h3 class="anchored" data-anchor-id="normalization" id="normalization">Normalization</h3>
<p><strong><code>BatchNormalization</code></strong>: Normalizes inputs across the batch dimension, then scales and shifts with learnable <code>scale</code> and <code>B</code> (bias) parameters, using running <code>mean</code> and <code>var</code> statistics. Has <code>epsilon</code> and <code>momentum</code> attributes. In inference mode (the default in ONNX), it uses the stored running statistics and has only one output.</p>
<p><strong><code>LayerNormalization</code></strong>: Normalizes across a specified set of axes (usually the last). Introduced in opset 17. Essential for Transformer architectures.</p>
<p><strong><code>InstanceNormalization</code></strong>: Normalizes per-channel per-sample. Useful for style transfer networks.</p>
</section>
<section id="convolutions" class="level3">
<h3 class="anchored" data-anchor-id="convolutions" id="convolutions">Convolutions</h3>
<p><strong><code>Conv</code></strong>: N-dimensional convolution. Key attributes: <code>kernel_shape</code>, <code>strides</code>, <code>pads</code>, <code>dilations</code>, <code>group</code> (for grouped/depthwise convolutions), <code>auto_pad</code>. Inputs: <code>X</code> (data), <code>W</code> (weights), <code>B</code> (bias, optional).</p>
<p><strong><code>ConvTranspose</code></strong>: Transposed (fractionally-strided) convolution for upsampling. Same attribute set as <code>Conv</code> plus <code>output_padding</code>.</p>
<p><strong><code>MaxPool</code></strong>, <strong><code>AveragePool</code></strong>: Pooling with <code>kernel_shape</code>, <code>strides</code>, <code>pads</code>.</p>
<p><strong><code>GlobalAveragePool</code></strong>, <strong><code>GlobalMaxPool</code></strong>: Reduce each spatial map to a single value.</p>
</section>
<section id="recurrence" class="level3">
<h3 class="anchored" data-anchor-id="recurrence" id="recurrence">Recurrence</h3>
<p><strong><code>LSTM</code></strong>: Full Long Short-Term Memory cell. Inputs: <code>X</code>, <code>W</code>, <code>R</code>, <code>B</code>, <code>sequence_lens</code>, <code>initial_h</code>, <code>initial_c</code>, <code>P</code>. Attributes: <code>hidden_size</code>, <code>direction</code> (<code>forward</code>, <code>reverse</code>, <code>bidirectional</code>).</p>
<p><strong><code>GRU</code></strong>: Gated Recurrent Unit. Similar interface to LSTM.</p>
<p><strong><code>RNN</code></strong>: Simple Elman RNN.</p>
</section>
<section id="shape-manipulation" class="level3">
<h3 class="anchored" data-anchor-id="shape-manipulation" id="shape-manipulation">Shape Manipulation</h3>
<p><strong><code>Reshape</code></strong>: Changes shape without copying data. Takes a <code>shape</code> tensor as the second input (not an attribute). Use <code>-1</code> for one inferred dimension.</p>
<p><strong><code>Flatten</code></strong>: Flattens from axis <code>axis</code> onward into a 2D tensor.</p>
<p><strong><code>Squeeze</code></strong>: Removes dimensions of size 1 at specified axes.</p>
<p><strong><code>Unsqueeze</code></strong>: Inserts dimensions of size 1 at specified axes.</p>
<p><strong><code>Concat</code></strong>: Concatenates tensors along a specified <code>axis</code>.</p>
<p><strong><code>Split</code></strong>: Splits a tensor into multiple outputs along an <code>axis</code>.</p>
<p><strong><code>Slice</code></strong>: Extracts a sub-tensor using start, end, axes, and step inputs.</p>
<p><strong><code>Gather</code></strong>: Index-based lookup (embedding table access, index selection).</p>
<p><strong><code>GatherElements</code></strong>: Gathers elements along a specified axis using an index tensor.</p>
<p><strong><code>Scatter</code></strong>, <strong><code>ScatterElements</code></strong>: Inverse of Gather.</p>
<p><strong><code>Pad</code></strong>: Pads a tensor with a constant, edge, reflect, or wrap strategy.</p>
<p><strong><code>Tile</code></strong>: Repeats a tensor along each axis a specified number of times.</p>
<p><strong><code>Expand</code></strong>: Broadcasts a tensor to a target shape.</p>
</section>
<section id="reduction" class="level3">
<h3 class="anchored" data-anchor-id="reduction" id="reduction">Reduction</h3>
<p><strong><code>ReduceMean</code></strong>, <strong><code>ReduceSum</code></strong>, <strong><code>ReduceMax</code></strong>, <strong><code>ReduceMin</code></strong>, <strong><code>ReduceProd</code></strong>: Reduce along specified axes, with optional <code>keepdims</code>.</p>
<p><strong><code>ArgMax</code></strong>, <strong><code>ArgMin</code></strong>: Return the index of the max/min value along an axis.</p>
</section>
<section id="logical-and-comparison" class="level3">
<h3 class="anchored" data-anchor-id="logical-and-comparison" id="logical-and-comparison">Logical and Comparison</h3>
<p><strong><code>Equal</code></strong>, <strong><code>Less</code></strong>, <strong><code>Greater</code></strong>, <strong><code>LessOrEqual</code></strong>, <strong><code>GreaterOrEqual</code></strong>: Element-wise comparisons returning bool tensors.</p>
<p><strong><code>And</code></strong>, <strong><code>Or</code></strong>, <strong><code>Not</code></strong>, <strong><code>Xor</code></strong>: Boolean logic.</p>
<p><strong><code>Where</code></strong>: Selects elements from two tensors based on a bool condition tensor.</p>
</section>
<section id="miscellaneous" class="level3">
<h3 class="anchored" data-anchor-id="miscellaneous" id="miscellaneous">Miscellaneous</h3>
<p><strong><code>Cast</code></strong>: Converts element dtype, e.g., from <code>INT64</code> to <code>FLOAT</code>.</p>
<p><strong><code>Constant</code></strong>: Embeds a constant tensor directly as a node. Useful when you need a tensor value but it is computed (not stored as an initializer).</p>
<p><strong><code>Shape</code></strong>: Returns the shape of a tensor as a 1D INT64 tensor.</p>
<p><strong><code>Size</code></strong>: Returns the total number of elements as a scalar INT64.</p>
<p><strong><code>Dropout</code></strong>: Applies dropout. In ONNX inference mode, this is a pass-through (no masking).</p>
<p><strong><code>Einsum</code></strong>: General einsum notation. Available from opset 12.</p>
</section>
</section>
<section id="constructing-graphs-with-the-onnx-helper-api" class="level2">
<h2 class="anchored" data-anchor-id="constructing-graphs-with-the-onnx-helper-api" id="constructing-graphs-with-the-onnx-helper-api">Constructing Graphs with the ONNX Helper API</h2>
<p>The <code>onnx.helper</code> module is your primary interface. Here is an overview of its key functions.</p>
<section id="onnx.helper.make_node" class="level3">
<h3 class="anchored" data-anchor-id="onnx.helper.make_node" id="onnx.helper.make_node"><code>onnx.helper.make_node</code></h3>
<p>Creates a <code>NodeProto</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>node <span class="op">=</span> helper.make_node(</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    op_type<span class="op">=</span><span class="st">"Relu"</span>,          <span class="co"># operator name</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"linear_out"</span>],   <span class="co"># names of input tensors</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"relu_out"</span>],    <span class="co"># names of output tensors</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"relu_1"</span>,           <span class="co"># optional: name for the node itself</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>For operators with attributes:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>node <span class="op">=</span> helper.make_node(</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    op_type<span class="op">=</span><span class="st">"Conv"</span>,</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="st">"W"</span>, <span class="st">"b"</span>],</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"conv_out"</span>],</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    kernel_shape<span class="op">=</span>[<span class="dv">3</span>, <span class="dv">3</span>],</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    strides<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    pads<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"conv_1"</span>,</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>Attributes are passed as keyword arguments. ONNX infers their types automatically from the Python values you pass (int → INT, float → FLOAT, list of ints → INTS, etc.).</p>
</section>
<section id="onnx.helper.make_tensor_value_info" class="level3">
<h3 class="anchored" data-anchor-id="onnx.helper.make_tensor_value_info" id="onnx.helper.make_tensor_value_info"><code>onnx.helper.make_tensor_value_info</code></h3>
<p>Creates a <code>ValueInfoProto</code> for describing graph inputs and outputs.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Fixed batch size of 1, 784 features</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>x_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>, TensorProto.FLOAT, [<span class="dv">1</span>, <span class="dv">784</span>])</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Dynamic batch size (symbolic), 10 classes</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>y_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"output"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">10</span>])</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Completely dynamic shape</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>z_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"z"</span>, TensorProto.FLOAT, <span class="va">None</span>)</span></code></pre></div></div>
</section>
<section id="onnx.numpy_helper.from_array" class="level3">
<h3 class="anchored" data-anchor-id="onnx.numpy_helper.from_array" id="onnx.numpy_helper.from_array"><code>onnx.numpy_helper.from_array</code></h3>
<p>Converts a NumPy array to a <code>TensorProto</code> for use as an initializer.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> numpy_helper</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>W <span class="op">=</span> np.random.randn(<span class="dv">128</span>, <span class="dv">784</span>).astype(np.float32)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>W_tensor <span class="op">=</span> numpy_helper.from_array(W, name<span class="op">=</span><span class="st">"fc1_weight"</span>)</span></code></pre></div></div>
</section>
<section id="onnx.helper.make_graph" class="level3">
<h3 class="anchored" data-anchor-id="onnx.helper.make_graph" id="onnx.helper.make_graph"><code>onnx.helper.make_graph</code></h3>
<p>Assembles nodes, initializers, inputs, and outputs into a <code>GraphProto</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    nodes<span class="op">=</span>[node1, node2, node3],</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"my_mlp"</span>,</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info],</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[y_info],</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    initializer<span class="op">=</span>[W_tensor, b_tensor],</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="onnx.helper.make_model" class="level3">
<h3 class="anchored" data-anchor-id="onnx.helper.make_model" id="onnx.helper.make_model"><code>onnx.helper.make_model</code></h3>
<p>Wraps a graph in a <code>ModelProto</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    graph,</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)],  <span class="co"># opset 17 of the default domain</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>model.producer_name <span class="op">=</span> <span class="st">"my_builder"</span></span></code></pre></div></div>
</section>
<section id="onnx.checker.check_model" class="level3">
<h3 class="anchored" data-anchor-id="onnx.checker.check_model" id="onnx.checker.check_model"><code>onnx.checker.check_model</code></h3>
<p>Validates the model’s structural correctness. Always run this before saving or running.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span></code></pre></div></div>
</section>
<section id="onnx.save" class="level3">
<h3 class="anchored" data-anchor-id="onnx.save" id="onnx.save"><code>onnx.save</code></h3>
<p>Serializes to a <code>.onnx</code> file.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"my_model.onnx"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="building-a-linear-regression-model" class="level2">
<h2 class="anchored" data-anchor-id="building-a-linear-regression-model" id="building-a-linear-regression-model">Building a Linear Regression Model</h2>
<p>Let us start with the simplest possible “network”: a linear regression that computes <code>y = X @ W + b</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Define weights and bias as numpy arrays                          #</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>in_features  <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>out_features <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>W_data <span class="op">=</span> np.random.randn(in_features, out_features).astype(np.float32)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>b_data <span class="op">=</span> np.zeros(out_features, dtype<span class="op">=</span>np.float32)</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Convert to TensorProto initializers                              #</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>W_init <span class="op">=</span> numpy_helper.from_array(W_data, name<span class="op">=</span><span class="st">"W"</span>)</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>b_init <span class="op">=</span> numpy_helper.from_array(b_data, name<span class="op">=</span><span class="st">"b"</span>)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Define the graph's external input and output shapes              #</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Input: batch of samples, each with 8 features</span></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>x_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, in_features])</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Output: batch of scalars</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>y_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"y"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, out_features])</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. Define the computation node                                      #</span></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Gemm computes: alpha * A @ B + beta * C</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a><span class="co"># We want: x @ W + b, which is: 1.0 * x @ W + 1.0 * b</span></span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>gemm_node <span class="op">=</span> helper.make_node(</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>    op_type<span class="op">=</span><span class="st">"Gemm"</span>,</span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="st">"W"</span>, <span class="st">"b"</span>],</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"y"</span>],</span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>    alpha<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>    beta<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>    transB<span class="op">=</span><span class="dv">0</span>,  <span class="co"># W is already (in_features, out_features), no transpose needed</span></span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"linear"</span>,</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a><span class="co"># 5. Build the graph                                                  #</span></span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(</span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>    nodes<span class="op">=</span>[gemm_node],</span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"linear_regression"</span>,</span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info],</span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[y_info],</span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>    initializer<span class="op">=</span>[W_init, b_init],</span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-53"><a href="#cb12-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-54"><a href="#cb12-54" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-55"><a href="#cb12-55" aria-hidden="true" tabindex="-1"></a><span class="co"># 6. Build the model                                                  #</span></span>
<span id="cb12-56"><a href="#cb12-56" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-57"><a href="#cb12-57" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(</span>
<span id="cb12-58"><a href="#cb12-58" aria-hidden="true" tabindex="-1"></a>    graph,</span>
<span id="cb12-59"><a href="#cb12-59" aria-hidden="true" tabindex="-1"></a>    opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)],</span>
<span id="cb12-60"><a href="#cb12-60" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-61"><a href="#cb12-61" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb12-62"><a href="#cb12-62" aria-hidden="true" tabindex="-1"></a>model.producer_name <span class="op">=</span> <span class="st">"onnx_guide"</span></span>
<span id="cb12-63"><a href="#cb12-63" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-64"><a href="#cb12-64" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-65"><a href="#cb12-65" aria-hidden="true" tabindex="-1"></a><span class="co"># 7. Validate and save                                                #</span></span>
<span id="cb12-66"><a href="#cb12-66" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb12-67"><a href="#cb12-67" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb12-68"><a href="#cb12-68" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"linear_regression.onnx"</span>)</span>
<span id="cb12-69"><a href="#cb12-69" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Model saved."</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key details
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>The initializers <code>W</code> and <code>b</code> are listed in <code>initializer</code> and implicitly available as named tensors in the graph. You do not list them as graph <code>inputs</code> because they are constants — they do not vary across inference calls.</li>
<li><code>Gemm</code>’s <code>transB</code> attribute controls whether the second matrix is transposed before multiply. With <code>transB=0</code> and <code>W</code> shaped <code>[in_features, out_features]</code>, the compute is <code>x @ W</code>, giving output shape <code>[batch, out_features]</code>.</li>
<li>Symbolic dimensions like <code>"batch"</code> in shape specifications tell ONNX Runtime to accept any value on that axis at runtime.</li>
</ul>
</div>
</div>
</section>
<section id="building-a-multi-layer-perceptron-mlp" class="level2">
<h2 class="anchored" data-anchor-id="building-a-multi-layer-perceptron-mlp" id="building-a-multi-layer-perceptron-mlp">Building a Multi-Layer Perceptron (MLP)</h2>
<p>A multi-layer perceptron stacks fully connected layers with nonlinear activations between them. Here we build a 3-layer MLP for classification: input → hidden1 → hidden2 → logits → softmax.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Architecture hyperparameters                                        #</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>in_dim  <span class="op">=</span> <span class="dv">784</span>   <span class="co"># e.g., MNIST flattened</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>h1_dim  <span class="op">=</span> <span class="dv">256</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>h2_dim  <span class="op">=</span> <span class="dv">128</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>out_dim <span class="op">=</span> <span class="dv">10</span>    <span class="co"># classes</span></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> make_fc_weights(in_d, out_d, name_prefix):</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create weight and bias initializers for a fully connected layer."""</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    W <span class="op">=</span> np.random.randn(in_d, out_d).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> in_d)</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    b <span class="op">=</span> np.zeros(out_d, dtype<span class="op">=</span>np.float32)</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    W_init <span class="op">=</span> numpy_helper.from_array(W, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_W"</span>)</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    b_init <span class="op">=</span> numpy_helper.from_array(b, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_b"</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> W_init, b_init</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Initializers                                                        #</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>fc1_W, fc1_b <span class="op">=</span> make_fc_weights(in_dim, h1_dim, <span class="st">"fc1"</span>)</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>fc2_W, fc2_b <span class="op">=</span> make_fc_weights(h1_dim, h2_dim, <span class="st">"fc2"</span>)</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>fc3_W, fc3_b <span class="op">=</span> make_fc_weights(h2_dim, out_dim, <span class="st">"fc3"</span>)</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>all_initializers <span class="op">=</span> [fc1_W, fc1_b, fc2_W, fc2_b, fc3_W, fc3_b]</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Nodes                                                               #</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">=</span> []</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a><span class="co"># Layer 1: Linear → ReLU</span></span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Gemm"</span>, inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="st">"fc1_W"</span>, <span class="st">"fc1_b"</span>], outputs<span class="op">=</span>[<span class="st">"fc1_out"</span>],</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"fc1"</span>, alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Relu"</span>, inputs<span class="op">=</span>[<span class="st">"fc1_out"</span>], outputs<span class="op">=</span>[<span class="st">"relu1_out"</span>],</span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"relu1"</span>,</span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Layer 2: Linear → ReLU</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Gemm"</span>, inputs<span class="op">=</span>[<span class="st">"relu1_out"</span>, <span class="st">"fc2_W"</span>, <span class="st">"fc2_b"</span>], outputs<span class="op">=</span>[<span class="st">"fc2_out"</span>],</span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"fc2"</span>, alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Relu"</span>, inputs<span class="op">=</span>[<span class="st">"fc2_out"</span>], outputs<span class="op">=</span>[<span class="st">"relu2_out"</span>],</span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"relu2"</span>,</span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a><span class="co"># Layer 3: Linear (logits)</span></span>
<span id="cb13-56"><a href="#cb13-56" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-57"><a href="#cb13-57" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Gemm"</span>, inputs<span class="op">=</span>[<span class="st">"relu2_out"</span>, <span class="st">"fc3_W"</span>, <span class="st">"fc3_b"</span>], outputs<span class="op">=</span>[<span class="st">"logits"</span>],</span>
<span id="cb13-58"><a href="#cb13-58" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"fc3"</span>, alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb13-59"><a href="#cb13-59" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-60"><a href="#cb13-60" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-61"><a href="#cb13-61" aria-hidden="true" tabindex="-1"></a><span class="co"># Softmax over class dimension (axis=-1 is the default)</span></span>
<span id="cb13-62"><a href="#cb13-62" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(</span>
<span id="cb13-63"><a href="#cb13-63" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Softmax"</span>, inputs<span class="op">=</span>[<span class="st">"logits"</span>], outputs<span class="op">=</span>[<span class="st">"probs"</span>],</span>
<span id="cb13-64"><a href="#cb13-64" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"softmax"</span>, axis<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb13-65"><a href="#cb13-65" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb13-66"><a href="#cb13-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-67"><a href="#cb13-67" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-68"><a href="#cb13-68" aria-hidden="true" tabindex="-1"></a><span class="co"># Graph inputs / outputs                                              #</span></span>
<span id="cb13-69"><a href="#cb13-69" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-70"><a href="#cb13-70" aria-hidden="true" tabindex="-1"></a>x_info    <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>,     TensorProto.FLOAT, [<span class="st">"batch"</span>, in_dim])</span>
<span id="cb13-71"><a href="#cb13-71" aria-hidden="true" tabindex="-1"></a>prob_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"probs"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, out_dim])</span>
<span id="cb13-72"><a href="#cb13-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-73"><a href="#cb13-73" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-74"><a href="#cb13-74" aria-hidden="true" tabindex="-1"></a><span class="co"># Assemble and save                                                   #</span></span>
<span id="cb13-75"><a href="#cb13-75" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb13-76"><a href="#cb13-76" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(</span>
<span id="cb13-77"><a href="#cb13-77" aria-hidden="true" tabindex="-1"></a>    nodes, <span class="st">"mlp"</span>,</span>
<span id="cb13-78"><a href="#cb13-78" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info],</span>
<span id="cb13-79"><a href="#cb13-79" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[prob_info],</span>
<span id="cb13-80"><a href="#cb13-80" aria-hidden="true" tabindex="-1"></a>    initializer<span class="op">=</span>all_initializers,</span>
<span id="cb13-81"><a href="#cb13-81" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb13-82"><a href="#cb13-82" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(graph, opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)])</span>
<span id="cb13-83"><a href="#cb13-83" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb13-84"><a href="#cb13-84" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-85"><a href="#cb13-85" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb13-86"><a href="#cb13-86" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"mlp.onnx"</span>)</span>
<span id="cb13-87"><a href="#cb13-87" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"MLP saved."</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key observations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Weight names in <code>make_node</code> must match exactly the names used in <code>numpy_helper.from_array</code>. A single character mismatch causes a runtime error.</li>
<li>He initialization (<code>np.sqrt(2.0 / in_d)</code>) is baked into the weight values at construction time. ONNX does not have an initialization scheme concept; weights are just constant tensors.</li>
<li><code>Gemm</code> expects <code>W</code> shaped <code>[in_dim, out_dim]</code> when <code>transB=0</code>. Some sources convention their weights as <code>[out_dim, in_dim]</code> and use <code>transB=1</code>; both are valid.</li>
</ul>
</div>
</div>
</section>
<section id="building-a-convolutional-neural-network-cnn" class="level2">
<h2 class="anchored" data-anchor-id="building-a-convolutional-neural-network-cnn" id="building-a-convolutional-neural-network-cnn">Building a Convolutional Neural Network (CNN)</h2>
<p>CNNs require managing multi-dimensional weight tensors. In ONNX, the <code>Conv</code> operator expects:</p>
<ul>
<li>Input <code>X</code>: shape <code>[batch, in_channels, height, width]</code> (NCHW format).</li>
<li>Weight <code>W</code>: shape <code>[out_channels, in_channels/group, kernel_h, kernel_w]</code>.</li>
<li>Bias <code>B</code>: shape <code>[out_channels]</code> (optional).</li>
</ul>
<p>Here we build a small CNN for image classification (CIFAR-style input: 3×32×32 → 10 classes).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> conv_weight(out_ch, in_ch, kH, kW, name):</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    fan_in <span class="op">=</span> in_ch <span class="op">*</span> kH <span class="op">*</span> kW</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    W <span class="op">=</span> np.random.randn(out_ch, in_ch, kH, kW).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> fan_in)</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> numpy_helper.from_array(W, name<span class="op">=</span>name)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> conv_bias(out_ch, name):</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    b <span class="op">=</span> np.zeros(out_ch, dtype<span class="op">=</span>np.float32)</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> numpy_helper.from_array(b, name<span class="op">=</span>name)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> bn_params(channels, name_prefix):</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""BatchNorm scale (gamma), bias (beta), running mean, running var."""</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    scale <span class="op">=</span> numpy_helper.from_array(np.ones(channels,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_scale"</span>)</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    bias  <span class="op">=</span> numpy_helper.from_array(np.zeros(channels, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_bias"</span>)</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    mean  <span class="op">=</span> numpy_helper.from_array(np.zeros(channels, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_mean"</span>)</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>    var   <span class="op">=</span> numpy_helper.from_array(np.ones(channels,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_var"</span>)</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> scale, bias, mean, var</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fc_params(in_d, out_d, name_prefix):</span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>    W <span class="op">=</span> np.random.randn(in_d, out_d).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> in_d)</span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>    b <span class="op">=</span> np.zeros(out_d, dtype<span class="op">=</span>np.float32)</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (numpy_helper.from_array(W, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_W"</span>),</span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>            numpy_helper.from_array(b, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>name_prefix<span class="sc">}</span><span class="ss">_b"</span>))</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Initializers                                                        #</span></span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>inits <span class="op">=</span> []</span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a><span class="co"># Conv block 1: 3 → 32 channels, 3×3 kernel</span></span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [conv_weight(<span class="dv">32</span>, <span class="dv">3</span>,  <span class="dv">3</span>, <span class="dv">3</span>, <span class="st">"conv1_W"</span>), conv_bias(<span class="dv">32</span>, <span class="st">"conv1_b"</span>)]</span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> <span class="bu">list</span>(bn_params(<span class="dv">32</span>, <span class="st">"bn1"</span>))</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Conv block 2: 32 → 64 channels, 3×3 kernel</span></span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [conv_weight(<span class="dv">64</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">3</span>, <span class="st">"conv2_W"</span>), conv_bias(<span class="dv">64</span>, <span class="st">"conv2_b"</span>)]</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> <span class="bu">list</span>(bn_params(<span class="dv">64</span>, <span class="st">"bn2"</span>))</span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Conv block 3: 64 → 128 channels, 3×3 kernel</span></span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [conv_weight(<span class="dv">128</span>, <span class="dv">64</span>, <span class="dv">3</span>, <span class="dv">3</span>, <span class="st">"conv3_W"</span>), conv_bias(<span class="dv">128</span>, <span class="st">"conv3_b"</span>)]</span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> <span class="bu">list</span>(bn_params(<span class="dv">128</span>, <span class="st">"bn3"</span>))</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Fully connected layers</span></span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a><span class="co"># After 3 max-pools on 32×32 input: 32/2/2/2 = 4 spatial → 128 * 4 * 4 = 2048</span></span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> <span class="bu">list</span>(fc_params(<span class="dv">128</span> <span class="op">*</span> <span class="dv">4</span> <span class="op">*</span> <span class="dv">4</span>, <span class="dv">256</span>, <span class="st">"fc1"</span>))</span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> <span class="bu">list</span>(fc_params(<span class="dv">256</span>, <span class="dv">10</span>, <span class="st">"fc2"</span>))</span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a><span class="co"># Nodes                                                               #</span></span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">=</span> []</span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> conv_bn_relu(x_name, conv_w, conv_b, bn_prefix, out_name, kH<span class="op">=</span><span class="dv">3</span>, kW<span class="op">=</span><span class="dv">3</span>, pad<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Returns a list of nodes: Conv → BatchNorm → Relu."""</span></span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>    conv_out <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_conv"</span></span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>    bn_out   <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn"</span></span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [</span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>        helper.make_node(<span class="st">"Conv"</span>,</span>
<span id="cb14-61"><a href="#cb14-61" aria-hidden="true" tabindex="-1"></a>            inputs<span class="op">=</span>[x_name, conv_w, conv_b],</span>
<span id="cb14-62"><a href="#cb14-62" aria-hidden="true" tabindex="-1"></a>            outputs<span class="op">=</span>[conv_out],</span>
<span id="cb14-63"><a href="#cb14-63" aria-hidden="true" tabindex="-1"></a>            kernel_shape<span class="op">=</span>[kH, kW],</span>
<span id="cb14-64"><a href="#cb14-64" aria-hidden="true" tabindex="-1"></a>            strides<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb14-65"><a href="#cb14-65" aria-hidden="true" tabindex="-1"></a>            pads<span class="op">=</span>[pad, pad, pad, pad],</span>
<span id="cb14-66"><a href="#cb14-66" aria-hidden="true" tabindex="-1"></a>            name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_conv_op"</span>,</span>
<span id="cb14-67"><a href="#cb14-67" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb14-68"><a href="#cb14-68" aria-hidden="true" tabindex="-1"></a>        helper.make_node(<span class="st">"BatchNormalization"</span>,</span>
<span id="cb14-69"><a href="#cb14-69" aria-hidden="true" tabindex="-1"></a>            inputs<span class="op">=</span>[conv_out,</span>
<span id="cb14-70"><a href="#cb14-70" aria-hidden="true" tabindex="-1"></a>                    <span class="ss">f"</span><span class="sc">{</span>bn_prefix<span class="sc">}</span><span class="ss">_scale"</span>, <span class="ss">f"</span><span class="sc">{</span>bn_prefix<span class="sc">}</span><span class="ss">_bias"</span>,</span>
<span id="cb14-71"><a href="#cb14-71" aria-hidden="true" tabindex="-1"></a>                    <span class="ss">f"</span><span class="sc">{</span>bn_prefix<span class="sc">}</span><span class="ss">_mean"</span>,  <span class="ss">f"</span><span class="sc">{</span>bn_prefix<span class="sc">}</span><span class="ss">_var"</span>],</span>
<span id="cb14-72"><a href="#cb14-72" aria-hidden="true" tabindex="-1"></a>            outputs<span class="op">=</span>[bn_out],</span>
<span id="cb14-73"><a href="#cb14-73" aria-hidden="true" tabindex="-1"></a>            epsilon<span class="op">=</span><span class="fl">1e-5</span>,</span>
<span id="cb14-74"><a href="#cb14-74" aria-hidden="true" tabindex="-1"></a>            momentum<span class="op">=</span><span class="fl">0.9</span>,</span>
<span id="cb14-75"><a href="#cb14-75" aria-hidden="true" tabindex="-1"></a>            name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn_op"</span>,</span>
<span id="cb14-76"><a href="#cb14-76" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb14-77"><a href="#cb14-77" aria-hidden="true" tabindex="-1"></a>        helper.make_node(<span class="st">"Relu"</span>,</span>
<span id="cb14-78"><a href="#cb14-78" aria-hidden="true" tabindex="-1"></a>            inputs<span class="op">=</span>[bn_out],</span>
<span id="cb14-79"><a href="#cb14-79" aria-hidden="true" tabindex="-1"></a>            outputs<span class="op">=</span>[out_name],</span>
<span id="cb14-80"><a href="#cb14-80" aria-hidden="true" tabindex="-1"></a>            name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_relu_op"</span>,</span>
<span id="cb14-81"><a href="#cb14-81" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb14-82"><a href="#cb14-82" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb14-83"><a href="#cb14-83" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-84"><a href="#cb14-84" aria-hidden="true" tabindex="-1"></a><span class="co"># Block 1 + MaxPool</span></span>
<span id="cb14-85"><a href="#cb14-85" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">+=</span> conv_bn_relu(<span class="st">"x"</span>, <span class="st">"conv1_W"</span>, <span class="st">"conv1_b"</span>, <span class="st">"bn1"</span>, <span class="st">"block1_out"</span>)</span>
<span id="cb14-86"><a href="#cb14-86" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MaxPool"</span>,</span>
<span id="cb14-87"><a href="#cb14-87" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"block1_out"</span>], outputs<span class="op">=</span>[<span class="st">"pool1_out"</span>],</span>
<span id="cb14-88"><a href="#cb14-88" aria-hidden="true" tabindex="-1"></a>    kernel_shape<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], strides<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], name<span class="op">=</span><span class="st">"pool1"</span>,</span>
<span id="cb14-89"><a href="#cb14-89" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-90"><a href="#cb14-90" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-91"><a href="#cb14-91" aria-hidden="true" tabindex="-1"></a><span class="co"># Block 2 + MaxPool</span></span>
<span id="cb14-92"><a href="#cb14-92" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">+=</span> conv_bn_relu(<span class="st">"pool1_out"</span>, <span class="st">"conv2_W"</span>, <span class="st">"conv2_b"</span>, <span class="st">"bn2"</span>, <span class="st">"block2_out"</span>)</span>
<span id="cb14-93"><a href="#cb14-93" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MaxPool"</span>,</span>
<span id="cb14-94"><a href="#cb14-94" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"block2_out"</span>], outputs<span class="op">=</span>[<span class="st">"pool2_out"</span>],</span>
<span id="cb14-95"><a href="#cb14-95" aria-hidden="true" tabindex="-1"></a>    kernel_shape<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], strides<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], name<span class="op">=</span><span class="st">"pool2"</span>,</span>
<span id="cb14-96"><a href="#cb14-96" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-97"><a href="#cb14-97" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-98"><a href="#cb14-98" aria-hidden="true" tabindex="-1"></a><span class="co"># Block 3 + MaxPool</span></span>
<span id="cb14-99"><a href="#cb14-99" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">+=</span> conv_bn_relu(<span class="st">"pool2_out"</span>, <span class="st">"conv3_W"</span>, <span class="st">"conv3_b"</span>, <span class="st">"bn3"</span>, <span class="st">"block3_out"</span>)</span>
<span id="cb14-100"><a href="#cb14-100" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MaxPool"</span>,</span>
<span id="cb14-101"><a href="#cb14-101" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"block3_out"</span>], outputs<span class="op">=</span>[<span class="st">"pool3_out"</span>],</span>
<span id="cb14-102"><a href="#cb14-102" aria-hidden="true" tabindex="-1"></a>    kernel_shape<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], strides<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], name<span class="op">=</span><span class="st">"pool3"</span>,</span>
<span id="cb14-103"><a href="#cb14-103" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-104"><a href="#cb14-104" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-105"><a href="#cb14-105" aria-hidden="true" tabindex="-1"></a><span class="co"># Flatten: [batch, 128, 4, 4] → [batch, 2048]</span></span>
<span id="cb14-106"><a href="#cb14-106" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Flatten"</span>,</span>
<span id="cb14-107"><a href="#cb14-107" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"pool3_out"</span>], outputs<span class="op">=</span>[<span class="st">"flat_out"</span>],</span>
<span id="cb14-108"><a href="#cb14-108" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=</span><span class="dv">1</span>, name<span class="op">=</span><span class="st">"flatten"</span>,</span>
<span id="cb14-109"><a href="#cb14-109" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-110"><a href="#cb14-110" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-111"><a href="#cb14-111" aria-hidden="true" tabindex="-1"></a><span class="co"># FC1 + ReLU</span></span>
<span id="cb14-112"><a href="#cb14-112" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Gemm"</span>,</span>
<span id="cb14-113"><a href="#cb14-113" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"flat_out"</span>, <span class="st">"fc1_W"</span>, <span class="st">"fc1_b"</span>], outputs<span class="op">=</span>[<span class="st">"fc1_out"</span>],</span>
<span id="cb14-114"><a href="#cb14-114" aria-hidden="true" tabindex="-1"></a>    alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>, name<span class="op">=</span><span class="st">"fc1"</span>,</span>
<span id="cb14-115"><a href="#cb14-115" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-116"><a href="#cb14-116" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Relu"</span>,</span>
<span id="cb14-117"><a href="#cb14-117" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"fc1_out"</span>], outputs<span class="op">=</span>[<span class="st">"fc1_relu"</span>],</span>
<span id="cb14-118"><a href="#cb14-118" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"fc1_relu"</span>,</span>
<span id="cb14-119"><a href="#cb14-119" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-120"><a href="#cb14-120" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-121"><a href="#cb14-121" aria-hidden="true" tabindex="-1"></a><span class="co"># FC2 (logits) + Softmax</span></span>
<span id="cb14-122"><a href="#cb14-122" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Gemm"</span>,</span>
<span id="cb14-123"><a href="#cb14-123" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"fc1_relu"</span>, <span class="st">"fc2_W"</span>, <span class="st">"fc2_b"</span>], outputs<span class="op">=</span>[<span class="st">"logits"</span>],</span>
<span id="cb14-124"><a href="#cb14-124" aria-hidden="true" tabindex="-1"></a>    alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>, name<span class="op">=</span><span class="st">"fc2"</span>,</span>
<span id="cb14-125"><a href="#cb14-125" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-126"><a href="#cb14-126" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Softmax"</span>,</span>
<span id="cb14-127"><a href="#cb14-127" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"logits"</span>], outputs<span class="op">=</span>[<span class="st">"probs"</span>],</span>
<span id="cb14-128"><a href="#cb14-128" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=-</span><span class="dv">1</span>, name<span class="op">=</span><span class="st">"softmax"</span>,</span>
<span id="cb14-129"><a href="#cb14-129" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb14-130"><a href="#cb14-130" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-131"><a href="#cb14-131" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-132"><a href="#cb14-132" aria-hidden="true" tabindex="-1"></a><span class="co"># Graph, model, validate, save                                       #</span></span>
<span id="cb14-133"><a href="#cb14-133" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb14-134"><a href="#cb14-134" aria-hidden="true" tabindex="-1"></a>x_info    <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>,     TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">32</span>])</span>
<span id="cb14-135"><a href="#cb14-135" aria-hidden="true" tabindex="-1"></a>prob_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"probs"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">10</span>])</span>
<span id="cb14-136"><a href="#cb14-136" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-137"><a href="#cb14-137" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(nodes, <span class="st">"cnn_classifier"</span>,</span>
<span id="cb14-138"><a href="#cb14-138" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info], outputs<span class="op">=</span>[prob_info], initializer<span class="op">=</span>inits)</span>
<span id="cb14-139"><a href="#cb14-139" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(graph, opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)])</span>
<span id="cb14-140"><a href="#cb14-140" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb14-141"><a href="#cb14-141" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-142"><a href="#cb14-142" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb14-143"><a href="#cb14-143" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"cnn_classifier.onnx"</span>)</span>
<span id="cb14-144"><a href="#cb14-144" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"CNN saved."</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Important notes on the CNN
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>NCHW layout</strong>: ONNX <code>Conv</code> assumes channel-first ordering. If your input data is NHWC, you must <code>Transpose</code> it first.</li>
<li><strong><code>pads</code> attribute</strong>: For <code>Conv</code>, pads are listed as <code>[pad_top, pad_left, pad_bottom, pad_right]</code> for 2D convolutions. Some versions use <code>[pad_h_begin, pad_w_begin, pad_h_end, pad_w_end]</code>. Check the ONNX spec for your opset.</li>
<li><strong><code>BatchNormalization</code> in inference mode</strong>: ONNX BN in opset 9+ produces only one output (the normalized tensor). The training-mode outputs (saved mean, saved variance) are not produced in inference mode. If you see BN with 5 outputs, it is training mode; for inference, set <code>training_mode=0</code> (default).</li>
<li><strong><code>Flatten</code> axis</strong>: <code>axis=1</code> means flatten from dimension 1 onward, preserving the batch dimension. The result is <code>[batch, 128*4*4]</code>.</li>
</ul>
</div>
</div>
</section>
<section id="building-a-recurrent-neural-network-rnnlstm" class="level2">
<h2 class="anchored" data-anchor-id="building-a-recurrent-neural-network-rnnlstm" id="building-a-recurrent-neural-network-rnnlstm">Building a Recurrent Neural Network (RNN/LSTM)</h2>
<p>ONNX’s <code>LSTM</code> operator encodes a full LSTM layer in a single node, which is different from the cell-by-cell approach in PyTorch. This makes it compact but the weight layout requires care.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Gate order difference
</div>
</div>
<div class="callout-body-container callout-body">
<p>The ONNX LSTM gate order is <strong>IOFC</strong> (Input, Output, Forget, Cell), while PyTorch uses <strong>IFCO</strong> (Input, Forget, Cell, Output). This affects how you lay out the weight tensor if you ever interoperate.</p>
</div>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>seq_len    <span class="op">=</span> <span class="dv">20</span>     <span class="co"># sequence length</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>batch_size <span class="op">=</span> <span class="dv">4</span>      <span class="co"># batch size</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>input_size <span class="op">=</span> <span class="dv">16</span>     <span class="co"># features per timestep</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>hidden_size <span class="op">=</span> <span class="dv">32</span>    <span class="co"># LSTM hidden dim</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>num_layers <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>directions <span class="op">=</span> <span class="dv">1</span>      <span class="co"># 1 for forward, 2 for bidirectional</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a><span class="co"># LSTM Weight Layout (ONNX standard):                                 #</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a><span class="co"># W shape: [directions, 4 * hidden_size, input_size]                  #</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a><span class="co"># R shape: [directions, 4 * hidden_size, hidden_size]                 #</span></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a><span class="co"># B shape: [directions, 8 * hidden_size]  (W_bias concat R_bias)      #</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>W_data <span class="op">=</span> np.random.randn(directions, <span class="dv">4</span> <span class="op">*</span> hidden_size, input_size).astype(np.float32)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>R_data <span class="op">=</span> np.random.randn(directions, <span class="dv">4</span> <span class="op">*</span> hidden_size, hidden_size).astype(np.float32)</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>B_data <span class="op">=</span> np.zeros((directions, <span class="dv">8</span> <span class="op">*</span> hidden_size), dtype<span class="op">=</span>np.float32)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>W_init <span class="op">=</span> numpy_helper.from_array(W_data, name<span class="op">=</span><span class="st">"lstm_W"</span>)</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>R_init <span class="op">=</span> numpy_helper.from_array(R_data, name<span class="op">=</span><span class="st">"lstm_R"</span>)</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>B_init <span class="op">=</span> numpy_helper.from_array(B_data, name<span class="op">=</span><span class="st">"lstm_B"</span>)</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a><span class="co"># LSTM node                                                           #</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Inputs:  X, W, R, B, sequence_lens (optional), initial_h, initial_c, P (peepholes)</span></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Outputs: Y (all hidden states), Y_h (final hidden), Y_c (final cell)</span></span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>lstm_node <span class="op">=</span> helper.make_node(</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>    <span class="st">"LSTM"</span>,</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="st">"lstm_W"</span>, <span class="st">"lstm_R"</span>, <span class="st">"lstm_B"</span>],  <span class="co"># omit optional inputs</span></span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"Y"</span>, <span class="st">"Y_h"</span>, <span class="st">"Y_c"</span>],</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>    hidden_size<span class="op">=</span>hidden_size,</span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>    direction<span class="op">=</span><span class="st">"forward"</span>,</span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"lstm_layer"</span>,</span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Y shape:   [seq_len, directions, batch, hidden_size]</span></span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a><span class="co"># Y_h shape: [directions, batch, hidden_size]</span></span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a><span class="co"># Y_c shape: [directions, batch, hidden_size]</span></span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a><span class="co"># We want the final hidden state: Y_h, shape [1, batch, hidden_size]</span></span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a><span class="co"># Squeeze the directions dimension:</span></span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>squeeze_axes <span class="op">=</span> numpy_helper.from_array(np.array([<span class="dv">0</span>], dtype<span class="op">=</span>np.int64), name<span class="op">=</span><span class="st">"squeeze_axes"</span>)</span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>squeeze_node <span class="op">=</span> helper.make_node(</span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Squeeze"</span>,</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"Y_h"</span>, <span class="st">"squeeze_axes"</span>],</span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"h_final"</span>],   <span class="co"># shape: [batch, hidden_size]</span></span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"squeeze_h"</span>,</span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-56"><a href="#cb15-56" aria-hidden="true" tabindex="-1"></a><span class="co"># Final classifier</span></span>
<span id="cb15-57"><a href="#cb15-57" aria-hidden="true" tabindex="-1"></a>fc_W_data <span class="op">=</span> np.random.randn(hidden_size, <span class="dv">5</span>).astype(np.float32)  <span class="co"># 5 output classes</span></span>
<span id="cb15-58"><a href="#cb15-58" aria-hidden="true" tabindex="-1"></a>fc_b_data <span class="op">=</span> np.zeros(<span class="dv">5</span>, dtype<span class="op">=</span>np.float32)</span>
<span id="cb15-59"><a href="#cb15-59" aria-hidden="true" tabindex="-1"></a>fc_W_init <span class="op">=</span> numpy_helper.from_array(fc_W_data, name<span class="op">=</span><span class="st">"fc_W"</span>)</span>
<span id="cb15-60"><a href="#cb15-60" aria-hidden="true" tabindex="-1"></a>fc_b_init <span class="op">=</span> numpy_helper.from_array(fc_b_data, name<span class="op">=</span><span class="st">"fc_b"</span>)</span>
<span id="cb15-61"><a href="#cb15-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-62"><a href="#cb15-62" aria-hidden="true" tabindex="-1"></a>fc_node <span class="op">=</span> helper.make_node(</span>
<span id="cb15-63"><a href="#cb15-63" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Gemm"</span>,</span>
<span id="cb15-64"><a href="#cb15-64" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"h_final"</span>, <span class="st">"fc_W"</span>, <span class="st">"fc_b"</span>],</span>
<span id="cb15-65"><a href="#cb15-65" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"logits"</span>],</span>
<span id="cb15-66"><a href="#cb15-66" aria-hidden="true" tabindex="-1"></a>    alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb15-67"><a href="#cb15-67" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"fc_out"</span>,</span>
<span id="cb15-68"><a href="#cb15-68" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-69"><a href="#cb15-69" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-70"><a href="#cb15-70" aria-hidden="true" tabindex="-1"></a>softmax_node <span class="op">=</span> helper.make_node(</span>
<span id="cb15-71"><a href="#cb15-71" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Softmax"</span>,</span>
<span id="cb15-72"><a href="#cb15-72" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"logits"</span>],</span>
<span id="cb15-73"><a href="#cb15-73" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"probs"</span>],</span>
<span id="cb15-74"><a href="#cb15-74" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb15-75"><a href="#cb15-75" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"softmax"</span>,</span>
<span id="cb15-76"><a href="#cb15-76" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-77"><a href="#cb15-77" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-78"><a href="#cb15-78" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-79"><a href="#cb15-79" aria-hidden="true" tabindex="-1"></a><span class="co"># Graph assembly                                                      #</span></span>
<span id="cb15-80"><a href="#cb15-80" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb15-81"><a href="#cb15-81" aria-hidden="true" tabindex="-1"></a><span class="co"># X: [seq_len, batch, input_size] — ONNX LSTM uses time-first layout</span></span>
<span id="cb15-82"><a href="#cb15-82" aria-hidden="true" tabindex="-1"></a>x_info    <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>, TensorProto.FLOAT,</span>
<span id="cb15-83"><a href="#cb15-83" aria-hidden="true" tabindex="-1"></a>                                           [seq_len, <span class="st">"batch"</span>, input_size])</span>
<span id="cb15-84"><a href="#cb15-84" aria-hidden="true" tabindex="-1"></a>prob_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"probs"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">5</span>])</span>
<span id="cb15-85"><a href="#cb15-85" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-86"><a href="#cb15-86" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(</span>
<span id="cb15-87"><a href="#cb15-87" aria-hidden="true" tabindex="-1"></a>    [lstm_node, squeeze_node, fc_node, softmax_node],</span>
<span id="cb15-88"><a href="#cb15-88" aria-hidden="true" tabindex="-1"></a>    <span class="st">"lstm_classifier"</span>,</span>
<span id="cb15-89"><a href="#cb15-89" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info],</span>
<span id="cb15-90"><a href="#cb15-90" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[prob_info],</span>
<span id="cb15-91"><a href="#cb15-91" aria-hidden="true" tabindex="-1"></a>    initializer<span class="op">=</span>[W_init, R_init, B_init, squeeze_axes, fc_W_init, fc_b_init],</span>
<span id="cb15-92"><a href="#cb15-92" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-93"><a href="#cb15-93" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(graph, opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)])</span>
<span id="cb15-94"><a href="#cb15-94" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb15-95"><a href="#cb15-95" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-96"><a href="#cb15-96" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb15-97"><a href="#cb15-97" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"lstm_classifier.onnx"</span>)</span>
<span id="cb15-98"><a href="#cb15-98" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"LSTM model saved."</span>)</span></code></pre></div></div>
<p>Crucially, ONNX LSTM takes input in <code>[seq_len, batch, input_size]</code> order (time-first). If your data is batch-first <code>[batch, seq_len, input_size]</code>, add a <code>Transpose</code> node before the LSTM:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a>transpose_node <span class="op">=</span> helper.make_node(</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Transpose"</span>,</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x_batch_first"</span>],</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"x"</span>],</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    perm<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">0</span>, <span class="dv">2</span>],  <span class="co"># swap seq and batch dimensions</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"batch_to_seq_first"</span>,</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="building-a-transformer-block" class="level2">
<h2 class="anchored" data-anchor-id="building-a-transformer-block" id="building-a-transformer-block">Building a Transformer Block</h2>
<p>A Transformer block is the most involved architecture to assemble in raw ONNX, but it is an outstanding exercise in understanding attention. We build a single encoder block: multi-head self-attention followed by a feed-forward network, both with residual connections and layer normalization.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>d_model   <span class="op">=</span> <span class="dv">64</span>    <span class="co"># embedding dimension</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>n_heads   <span class="op">=</span> <span class="dv">4</span>     <span class="co"># attention heads</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>d_k       <span class="op">=</span> d_model <span class="op">//</span> n_heads  <span class="co"># key/query dimension per head = 16</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>d_ff      <span class="op">=</span> <span class="dv">256</span>   <span class="co"># feed-forward inner dimension</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>seq_len   <span class="op">=</span> <span class="dv">10</span>    <span class="co"># sequence length (fixed for this example)</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>eps       <span class="op">=</span> <span class="fl">1e-6</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>rng <span class="op">=</span> np.random.default_rng(<span class="dv">42</span>)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> rand_f32(shape, name):</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> rng.standard_normal(shape).astype(np.float32) <span class="op">*</span> <span class="fl">0.02</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> numpy_helper.from_array(data, name<span class="op">=</span>name)</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> zeros_f32(shape, name):</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> numpy_helper.from_array(np.zeros(shape, dtype<span class="op">=</span>np.float32), name<span class="op">=</span>name)</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> ones_f32(shape, name):</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> numpy_helper.from_array(np.ones(shape, dtype<span class="op">=</span>np.float32), name<span class="op">=</span>name)</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>inits  <span class="op">=</span> []</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>nodes  <span class="op">=</span> []</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Projection weights for Q, K, V, and output                        #</span></span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a><span class="co"># [d_model, d_model] — we will split heads in-graph                  #</span></span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [rand_f32((d_model, d_model), <span class="st">"W_Q"</span>),</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>          rand_f32((d_model, d_model), <span class="st">"W_K"</span>),</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>          rand_f32((d_model, d_model), <span class="st">"W_V"</span>),</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>          rand_f32((d_model, d_model), <span class="st">"W_O"</span>),</span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>          zeros_f32((d_model,), <span class="st">"b_Q"</span>),</span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>          zeros_f32((d_model,), <span class="st">"b_K"</span>),</span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>          zeros_f32((d_model,), <span class="st">"b_V"</span>),</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>          zeros_f32((d_model,), <span class="st">"b_O"</span>)]</span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Feed-forward weights</span></span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [rand_f32((d_model, d_ff), <span class="st">"W_ff1"</span>), zeros_f32((d_ff,),    <span class="st">"b_ff1"</span>),</span>
<span id="cb17-42"><a href="#cb17-42" aria-hidden="true" tabindex="-1"></a>          rand_f32((d_ff, d_model), <span class="st">"W_ff2"</span>), zeros_f32((d_model,), <span class="st">"b_ff2"</span>)]</span>
<span id="cb17-43"><a href="#cb17-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-44"><a href="#cb17-44" aria-hidden="true" tabindex="-1"></a><span class="co"># LayerNorm parameters (two sets: after attention, after FFN)</span></span>
<span id="cb17-45"><a href="#cb17-45" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [ones_f32((d_model,),  <span class="st">"ln1_scale"</span>), zeros_f32((d_model,), <span class="st">"ln1_bias"</span>),</span>
<span id="cb17-46"><a href="#cb17-46" aria-hidden="true" tabindex="-1"></a>          ones_f32((d_model,),  <span class="st">"ln2_scale"</span>), zeros_f32((d_model,), <span class="st">"ln2_bias"</span>)]</span>
<span id="cb17-47"><a href="#cb17-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-48"><a href="#cb17-48" aria-hidden="true" tabindex="-1"></a><span class="co"># Scale factor for attention: 1 / sqrt(d_k)</span></span>
<span id="cb17-49"><a href="#cb17-49" aria-hidden="true" tabindex="-1"></a>scale_val <span class="op">=</span> np.array(<span class="fl">1.0</span> <span class="op">/</span> np.sqrt(d_k), dtype<span class="op">=</span>np.float32)</span>
<span id="cb17-50"><a href="#cb17-50" aria-hidden="true" tabindex="-1"></a>inits.append(numpy_helper.from_array(scale_val, name<span class="op">=</span><span class="st">"attn_scale"</span>))</span>
<span id="cb17-51"><a href="#cb17-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-52"><a href="#cb17-52" aria-hidden="true" tabindex="-1"></a><span class="co"># Shape constants for Reshape operations</span></span>
<span id="cb17-53"><a href="#cb17-53" aria-hidden="true" tabindex="-1"></a>reshape_to_heads <span class="op">=</span> np.array([<span class="op">-</span><span class="dv">1</span>, seq_len, n_heads, d_k], dtype<span class="op">=</span>np.int64)</span>
<span id="cb17-54"><a href="#cb17-54" aria-hidden="true" tabindex="-1"></a>inits.append(numpy_helper.from_array(reshape_to_heads, name<span class="op">=</span><span class="st">"shape_heads"</span>))</span>
<span id="cb17-55"><a href="#cb17-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-56"><a href="#cb17-56" aria-hidden="true" tabindex="-1"></a>restore_shape <span class="op">=</span> np.array([<span class="op">-</span><span class="dv">1</span>, seq_len, d_model], dtype<span class="op">=</span>np.int64)</span>
<span id="cb17-57"><a href="#cb17-57" aria-hidden="true" tabindex="-1"></a>inits.append(numpy_helper.from_array(restore_shape, name<span class="op">=</span><span class="st">"shape_restore"</span>))</span>
<span id="cb17-58"><a href="#cb17-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-59"><a href="#cb17-59" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-60"><a href="#cb17-60" aria-hidden="true" tabindex="-1"></a><span class="co"># MULTI-HEAD SELF-</span><span class="al">ATTENTION</span><span class="co">                                           #</span></span>
<span id="cb17-61"><a href="#cb17-61" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-62"><a href="#cb17-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-63"><a href="#cb17-63" aria-hidden="true" tabindex="-1"></a><span class="co"># --- Compute Q, K, V projections ---</span></span>
<span id="cb17-64"><a href="#cb17-64" aria-hidden="true" tabindex="-1"></a><span class="co"># MatMul: [batch, seq, d_model] @ [d_model, d_model] → [batch, seq, d_model]</span></span>
<span id="cb17-65"><a href="#cb17-65" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> letter <span class="kw">in</span> [<span class="st">"Q"</span>, <span class="st">"K"</span>, <span class="st">"V"</span>]:</span>
<span id="cb17-66"><a href="#cb17-66" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-67"><a href="#cb17-67" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="ss">f"W_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>],</span>
<span id="cb17-68"><a href="#cb17-68" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">_proj"</span>],</span>
<span id="cb17-69"><a href="#cb17-69" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"matmul_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb17-70"><a href="#cb17-70" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb17-71"><a href="#cb17-71" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-72"><a href="#cb17-72" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">_proj"</span>, <span class="ss">f"b_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>],</span>
<span id="cb17-73"><a href="#cb17-73" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>],</span>
<span id="cb17-74"><a href="#cb17-74" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"add_bias_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb17-75"><a href="#cb17-75" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb17-76"><a href="#cb17-76" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-77"><a href="#cb17-77" aria-hidden="true" tabindex="-1"></a><span class="co"># --- Reshape to [batch, seq, n_heads, d_k] ---</span></span>
<span id="cb17-78"><a href="#cb17-78" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> letter <span class="kw">in</span> [<span class="st">"Q"</span>, <span class="st">"K"</span>, <span class="st">"V"</span>]:</span>
<span id="cb17-79"><a href="#cb17-79" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Reshape"</span>,</span>
<span id="cb17-80"><a href="#cb17-80" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[letter, <span class="st">"shape_heads"</span>],</span>
<span id="cb17-81"><a href="#cb17-81" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">_h"</span>],</span>
<span id="cb17-82"><a href="#cb17-82" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"reshape_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb17-83"><a href="#cb17-83" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb17-84"><a href="#cb17-84" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-85"><a href="#cb17-85" aria-hidden="true" tabindex="-1"></a><span class="co"># --- Transpose to [batch, n_heads, seq, d_k] ---</span></span>
<span id="cb17-86"><a href="#cb17-86" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> letter <span class="kw">in</span> [<span class="st">"Q"</span>, <span class="st">"K"</span>, <span class="st">"V"</span>]:</span>
<span id="cb17-87"><a href="#cb17-87" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Transpose"</span>,</span>
<span id="cb17-88"><a href="#cb17-88" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">_h"</span>],</span>
<span id="cb17-89"><a href="#cb17-89" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">_t"</span>],</span>
<span id="cb17-90"><a href="#cb17-90" aria-hidden="true" tabindex="-1"></a>        perm<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">2</span>, <span class="dv">1</span>, <span class="dv">3</span>],</span>
<span id="cb17-91"><a href="#cb17-91" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"transpose_</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb17-92"><a href="#cb17-92" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb17-93"><a href="#cb17-93" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-94"><a href="#cb17-94" aria-hidden="true" tabindex="-1"></a><span class="co"># --- Attention scores: Q @ K^T ---</span></span>
<span id="cb17-95"><a href="#cb17-95" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Transpose"</span>,</span>
<span id="cb17-96"><a href="#cb17-96" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"K_t"</span>],</span>
<span id="cb17-97"><a href="#cb17-97" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"K_t_T"</span>],</span>
<span id="cb17-98"><a href="#cb17-98" aria-hidden="true" tabindex="-1"></a>    perm<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">2</span>],</span>
<span id="cb17-99"><a href="#cb17-99" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"transpose_K_for_attn"</span>,</span>
<span id="cb17-100"><a href="#cb17-100" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-101"><a href="#cb17-101" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-102"><a href="#cb17-102" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-103"><a href="#cb17-103" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"Q_t"</span>, <span class="st">"K_t_T"</span>],</span>
<span id="cb17-104"><a href="#cb17-104" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"raw_scores"</span>],</span>
<span id="cb17-105"><a href="#cb17-105" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"attn_scores"</span>,</span>
<span id="cb17-106"><a href="#cb17-106" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-107"><a href="#cb17-107" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-108"><a href="#cb17-108" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Mul"</span>,</span>
<span id="cb17-109"><a href="#cb17-109" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"raw_scores"</span>, <span class="st">"attn_scale"</span>],</span>
<span id="cb17-110"><a href="#cb17-110" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"scaled_scores"</span>],</span>
<span id="cb17-111"><a href="#cb17-111" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"scale_scores"</span>,</span>
<span id="cb17-112"><a href="#cb17-112" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-113"><a href="#cb17-113" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-114"><a href="#cb17-114" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Softmax"</span>,</span>
<span id="cb17-115"><a href="#cb17-115" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"scaled_scores"</span>],</span>
<span id="cb17-116"><a href="#cb17-116" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"attn_weights"</span>],</span>
<span id="cb17-117"><a href="#cb17-117" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb17-118"><a href="#cb17-118" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"attn_softmax"</span>,</span>
<span id="cb17-119"><a href="#cb17-119" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-120"><a href="#cb17-120" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-121"><a href="#cb17-121" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-122"><a href="#cb17-122" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"attn_weights"</span>, <span class="st">"V_t"</span>],</span>
<span id="cb17-123"><a href="#cb17-123" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"context_t"</span>],</span>
<span id="cb17-124"><a href="#cb17-124" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"attn_context"</span>,</span>
<span id="cb17-125"><a href="#cb17-125" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-126"><a href="#cb17-126" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-127"><a href="#cb17-127" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Transpose"</span>,</span>
<span id="cb17-128"><a href="#cb17-128" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"context_t"</span>],</span>
<span id="cb17-129"><a href="#cb17-129" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"context_h"</span>],</span>
<span id="cb17-130"><a href="#cb17-130" aria-hidden="true" tabindex="-1"></a>    perm<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">2</span>, <span class="dv">1</span>, <span class="dv">3</span>],</span>
<span id="cb17-131"><a href="#cb17-131" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"transpose_context"</span>,</span>
<span id="cb17-132"><a href="#cb17-132" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-133"><a href="#cb17-133" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-134"><a href="#cb17-134" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Reshape"</span>,</span>
<span id="cb17-135"><a href="#cb17-135" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"context_h"</span>, <span class="st">"shape_restore"</span>],</span>
<span id="cb17-136"><a href="#cb17-136" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"context"</span>],</span>
<span id="cb17-137"><a href="#cb17-137" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"reshape_context"</span>,</span>
<span id="cb17-138"><a href="#cb17-138" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-139"><a href="#cb17-139" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-140"><a href="#cb17-140" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-141"><a href="#cb17-141" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"context"</span>, <span class="st">"W_O"</span>],</span>
<span id="cb17-142"><a href="#cb17-142" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"attn_out_proj"</span>],</span>
<span id="cb17-143"><a href="#cb17-143" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"output_proj"</span>,</span>
<span id="cb17-144"><a href="#cb17-144" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-145"><a href="#cb17-145" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-146"><a href="#cb17-146" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"attn_out_proj"</span>, <span class="st">"b_O"</span>],</span>
<span id="cb17-147"><a href="#cb17-147" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"attn_out"</span>],</span>
<span id="cb17-148"><a href="#cb17-148" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"add_output_bias"</span>,</span>
<span id="cb17-149"><a href="#cb17-149" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-150"><a href="#cb17-150" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-151"><a href="#cb17-151" aria-hidden="true" tabindex="-1"></a><span class="co"># Residual + LayerNorm</span></span>
<span id="cb17-152"><a href="#cb17-152" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-153"><a href="#cb17-153" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x"</span>, <span class="st">"attn_out"</span>],</span>
<span id="cb17-154"><a href="#cb17-154" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"residual1"</span>],</span>
<span id="cb17-155"><a href="#cb17-155" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"residual1"</span>,</span>
<span id="cb17-156"><a href="#cb17-156" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-157"><a href="#cb17-157" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"LayerNormalization"</span>,</span>
<span id="cb17-158"><a href="#cb17-158" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"residual1"</span>, <span class="st">"ln1_scale"</span>, <span class="st">"ln1_bias"</span>],</span>
<span id="cb17-159"><a href="#cb17-159" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ln1_out"</span>],</span>
<span id="cb17-160"><a href="#cb17-160" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb17-161"><a href="#cb17-161" aria-hidden="true" tabindex="-1"></a>    epsilon<span class="op">=</span>eps,</span>
<span id="cb17-162"><a href="#cb17-162" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"layernorm1"</span>,</span>
<span id="cb17-163"><a href="#cb17-163" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-164"><a href="#cb17-164" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-165"><a href="#cb17-165" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-166"><a href="#cb17-166" aria-hidden="true" tabindex="-1"></a><span class="co"># FEED-FORWARD NETWORK                                                #</span></span>
<span id="cb17-167"><a href="#cb17-167" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-168"><a href="#cb17-168" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-169"><a href="#cb17-169" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-170"><a href="#cb17-170" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ln1_out"</span>, <span class="st">"W_ff1"</span>],</span>
<span id="cb17-171"><a href="#cb17-171" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ff1_proj"</span>],</span>
<span id="cb17-172"><a href="#cb17-172" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"ff1_proj"</span>,</span>
<span id="cb17-173"><a href="#cb17-173" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-174"><a href="#cb17-174" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-175"><a href="#cb17-175" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ff1_proj"</span>, <span class="st">"b_ff1"</span>],</span>
<span id="cb17-176"><a href="#cb17-176" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ff1"</span>],</span>
<span id="cb17-177"><a href="#cb17-177" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"ff1_bias"</span>,</span>
<span id="cb17-178"><a href="#cb17-178" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-179"><a href="#cb17-179" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Relu"</span>,</span>
<span id="cb17-180"><a href="#cb17-180" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ff1"</span>],</span>
<span id="cb17-181"><a href="#cb17-181" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ff1_relu"</span>],</span>
<span id="cb17-182"><a href="#cb17-182" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"ff1_relu"</span>,</span>
<span id="cb17-183"><a href="#cb17-183" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-184"><a href="#cb17-184" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MatMul"</span>,</span>
<span id="cb17-185"><a href="#cb17-185" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ff1_relu"</span>, <span class="st">"W_ff2"</span>],</span>
<span id="cb17-186"><a href="#cb17-186" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ff2_proj"</span>],</span>
<span id="cb17-187"><a href="#cb17-187" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"ff2_proj"</span>,</span>
<span id="cb17-188"><a href="#cb17-188" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-189"><a href="#cb17-189" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-190"><a href="#cb17-190" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ff2_proj"</span>, <span class="st">"b_ff2"</span>],</span>
<span id="cb17-191"><a href="#cb17-191" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"ff2"</span>],</span>
<span id="cb17-192"><a href="#cb17-192" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"ff2_bias"</span>,</span>
<span id="cb17-193"><a href="#cb17-193" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-194"><a href="#cb17-194" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-195"><a href="#cb17-195" aria-hidden="true" tabindex="-1"></a><span class="co"># Residual + LayerNorm</span></span>
<span id="cb17-196"><a href="#cb17-196" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb17-197"><a href="#cb17-197" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"ln1_out"</span>, <span class="st">"ff2"</span>],</span>
<span id="cb17-198"><a href="#cb17-198" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"residual2"</span>],</span>
<span id="cb17-199"><a href="#cb17-199" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"residual2"</span>,</span>
<span id="cb17-200"><a href="#cb17-200" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-201"><a href="#cb17-201" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"LayerNormalization"</span>,</span>
<span id="cb17-202"><a href="#cb17-202" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"residual2"</span>, <span class="st">"ln2_scale"</span>, <span class="st">"ln2_bias"</span>],</span>
<span id="cb17-203"><a href="#cb17-203" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"output"</span>],</span>
<span id="cb17-204"><a href="#cb17-204" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb17-205"><a href="#cb17-205" aria-hidden="true" tabindex="-1"></a>    epsilon<span class="op">=</span>eps,</span>
<span id="cb17-206"><a href="#cb17-206" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"layernorm2"</span>,</span>
<span id="cb17-207"><a href="#cb17-207" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb17-208"><a href="#cb17-208" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-209"><a href="#cb17-209" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-210"><a href="#cb17-210" aria-hidden="true" tabindex="-1"></a><span class="co"># Graph assembly                                                      #</span></span>
<span id="cb17-211"><a href="#cb17-211" aria-hidden="true" tabindex="-1"></a><span class="co"># ================================================================== #</span></span>
<span id="cb17-212"><a href="#cb17-212" aria-hidden="true" tabindex="-1"></a>x_info   <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>,      TensorProto.FLOAT, [<span class="st">"batch"</span>, seq_len, d_model])</span>
<span id="cb17-213"><a href="#cb17-213" aria-hidden="true" tabindex="-1"></a>out_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"output"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, seq_len, d_model])</span>
<span id="cb17-214"><a href="#cb17-214" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-215"><a href="#cb17-215" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(nodes, <span class="st">"transformer_encoder_block"</span>,</span>
<span id="cb17-216"><a href="#cb17-216" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info], outputs<span class="op">=</span>[out_info], initializer<span class="op">=</span>inits)</span>
<span id="cb17-217"><a href="#cb17-217" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(graph, opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)])</span>
<span id="cb17-218"><a href="#cb17-218" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb17-219"><a href="#cb17-219" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-220"><a href="#cb17-220" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb17-221"><a href="#cb17-221" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"transformer_block.onnx"</span>)</span>
<span id="cb17-222"><a href="#cb17-222" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Transformer block saved."</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Transformer-specific ONNX patterns
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>3D MatMul</strong>: When one operand is 2D <code>[d_model, d_model]</code> and the other is 3D <code>[batch, seq, d_model]</code>, ONNX’s MatMul broadcasts over the batch dimension automatically.</li>
<li><strong>Reshape + Transpose for multi-head attention</strong>: The head-splitting is entirely explicit. You reshape the projected Q/K/V to expose the head dimension, then transpose to make it the second axis for batched matrix multiplication.</li>
<li><strong><code>LayerNormalization</code></strong>: Available from opset 17. It takes <code>scale</code> and <code>bias</code> as separate inputs (not attributes), and normalizes along all axes from <code>axis</code> to the last.</li>
<li><strong>Broadcasting of the scale scalar</strong>: The <code>attn_scale</code> constant is a scalar <code>np.float32</code> value. ONNX’s <code>Mul</code> operator broadcasts it across the entire <code>[batch, heads, seq, seq]</code> scores tensor without any reshape.</li>
</ul>
</div>
</div>
</section>
<section id="building-a-residual-resnet-style-block" class="level2">
<h2 class="anchored" data-anchor-id="building-a-residual-resnet-style-block" id="building-a-residual-resnet-style-block">Building a Residual (ResNet-style) Block</h2>
<p>Residual connections are essential for deep networks. In ONNX, they are simply <code>Add</code> nodes where one input comes from early in the graph.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto, numpy_helper</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> make_conv_bn_relu(x_name, out_name, in_ch, out_ch, stride, inits, nodes, kH<span class="op">=</span><span class="dv">3</span>, kW<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Adds Conv → BN → ReLU nodes and their initializers in-place."""</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    fan_in  <span class="op">=</span> in_ch <span class="op">*</span> kH <span class="op">*</span> kW</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    W_data  <span class="op">=</span> np.random.randn(out_ch, in_ch, kH, kW).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> fan_in)</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    W_init  <span class="op">=</span> numpy_helper.from_array(W_data, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_cW"</span>)</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    b_init  <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_cb"</span>)</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    sc_init <span class="op">=</span> numpy_helper.from_array(np.ones(out_ch,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bns"</span>)</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    bi_init <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnb"</span>)</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>    mn_init <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnm"</span>)</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    vr_init <span class="op">=</span> numpy_helper.from_array(np.ones(out_ch,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnv"</span>)</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    inits <span class="op">+=</span> [W_init, b_init, sc_init, bi_init, mn_init, vr_init]</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    conv_out <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c"</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    bn_out   <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn"</span></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Conv"</span>,</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[x_name, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_cW"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_cb"</span>],</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[conv_out],</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        kernel_shape<span class="op">=</span>[kH, kW],</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>        strides<span class="op">=</span>[stride, stride],</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>        pads<span class="op">=</span>[kH<span class="op">//</span><span class="dv">2</span>, kW<span class="op">//</span><span class="dv">2</span>, kH<span class="op">//</span><span class="dv">2</span>, kW<span class="op">//</span><span class="dv">2</span>],</span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_conv"</span>,</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"BatchNormalization"</span>,</span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[conv_out, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bns"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnb"</span>,</span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>                <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnm"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bnv"</span>],</span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[bn_out],</span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        epsilon<span class="op">=</span><span class="fl">1e-5</span>, momentum<span class="op">=</span><span class="fl">0.1</span>,</span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn"</span>,</span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Relu"</span>,</span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[bn_out],</span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[out_name],</span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_relu"</span>,</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> make_residual_block(x_name, out_name, in_ch, out_ch, stride, inits, nodes):</span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a><span class="co">    A basic ResNet residual block.</span></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a><span class="co">    If in_ch != out_ch or stride != 1, a 1x1 projection shortcut is added.</span></span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>    mid_name <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_mid"</span></span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>    make_conv_bn_relu(x_name, mid_name, in_ch, out_ch, stride, inits, nodes)</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>    fan_in  <span class="op">=</span> out_ch <span class="op">*</span> <span class="dv">3</span> <span class="op">*</span> <span class="dv">3</span></span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>    W2_data <span class="op">=</span> np.random.randn(out_ch, out_ch, <span class="dv">3</span>, <span class="dv">3</span>).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> fan_in)</span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>    W2_init <span class="op">=</span> numpy_helper.from_array(W2_data, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c2W"</span>)</span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>    b2_init <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c2b"</span>)</span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>    sc2     <span class="op">=</span> numpy_helper.from_array(np.ones(out_ch,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2s"</span>)</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>    bi2     <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2b"</span>)</span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>    mn2     <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2m"</span>)</span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>    vr2     <span class="op">=</span> numpy_helper.from_array(np.ones(out_ch,  dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2v"</span>)</span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>    inits  <span class="op">+=</span> [W2_init, b2_init, sc2, bi2, mn2, vr2]</span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-59"><a href="#cb18-59" aria-hidden="true" tabindex="-1"></a>    conv2_out <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c2"</span></span>
<span id="cb18-60"><a href="#cb18-60" aria-hidden="true" tabindex="-1"></a>    bn2_out   <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2"</span></span>
<span id="cb18-61"><a href="#cb18-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-62"><a href="#cb18-62" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Conv"</span>,</span>
<span id="cb18-63"><a href="#cb18-63" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[mid_name, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c2W"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_c2b"</span>],</span>
<span id="cb18-64"><a href="#cb18-64" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[conv2_out],</span>
<span id="cb18-65"><a href="#cb18-65" aria-hidden="true" tabindex="-1"></a>        kernel_shape<span class="op">=</span>[<span class="dv">3</span>, <span class="dv">3</span>], strides<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>], pads<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb18-66"><a href="#cb18-66" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_conv2"</span>,</span>
<span id="cb18-67"><a href="#cb18-67" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-68"><a href="#cb18-68" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"BatchNormalization"</span>,</span>
<span id="cb18-69"><a href="#cb18-69" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[conv2_out, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2s"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2b"</span>,</span>
<span id="cb18-70"><a href="#cb18-70" aria-hidden="true" tabindex="-1"></a>                <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2m"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2v"</span>],</span>
<span id="cb18-71"><a href="#cb18-71" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[bn2_out],</span>
<span id="cb18-72"><a href="#cb18-72" aria-hidden="true" tabindex="-1"></a>        epsilon<span class="op">=</span><span class="fl">1e-5</span>, momentum<span class="op">=</span><span class="fl">0.1</span>,</span>
<span id="cb18-73"><a href="#cb18-73" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_bn2"</span>,</span>
<span id="cb18-74"><a href="#cb18-74" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-75"><a href="#cb18-75" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-76"><a href="#cb18-76" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> in_ch <span class="op">!=</span> out_ch <span class="kw">or</span> stride <span class="op">!=</span> <span class="dv">1</span>:</span>
<span id="cb18-77"><a href="#cb18-77" aria-hidden="true" tabindex="-1"></a>        Ws_data <span class="op">=</span> np.random.randn(out_ch, in_ch, <span class="dv">1</span>, <span class="dv">1</span>).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> in_ch)</span>
<span id="cb18-78"><a href="#cb18-78" aria-hidden="true" tabindex="-1"></a>        Ws_init <span class="op">=</span> numpy_helper.from_array(Ws_data, name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sW"</span>)</span>
<span id="cb18-79"><a href="#cb18-79" aria-hidden="true" tabindex="-1"></a>        bs_init <span class="op">=</span> numpy_helper.from_array(np.zeros(out_ch, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sb"</span>)</span>
<span id="cb18-80"><a href="#cb18-80" aria-hidden="true" tabindex="-1"></a>        inits  <span class="op">+=</span> [Ws_init, bs_init]</span>
<span id="cb18-81"><a href="#cb18-81" aria-hidden="true" tabindex="-1"></a>        shortcut_name <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_shortcut"</span></span>
<span id="cb18-82"><a href="#cb18-82" aria-hidden="true" tabindex="-1"></a>        nodes.append(helper.make_node(<span class="st">"Conv"</span>,</span>
<span id="cb18-83"><a href="#cb18-83" aria-hidden="true" tabindex="-1"></a>            inputs<span class="op">=</span>[x_name, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sW"</span>, <span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sb"</span>],</span>
<span id="cb18-84"><a href="#cb18-84" aria-hidden="true" tabindex="-1"></a>            outputs<span class="op">=</span>[shortcut_name],</span>
<span id="cb18-85"><a href="#cb18-85" aria-hidden="true" tabindex="-1"></a>            kernel_shape<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>], strides<span class="op">=</span>[stride, stride], pads<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">0</span>, <span class="dv">0</span>, <span class="dv">0</span>],</span>
<span id="cb18-86"><a href="#cb18-86" aria-hidden="true" tabindex="-1"></a>            name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_shortcut_conv"</span>,</span>
<span id="cb18-87"><a href="#cb18-87" aria-hidden="true" tabindex="-1"></a>        ))</span>
<span id="cb18-88"><a href="#cb18-88" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb18-89"><a href="#cb18-89" aria-hidden="true" tabindex="-1"></a>        shortcut_name <span class="op">=</span> x_name</span>
<span id="cb18-90"><a href="#cb18-90" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-91"><a href="#cb18-91" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Add"</span>,</span>
<span id="cb18-92"><a href="#cb18-92" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[bn2_out, shortcut_name],</span>
<span id="cb18-93"><a href="#cb18-93" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sum"</span>],</span>
<span id="cb18-94"><a href="#cb18-94" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_add"</span>,</span>
<span id="cb18-95"><a href="#cb18-95" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-96"><a href="#cb18-96" aria-hidden="true" tabindex="-1"></a>    nodes.append(helper.make_node(<span class="st">"Relu"</span>,</span>
<span id="cb18-97"><a href="#cb18-97" aria-hidden="true" tabindex="-1"></a>        inputs<span class="op">=</span>[<span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_sum"</span>],</span>
<span id="cb18-98"><a href="#cb18-98" aria-hidden="true" tabindex="-1"></a>        outputs<span class="op">=</span>[out_name],</span>
<span id="cb18-99"><a href="#cb18-99" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="ss">f"</span><span class="sc">{</span>out_name<span class="sc">}</span><span class="ss">_relu_final"</span>,</span>
<span id="cb18-100"><a href="#cb18-100" aria-hidden="true" tabindex="-1"></a>    ))</span>
<span id="cb18-101"><a href="#cb18-101" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-102"><a href="#cb18-102" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb18-103"><a href="#cb18-103" aria-hidden="true" tabindex="-1"></a><span class="co"># Build a tiny ResNet                                                 #</span></span>
<span id="cb18-104"><a href="#cb18-104" aria-hidden="true" tabindex="-1"></a><span class="co"># ------------------------------------------------------------------ #</span></span>
<span id="cb18-105"><a href="#cb18-105" aria-hidden="true" tabindex="-1"></a>inits <span class="op">=</span> []</span>
<span id="cb18-106"><a href="#cb18-106" aria-hidden="true" tabindex="-1"></a>nodes <span class="op">=</span> []</span>
<span id="cb18-107"><a href="#cb18-107" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-108"><a href="#cb18-108" aria-hidden="true" tabindex="-1"></a>make_conv_bn_relu(<span class="st">"x"</span>, <span class="st">"stem_out"</span>, in_ch<span class="op">=</span><span class="dv">3</span>, out_ch<span class="op">=</span><span class="dv">64</span>, stride<span class="op">=</span><span class="dv">2</span>, inits<span class="op">=</span>inits, nodes<span class="op">=</span>nodes, kH<span class="op">=</span><span class="dv">7</span>, kW<span class="op">=</span><span class="dv">7</span>)</span>
<span id="cb18-109"><a href="#cb18-109" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"MaxPool"</span>,</span>
<span id="cb18-110"><a href="#cb18-110" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"stem_out"</span>], outputs<span class="op">=</span>[<span class="st">"pool_out"</span>],</span>
<span id="cb18-111"><a href="#cb18-111" aria-hidden="true" tabindex="-1"></a>    kernel_shape<span class="op">=</span>[<span class="dv">3</span>, <span class="dv">3</span>], strides<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">2</span>], pads<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb18-112"><a href="#cb18-112" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"stem_pool"</span>,</span>
<span id="cb18-113"><a href="#cb18-113" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb18-114"><a href="#cb18-114" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-115"><a href="#cb18-115" aria-hidden="true" tabindex="-1"></a>make_residual_block(<span class="st">"pool_out"</span>, <span class="st">"layer1a"</span>, in_ch<span class="op">=</span><span class="dv">64</span>, out_ch<span class="op">=</span><span class="dv">64</span>, stride<span class="op">=</span><span class="dv">1</span>, inits<span class="op">=</span>inits, nodes<span class="op">=</span>nodes)</span>
<span id="cb18-116"><a href="#cb18-116" aria-hidden="true" tabindex="-1"></a>make_residual_block(<span class="st">"layer1a"</span>,  <span class="st">"layer1b"</span>, in_ch<span class="op">=</span><span class="dv">64</span>, out_ch<span class="op">=</span><span class="dv">64</span>, stride<span class="op">=</span><span class="dv">1</span>, inits<span class="op">=</span>inits, nodes<span class="op">=</span>nodes)</span>
<span id="cb18-117"><a href="#cb18-117" aria-hidden="true" tabindex="-1"></a>make_residual_block(<span class="st">"layer1b"</span>, <span class="st">"layer2a"</span>, in_ch<span class="op">=</span><span class="dv">64</span>,  out_ch<span class="op">=</span><span class="dv">128</span>, stride<span class="op">=</span><span class="dv">2</span>, inits<span class="op">=</span>inits, nodes<span class="op">=</span>nodes)</span>
<span id="cb18-118"><a href="#cb18-118" aria-hidden="true" tabindex="-1"></a>make_residual_block(<span class="st">"layer2a"</span>, <span class="st">"layer2b"</span>, in_ch<span class="op">=</span><span class="dv">128</span>, out_ch<span class="op">=</span><span class="dv">128</span>, stride<span class="op">=</span><span class="dv">1</span>, inits<span class="op">=</span>inits, nodes<span class="op">=</span>nodes)</span>
<span id="cb18-119"><a href="#cb18-119" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-120"><a href="#cb18-120" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"GlobalAveragePool"</span>,</span>
<span id="cb18-121"><a href="#cb18-121" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"layer2b"</span>], outputs<span class="op">=</span>[<span class="st">"gap_out"</span>],</span>
<span id="cb18-122"><a href="#cb18-122" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"global_avg_pool"</span>,</span>
<span id="cb18-123"><a href="#cb18-123" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb18-124"><a href="#cb18-124" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Flatten"</span>,</span>
<span id="cb18-125"><a href="#cb18-125" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"gap_out"</span>], outputs<span class="op">=</span>[<span class="st">"flat_out"</span>],</span>
<span id="cb18-126"><a href="#cb18-126" aria-hidden="true" tabindex="-1"></a>    axis<span class="op">=</span><span class="dv">1</span>, name<span class="op">=</span><span class="st">"flatten"</span>,</span>
<span id="cb18-127"><a href="#cb18-127" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb18-128"><a href="#cb18-128" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-129"><a href="#cb18-129" aria-hidden="true" tabindex="-1"></a>fc_W <span class="op">=</span> numpy_helper.from_array(</span>
<span id="cb18-130"><a href="#cb18-130" aria-hidden="true" tabindex="-1"></a>    np.random.randn(<span class="dv">128</span>, <span class="dv">10</span>).astype(np.float32) <span class="op">*</span> <span class="fl">0.01</span>, name<span class="op">=</span><span class="st">"fc_W"</span>)</span>
<span id="cb18-131"><a href="#cb18-131" aria-hidden="true" tabindex="-1"></a>fc_b <span class="op">=</span> numpy_helper.from_array(np.zeros(<span class="dv">10</span>, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="st">"fc_b"</span>)</span>
<span id="cb18-132"><a href="#cb18-132" aria-hidden="true" tabindex="-1"></a>inits <span class="op">+=</span> [fc_W, fc_b]</span>
<span id="cb18-133"><a href="#cb18-133" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-134"><a href="#cb18-134" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Gemm"</span>,</span>
<span id="cb18-135"><a href="#cb18-135" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"flat_out"</span>, <span class="st">"fc_W"</span>, <span class="st">"fc_b"</span>],</span>
<span id="cb18-136"><a href="#cb18-136" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"logits"</span>], alpha<span class="op">=</span><span class="fl">1.0</span>, beta<span class="op">=</span><span class="fl">1.0</span>, name<span class="op">=</span><span class="st">"classifier"</span>,</span>
<span id="cb18-137"><a href="#cb18-137" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb18-138"><a href="#cb18-138" aria-hidden="true" tabindex="-1"></a>nodes.append(helper.make_node(<span class="st">"Softmax"</span>,</span>
<span id="cb18-139"><a href="#cb18-139" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"logits"</span>], outputs<span class="op">=</span>[<span class="st">"probs"</span>], axis<span class="op">=-</span><span class="dv">1</span>, name<span class="op">=</span><span class="st">"softmax"</span>,</span>
<span id="cb18-140"><a href="#cb18-140" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb18-141"><a href="#cb18-141" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-142"><a href="#cb18-142" aria-hidden="true" tabindex="-1"></a>x_info    <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"x"</span>,     TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>])</span>
<span id="cb18-143"><a href="#cb18-143" aria-hidden="true" tabindex="-1"></a>prob_info <span class="op">=</span> helper.make_tensor_value_info(<span class="st">"probs"</span>, TensorProto.FLOAT, [<span class="st">"batch"</span>, <span class="dv">10</span>])</span>
<span id="cb18-144"><a href="#cb18-144" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-145"><a href="#cb18-145" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> helper.make_graph(nodes, <span class="st">"tiny_resnet"</span>,</span>
<span id="cb18-146"><a href="#cb18-146" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[x_info], outputs<span class="op">=</span>[prob_info], initializer<span class="op">=</span>inits)</span>
<span id="cb18-147"><a href="#cb18-147" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(graph, opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)])</span>
<span id="cb18-148"><a href="#cb18-148" aria-hidden="true" tabindex="-1"></a>model.ir_version <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb18-149"><a href="#cb18-149" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-150"><a href="#cb18-150" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb18-151"><a href="#cb18-151" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"tiny_resnet.onnx"</span>)</span>
<span id="cb18-152"><a href="#cb18-152" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"ResNet-style model saved."</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>The residual block pattern is elegant in ONNX because the “skip connection” is just a string: you pass the same input name <code>x_name</code> to both the main path and the shortcut <code>Add</code> node. The graph structure itself encodes the skip without any special syntax.</p>
</div>
</div>
</section>
<section id="initializers-constants-and-weight-management" class="level2">
<h2 class="anchored" data-anchor-id="initializers-constants-and-weight-management" id="initializers-constants-and-weight-management">Initializers, Constants, and Weight Management</h2>
<p>There are two ways to embed constant data in an ONNX graph.</p>
<p><strong>Initializers</strong> are <code>TensorProto</code> objects stored in <code>graph.initializer</code>. They represent parameters (weights, biases) or other constant tensors. They are the preferred way to store large parameter tensors because they are memory-efficient and can be memory-mapped at load time.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a>W <span class="op">=</span> numpy_helper.from_array(np.eye(<span class="dv">64</span>, dtype<span class="op">=</span>np.float32), name<span class="op">=</span><span class="st">"identity_W"</span>)</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Add to graph initializer list</span></span></code></pre></div></div>
<p><strong><code>Constant</code> nodes</strong> embed a tensor directly inside a <code>NodeProto</code>. Use these for small scalars or integer constants computed mid-graph (like reshape targets):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a>const_node <span class="op">=</span> helper.make_node(</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Constant"</span>,</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[],</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"const_value"</span>],</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    value<span class="op">=</span>helper.make_tensor(</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">""</span>,</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        data_type<span class="op">=</span>TensorProto.FLOAT,</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        dims<span class="op">=</span>[],          <span class="co"># scalar</span></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>        vals<span class="op">=</span>[<span class="fl">0.5</span>],</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    ),</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>For integer shape tensors (common when using <code>Reshape</code>), you can also store them as initializers:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a>shape_const <span class="op">=</span> numpy_helper.from_array(</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    np.array([<span class="op">-</span><span class="dv">1</span>, <span class="dv">128</span>], dtype<span class="op">=</span>np.int64), name<span class="op">=</span><span class="st">"reshape_target"</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<section id="weight-initialization-strategies" class="level3">
<h3 class="anchored" data-anchor-id="weight-initialization-strategies" id="weight-initialization-strategies">Weight Initialization Strategies</h3>
<p>Since ONNX weights are just NumPy arrays, you apply initialization schemes yourself:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># He (Kaiming) initialization for ReLU networks</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> he_init(shape):</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    fan_in <span class="op">=</span> np.prod(shape[<span class="dv">1</span>:]) <span class="cf">if</span> <span class="bu">len</span>(shape) <span class="op">&gt;</span> <span class="dv">1</span> <span class="cf">else</span> shape[<span class="dv">0</span>]</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> np.random.randn(<span class="op">*</span>shape).astype(np.float32) <span class="op">*</span> np.sqrt(<span class="fl">2.0</span> <span class="op">/</span> fan_in)</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Glorot (Xavier) initialization for tanh/sigmoid</span></span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> glorot_init(shape):</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    fan_in  <span class="op">=</span> shape[<span class="dv">0</span>]</span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    fan_out <span class="op">=</span> shape[<span class="dv">1</span>] <span class="cf">if</span> <span class="bu">len</span>(shape) <span class="op">&gt;</span> <span class="dv">1</span> <span class="cf">else</span> shape[<span class="dv">0</span>]</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>    limit   <span class="op">=</span> np.sqrt(<span class="fl">6.0</span> <span class="op">/</span> (fan_in <span class="op">+</span> fan_out))</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> np.random.uniform(<span class="op">-</span>limit, limit, shape).astype(np.float32)</span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Orthogonal initialization (good for RNNs)</span></span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> orthogonal_init(shape):</span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>    flat <span class="op">=</span> np.random.randn(shape[<span class="dv">0</span>], np.prod(shape[<span class="dv">1</span>:])).astype(np.float32)</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>    U, _, Vt <span class="op">=</span> np.linalg.svd(flat, full_matrices<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (U <span class="cf">if</span> U.shape <span class="op">==</span> flat.shape <span class="cf">else</span> Vt).reshape(shape)</span></code></pre></div></div>
</section>
</section>
<section id="shape-inference-and-validation" class="level2">
<h2 class="anchored" data-anchor-id="shape-inference-and-validation" id="shape-inference-and-validation">Shape Inference and Validation</h2>
<p>ONNX provides automatic shape inference — it propagates shapes through the graph so you can verify that all intermediate tensor shapes are correct before running.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> shape_inference</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"my_model.onnx"</span>)</span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>inferred_model <span class="op">=</span> shape_inference.infer_shapes(model)</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Now inspect inferred shapes</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> vi <span class="kw">in</span> inferred_model.graph.value_info:</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> vi.<span class="bu">type</span>.tensor_type</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>    shape <span class="op">=</span> [d.dim_value <span class="cf">if</span> d.HasField(<span class="st">"dim_value"</span>) <span class="cf">else</span> d.dim_param</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>             <span class="cf">for</span> d <span class="kw">in</span> t.shape.dim]</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>vi<span class="sc">.</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>t<span class="sc">.</span>elem_type<span class="sc">}</span><span class="ss"> </span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Always run both <code>onnx.checker.check_model</code> (structural validity) and <code>shape_inference.infer_shapes</code> (shape consistency) after building a model. The checker will catch malformed protos; shape inference will catch shape mismatches before you waste time debugging at runtime.</p>
</div>
</div>
<section id="checking-shapes-programmatically" class="level3">
<h3 class="anchored" data-anchor-id="checking-shapes-programmatically" id="checking-shapes-programmatically">Checking Shapes Programmatically</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_shape(model, tensor_name):</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Return the inferred shape of any named tensor in the model."""</span></span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>    inferred <span class="op">=</span> shape_inference.infer_shapes(model)</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>    all_vi   <span class="op">=</span> (<span class="bu">list</span>(inferred.graph.<span class="bu">input</span>)</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>               <span class="op">+</span> <span class="bu">list</span>(inferred.graph.value_info)</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>               <span class="op">+</span> <span class="bu">list</span>(inferred.graph.output))</span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> vi <span class="kw">in</span> all_vi:</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> vi.name <span class="op">==</span> tensor_name:</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>            t <span class="op">=</span> vi.<span class="bu">type</span>.tensor_type</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> [d.dim_value <span class="kw">or</span> d.dim_param <span class="cf">for</span> d <span class="kw">in</span> t.shape.dim]</span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>shape <span class="op">=</span> get_shape(model, <span class="st">"relu1_out"</span>)</span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"relu1_out shape: </span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="running-inference-with-onnx-runtime" class="level2">
<h2 class="anchored" data-anchor-id="running-inference-with-onnx-runtime" id="running-inference-with-onnx-runtime">Running Inference with ONNX Runtime</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Load the session</span></span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"mlp.onnx"</span>,</span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CUDAExecutionProvider"</span>, <span class="st">"CPUExecutionProvider"</span>])</span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Inspect inputs and outputs</span></span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> inp <span class="kw">in</span> sess.get_inputs():</span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Input:  </span><span class="sc">{</span>inp<span class="sc">.</span>name<span class="sc">}</span><span class="ss"> | shape: </span><span class="sc">{</span>inp<span class="sc">.</span>shape<span class="sc">}</span><span class="ss"> | type: </span><span class="sc">{</span>inp<span class="sc">.</span><span class="bu">type</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> out <span class="kw">in</span> sess.get_outputs():</span>
<span id="cb25-12"><a href="#cb25-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Output: </span><span class="sc">{</span>out<span class="sc">.</span>name<span class="sc">}</span><span class="ss"> | shape: </span><span class="sc">{</span>out<span class="sc">.</span>shape<span class="sc">}</span><span class="ss"> | type: </span><span class="sc">{</span>out<span class="sc">.</span><span class="bu">type</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-13"><a href="#cb25-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-14"><a href="#cb25-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Run inference</span></span>
<span id="cb25-15"><a href="#cb25-15" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.random.randn(<span class="dv">8</span>, <span class="dv">784</span>).astype(np.float32)</span>
<span id="cb25-16"><a href="#cb25-16" aria-hidden="true" tabindex="-1"></a>outputs <span class="op">=</span> sess.run(</span>
<span id="cb25-17"><a href="#cb25-17" aria-hidden="true" tabindex="-1"></a>    output_names<span class="op">=</span>[<span class="st">"probs"</span>],  <span class="co"># None means "return all outputs"</span></span>
<span id="cb25-18"><a href="#cb25-18" aria-hidden="true" tabindex="-1"></a>    input_feed<span class="op">=</span>{<span class="st">"x"</span>: x},</span>
<span id="cb25-19"><a href="#cb25-19" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb25-20"><a href="#cb25-20" aria-hidden="true" tabindex="-1"></a>probs <span class="op">=</span> outputs[<span class="dv">0</span>]</span>
<span id="cb25-21"><a href="#cb25-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Output shape: </span><span class="sc">{</span>probs<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-22"><a href="#cb25-22" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Predictions:  </span><span class="sc">{</span>probs<span class="sc">.</span>argmax(axis<span class="op">=</span><span class="dv">1</span>)<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<section id="choosing-an-execution-provider" class="level3">
<h3 class="anchored" data-anchor-id="choosing-an-execution-provider" id="choosing-an-execution-provider">Choosing an Execution Provider</h3>
<p>ONNX Runtime supports multiple backends. Pass them in priority order:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"model.onnx"</span>, providers<span class="op">=</span>[</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a>    (<span class="st">"TensorrtExecutionProvider"</span>, {<span class="st">"device_id"</span>: <span class="dv">0</span>}),</span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a>    (<span class="st">"CUDAExecutionProvider"</span>,     {<span class="st">"device_id"</span>: <span class="dv">0</span>}),</span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"CPUExecutionProvider"</span>,</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(sess.get_providers())  <span class="co"># shows which providers were actually activated</span></span></code></pre></div></div>
</section>
<section id="session-options" class="level3">
<h3 class="anchored" data-anchor-id="session-options" id="session-options">Session Options</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a>opts.graph_optimization_level <span class="op">=</span> ort.GraphOptimizationLevel.ORT_ENABLE_ALL</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a>opts.intra_op_num_threads <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a>opts.enable_profiling <span class="op">=</span> <span class="va">False</span></span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"model.onnx"</span>, sess_options<span class="op">=</span>opts,</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span></code></pre></div></div>
</section>
</section>
<section id="inspecting-and-debugging-onnx-graphs" class="level2">
<h2 class="anchored" data-anchor-id="inspecting-and-debugging-onnx-graphs" id="inspecting-and-debugging-onnx-graphs">Inspecting and Debugging ONNX Graphs</h2>
<section id="printing-graph-structure" class="level3">
<h3 class="anchored" data-anchor-id="printing-graph-structure" id="printing-graph-structure">Printing Graph Structure</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"model.onnx"</span>)</span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> model.graph</span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Opset: </span><span class="sc">{</span>model<span class="sc">.</span>opset_import[<span class="dv">0</span>]<span class="sc">.</span>version<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Inputs:"</span>)</span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> inp <span class="kw">in</span> graph.<span class="bu">input</span>:</span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>inp<span class="sc">.</span>name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Outputs:"</span>)</span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> out <span class="kw">in</span> graph.output:</span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>out<span class="sc">.</span>name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Initializers: </span><span class="sc">{</span><span class="bu">len</span>(graph.initializer)<span class="sc">}</span><span class="ss"> tensors"</span>)</span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> init <span class="kw">in</span> graph.initializer:</span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a>    shape <span class="op">=</span> <span class="bu">list</span>(init.dims)</span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>    dtype <span class="op">=</span> init.data_type</span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>init<span class="sc">.</span>name<span class="sc">:30s}</span><span class="ss"> shape=</span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss">, dtype=</span><span class="sc">{</span>dtype<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-21"><a href="#cb28-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Nodes (</span><span class="sc">{</span><span class="bu">len</span>(graph.node)<span class="sc">}</span><span class="ss"> total):"</span>)</span>
<span id="cb28-22"><a href="#cb28-22" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> node <span class="kw">in</span> graph.node:</span>
<span id="cb28-23"><a href="#cb28-23" aria-hidden="true" tabindex="-1"></a>    attrs <span class="op">=</span> {a.name: ... <span class="cf">for</span> a <span class="kw">in</span> node.attribute}</span>
<span id="cb28-24"><a href="#cb28-24" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  [</span><span class="sc">{</span>node<span class="sc">.</span>op_type<span class="sc">:20s}</span><span class="ss">] </span><span class="sc">{</span><span class="bu">list</span>(node.<span class="bu">input</span>)<span class="sc">}</span><span class="ss"> → </span><span class="sc">{</span><span class="bu">list</span>(node.output)<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="extracting-intermediate-outputs" class="level3">
<h3 class="anchored" data-anchor-id="extracting-intermediate-outputs" id="extracting-intermediate-outputs">Extracting Intermediate Outputs</h3>
<p>You can expose intermediate tensors as additional graph outputs for debugging:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> shape_inference</span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a>model    <span class="op">=</span> onnx.load(<span class="st">"mlp.onnx"</span>)</span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a>inferred <span class="op">=</span> shape_inference.infer_shapes(model)</span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Identify the value_info for intermediate tensor "relu1_out"</span></span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a>vi_to_expose <span class="op">=</span> <span class="va">None</span></span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> vi <span class="kw">in</span> inferred.graph.value_info:</span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> vi.name <span class="op">==</span> <span class="st">"relu1_out"</span>:</span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a>        vi_to_expose <span class="op">=</span> vi</span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-14"><a href="#cb29-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Add it as a graph output</span></span>
<span id="cb29-15"><a href="#cb29-15" aria-hidden="true" tabindex="-1"></a>debug_model <span class="op">=</span> onnx.ModelProto()</span>
<span id="cb29-16"><a href="#cb29-16" aria-hidden="true" tabindex="-1"></a>debug_model.CopyFrom(inferred)</span>
<span id="cb29-17"><a href="#cb29-17" aria-hidden="true" tabindex="-1"></a>debug_model.graph.output.append(vi_to_expose)</span>
<span id="cb29-18"><a href="#cb29-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-19"><a href="#cb29-19" aria-hidden="true" tabindex="-1"></a>onnx.save(debug_model, <span class="st">"mlp_debug.onnx"</span>)</span></code></pre></div></div>
</section>
<section id="using-netron" class="level3">
<h3 class="anchored" data-anchor-id="using-netron" id="using-netron">Using Netron</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> subprocess</span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a>subprocess.Popen([<span class="st">"netron"</span>, <span class="st">"model.onnx"</span>])</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a><span class="co"># or just open the file directly in the Netron app</span></span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Netron renders the full computation graph in a browser. Each node shows its op type, attributes, and input/output tensor names with their inferred shapes (if you ran shape inference). It is the single most useful tool for understanding and debugging ONNX models.</p>
</div>
</div>
</section>
</section>
<section id="advanced-techniques-control-flow-subgraphs-and-custom-ops" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques-control-flow-subgraphs-and-custom-ops" id="advanced-techniques-control-flow-subgraphs-and-custom-ops">Advanced Techniques: Control Flow, Subgraphs, and Custom Ops</h2>
<section id="control-flow-if-loop-scan" class="level3">
<h3 class="anchored" data-anchor-id="control-flow-if-loop-scan" id="control-flow-if-loop-scan">Control Flow: <code>If</code>, <code>Loop</code>, <code>Scan</code></h3>
<p>ONNX supports limited control flow via three special operators. These operators contain subgraphs (nested <code>GraphProto</code> objects) inside their attributes.</p>
<p><strong><code>If</code></strong>: Conditional execution. Takes a boolean scalar condition and contains two subgraph attributes: <code>then_branch</code> and <code>else_branch</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Pseudocode — then_branch and else_branch are full GraphProto objects</span></span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a>if_node <span class="op">=</span> helper.make_node(</span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"If"</span>,</span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"condition"</span>],</span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"result"</span>],</span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a>    then_branch<span class="op">=</span>then_graph,</span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a>    else_branch<span class="op">=</span>else_graph,</span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong><code>Loop</code></strong>: A counted or condition-based loop. Takes a trip count, initial condition, and initial state tensors, and runs a body subgraph repeatedly.</p>
<p><strong><code>Scan</code></strong>: Applies a body subgraph across the time axis of sequence inputs, accumulating state. Useful for custom RNNs.</p>
<p>These operators are powerful but complex. Their subgraphs must be complete valid <code>GraphProto</code> objects with their own inputs and outputs. Building them requires careful management of variable names and scoping.</p>
</section>
<section id="custom-operators" class="level3">
<h3 class="anchored" data-anchor-id="custom-operators" id="custom-operators">Custom Operators</h3>
<p>If you need an operation not in the ONNX standard set, you can define a custom operator with a non-standard domain:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a>custom_node <span class="op">=</span> helper.make_node(</span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a>    op_type<span class="op">=</span><span class="st">"MySpecialOp"</span>,</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a>    domain<span class="op">=</span><span class="st">"com.mycompany"</span>,</span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"x"</span>],</span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"y"</span>],</span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>    my_custom_attr<span class="op">=</span><span class="dv">42</span>,</span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"custom_op_1"</span>,</span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Register the custom domain in the opset imports</span></span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> helper.make_model(</span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>    graph,</span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>    opset_imports<span class="op">=</span>[</span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>        helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>),</span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>        helper.make_opsetid(<span class="st">"com.mycompany"</span>, <span class="dv">1</span>),</span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>To run custom ops with ONNX Runtime, you register a Python or C++ custom op implementation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb33"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Python custom op (ort &gt;= 1.13)</span></span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MySpecialOpImpl:</span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, op, device):</span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute(<span class="va">self</span>, x):</span>
<span id="cb33-8"><a href="#cb33-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [x <span class="op">*</span> <span class="dv">2</span>]  <span class="co"># example: just double the input</span></span>
<span id="cb33-9"><a href="#cb33-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-10"><a href="#cb33-10" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb33-11"><a href="#cb33-11" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(</span>
<span id="cb33-12"><a href="#cb33-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"model_with_custom_op.onnx"</span>,</span>
<span id="cb33-13"><a href="#cb33-13" aria-hidden="true" tabindex="-1"></a>    sess_options<span class="op">=</span>opts,</span>
<span id="cb33-14"><a href="#cb33-14" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>],</span>
<span id="cb33-15"><a href="#cb33-15" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb33-16"><a href="#cb33-16" aria-hidden="true" tabindex="-1"></a><span class="co"># C++ ops are registered via shared libraries</span></span></code></pre></div></div>
</section>
<section id="function-based-operators" class="level3">
<h3 class="anchored" data-anchor-id="function-based-operators" id="function-based-operators">Function-Based Operators</h3>
<p>ONNX also allows you to define <code>FunctionProto</code> objects — named, reusable operator definitions composed of existing ONNX ops. These let you package composite operations (like a Transformer block) as a single named op that expands to primitives during execution:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb34"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb34-1"><a href="#cb34-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> helper, TensorProto</span>
<span id="cb34-2"><a href="#cb34-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-3"><a href="#cb34-3" aria-hidden="true" tabindex="-1"></a>func <span class="op">=</span> helper.make_function(</span>
<span id="cb34-4"><a href="#cb34-4" aria-hidden="true" tabindex="-1"></a>    domain<span class="op">=</span><span class="st">"com.myarch"</span>,</span>
<span id="cb34-5"><a href="#cb34-5" aria-hidden="true" tabindex="-1"></a>    fname<span class="op">=</span><span class="st">"LayerNormFunc"</span>,</span>
<span id="cb34-6"><a href="#cb34-6" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"X"</span>, <span class="st">"scale"</span>, <span class="st">"bias"</span>],</span>
<span id="cb34-7"><a href="#cb34-7" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"Y"</span>],</span>
<span id="cb34-8"><a href="#cb34-8" aria-hidden="true" tabindex="-1"></a>    nodes<span class="op">=</span>[...],  <span class="co"># the expanded graph nodes</span></span>
<span id="cb34-9"><a href="#cb34-9" aria-hidden="true" tabindex="-1"></a>    opset_imports<span class="op">=</span>[helper.make_opsetid(<span class="st">""</span>, <span class="dv">17</span>)],</span>
<span id="cb34-10"><a href="#cb34-10" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb34-11"><a href="#cb34-11" aria-hidden="true" tabindex="-1"></a>model.functions.append(func)</span></code></pre></div></div>
</section>
</section>
<section id="optimization-and-graph-transformations" class="level2">
<h2 class="anchored" data-anchor-id="optimization-and-graph-transformations" id="optimization-and-graph-transformations">Optimization and Graph Transformations</h2>
<p>Raw hand-built ONNX graphs are often not as efficient as they could be. Several tools exist to optimize them.</p>
<section id="onnx-runtime-graph-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="onnx-runtime-graph-optimizations" id="onnx-runtime-graph-optimizations">ONNX Runtime Graph Optimizations</h3>
<p>The simplest approach is to let ONNX Runtime’s optimizer do the work at load time:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a>opts.graph_optimization_level <span class="op">=</span> ort.GraphOptimizationLevel.ORT_ENABLE_ALL</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Optionally, save the optimized model for inspection</span></span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a>opts.optimized_model_filepath <span class="op">=</span> <span class="st">"optimized_model.onnx"</span></span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"model.onnx"</span>, sess_options<span class="op">=</span>opts,</span>
<span id="cb35-8"><a href="#cb35-8" aria-hidden="true" tabindex="-1"></a>    providers<span class="op">=</span>[<span class="st">"CPUExecutionProvider"</span>])</span></code></pre></div></div>
<p>ONNX Runtime performs fusions (Conv+BN+Relu → ConvRelu), dead code elimination, constant folding, and more.</p>
</section>
<section id="onnx-simplifier" class="level3">
<h3 class="anchored" data-anchor-id="onnx-simplifier" id="onnx-simplifier">ONNX Simplifier</h3>
<p><code>onnx-simplifier</code> is a third-party tool that applies constant folding and other simplifications:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb36"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb36-1"><a href="#cb36-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnxsim</span>
<span id="cb36-2"><a href="#cb36-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> onnxsim model.onnx simplified_model.onnx</span></code></pre></div></div>
<p>Or programmatically:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxsim <span class="im">import</span> simplify</span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-4"><a href="#cb37-4" aria-hidden="true" tabindex="-1"></a>model      <span class="op">=</span> onnx.load(<span class="st">"model.onnx"</span>)</span>
<span id="cb37-5"><a href="#cb37-5" aria-hidden="true" tabindex="-1"></a>simplified, check <span class="op">=</span> simplify(model)</span>
<span id="cb37-6"><a href="#cb37-6" aria-hidden="true" tabindex="-1"></a><span class="cf">assert</span> check, <span class="st">"Simplified ONNX model could not be validated!"</span></span>
<span id="cb37-7"><a href="#cb37-7" aria-hidden="true" tabindex="-1"></a>onnx.save(simplified, <span class="st">"simplified_model.onnx"</span>)</span></code></pre></div></div>
</section>
<section id="manual-graph-surgery-with-onnx.helper-and-onnx.compose" class="level3">
<h3 class="anchored" data-anchor-id="manual-graph-surgery-with-onnx.helper-and-onnx.compose" id="manual-graph-surgery-with-onnx.helper-and-onnx.compose">Manual Graph Surgery with <code>onnx.helper</code> and <code>onnx.compose</code></h3>
<p>The <code>onnx.compose</code> module (ONNX ≥ 1.13) provides <code>merge_models</code> and <code>add_prefix</code> utilities for combining and modifying graphs:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb38"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb38-1"><a href="#cb38-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnx <span class="im">import</span> compose</span>
<span id="cb38-2"><a href="#cb38-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb38-3"><a href="#cb38-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Merge two models sequentially (output of model1 feeds input of model2)</span></span>
<span id="cb38-4"><a href="#cb38-4" aria-hidden="true" tabindex="-1"></a>combined <span class="op">=</span> compose.merge_models(</span>
<span id="cb38-5"><a href="#cb38-5" aria-hidden="true" tabindex="-1"></a>    model1, model2,</span>
<span id="cb38-6"><a href="#cb38-6" aria-hidden="true" tabindex="-1"></a>    io_map<span class="op">=</span>[(<span class="st">"model1_output"</span>, <span class="st">"model2_input"</span>)],</span>
<span id="cb38-7"><a href="#cb38-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>For direct graph surgery (removing nodes, inserting nodes, rewiring edges), you work directly with the <code>graph.node</code> list:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> onnx.load(<span class="st">"model.onnx"</span>)</span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>graph <span class="op">=</span> model.graph</span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-4"><a href="#cb39-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Remove a specific node by name</span></span>
<span id="cb39-5"><a href="#cb39-5" aria-hidden="true" tabindex="-1"></a>graph.node[:] <span class="op">=</span> [n <span class="cf">for</span> n <span class="kw">in</span> graph.node <span class="cf">if</span> n.name <span class="op">!=</span> <span class="st">"relu_to_remove"</span>]</span>
<span id="cb39-6"><a href="#cb39-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-7"><a href="#cb39-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Insert a new node after a specific point</span></span>
<span id="cb39-8"><a href="#cb39-8" aria-hidden="true" tabindex="-1"></a>new_node <span class="op">=</span> helper.make_node(<span class="st">"Tanh"</span>, inputs<span class="op">=</span>[<span class="st">"linear_out"</span>], outputs<span class="op">=</span>[<span class="st">"tanh_out"</span>])</span>
<span id="cb39-9"><a href="#cb39-9" aria-hidden="true" tabindex="-1"></a>insert_idx <span class="op">=</span> <span class="bu">next</span>(i <span class="cf">for</span> i, n <span class="kw">in</span> <span class="bu">enumerate</span>(graph.node) <span class="cf">if</span> n.name <span class="op">==</span> <span class="st">"linear"</span>)</span>
<span id="cb39-10"><a href="#cb39-10" aria-hidden="true" tabindex="-1"></a>graph.node.insert(insert_idx <span class="op">+</span> <span class="dv">1</span>, new_node)</span>
<span id="cb39-11"><a href="#cb39-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-12"><a href="#cb39-12" aria-hidden="true" tabindex="-1"></a>onnx.checker.check_model(model)</span>
<span id="cb39-13"><a href="#cb39-13" aria-hidden="true" tabindex="-1"></a>onnx.save(model, <span class="st">"modified_model.onnx"</span>)</span></code></pre></div></div>
</section>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h3>
<p>ONNX Runtime provides post-training quantization tools:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb40"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><a href="#cb40-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.quantization <span class="im">import</span> quantize_dynamic, QuantType</span>
<span id="cb40-2"><a href="#cb40-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-3"><a href="#cb40-3" aria-hidden="true" tabindex="-1"></a>quantize_dynamic(</span>
<span id="cb40-4"><a href="#cb40-4" aria-hidden="true" tabindex="-1"></a>    model_input<span class="op">=</span><span class="st">"model.onnx"</span>,</span>
<span id="cb40-5"><a href="#cb40-5" aria-hidden="true" tabindex="-1"></a>    model_output<span class="op">=</span><span class="st">"model_quant.onnx"</span>,</span>
<span id="cb40-6"><a href="#cb40-6" aria-hidden="true" tabindex="-1"></a>    weight_type<span class="op">=</span>QuantType.QInt8,</span>
<span id="cb40-7"><a href="#cb40-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>For static quantization (requires calibration data):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> onnxruntime.quantization <span class="im">import</span> quantize_static, CalibrationDataReader, QuantType</span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MyCalibReader(CalibrationDataReader):</span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_next(<span class="va">self</span>):</span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># yield batches of calibration inputs</span></span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a>        ...</span>
<span id="cb41-7"><a href="#cb41-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-8"><a href="#cb41-8" aria-hidden="true" tabindex="-1"></a>quantize_static(</span>
<span id="cb41-9"><a href="#cb41-9" aria-hidden="true" tabindex="-1"></a>    model_input<span class="op">=</span><span class="st">"model.onnx"</span>,</span>
<span id="cb41-10"><a href="#cb41-10" aria-hidden="true" tabindex="-1"></a>    model_output<span class="op">=</span><span class="st">"model_quant_static.onnx"</span>,</span>
<span id="cb41-11"><a href="#cb41-11" aria-hidden="true" tabindex="-1"></a>    calibration_data_reader<span class="op">=</span>MyCalibReader(),</span>
<span id="cb41-12"><a href="#cb41-12" aria-hidden="true" tabindex="-1"></a>    quant_format<span class="op">=</span>QuantType.QInt8,</span>
<span id="cb41-13"><a href="#cb41-13" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices-and-common-pitfalls" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-common-pitfalls" id="best-practices-and-common-pitfalls">Best Practices and Common Pitfalls</h2>
<section id="always-use-unique-tensor-names" class="level3">
<h3 class="anchored" data-anchor-id="always-use-unique-tensor-names" id="always-use-unique-tensor-names">Always Use Unique Tensor Names</h3>
<p>Every intermediate tensor name in the graph must be unique. Reusing a name means two nodes will try to write the same tensor, causing silent corruption or runtime errors. A simple convention is to prefix names with the layer or block name:</p>
<pre><code>"block2_conv1_out"  rather than  "conv_out"</code></pre>
</section>
<section id="match-opset-to-your-runtime" class="level3">
<h3 class="anchored" data-anchor-id="match-opset-to-your-runtime" id="match-opset-to-your-runtime">Match Opset to Your Runtime</h3>
<p>ONNX Runtime versions support specific ONNX opset ranges. Using an opset that is too new will cause load failures. Check the ONNX Runtime release notes for the supported opset range, and pin your <code>opset_imports</code> accordingly. Opset 17 is a safe choice for most current runtimes as of 2025.</p>
</section>
<section id="initializer-vs.-graph-input-know-the-difference" class="level3">
<h3 class="anchored" data-anchor-id="initializer-vs.-graph-input-know-the-difference" id="initializer-vs.-graph-input-know-the-difference">Initializer vs.&nbsp;Graph Input: Know the Difference</h3>
<p>Initializers represent constant parameters that are part of the model. Graph inputs are external tensors provided at inference time. Do not list your weights in <code>graph.input</code> — they belong only in <code>graph.initializer</code>. ONNX Runtime will warn about (and older versions will fail on) weights that appear in both places.</p>
<p>In older ONNX IR versions (IR &lt; 4), initializers were required to also appear as graph inputs. From IR version 4 onward, this is no longer needed. Set <code>model.ir_version = 8</code> and list weights only as initializers.</p>
</section>
<section id="check-data-types-carefully" class="level3">
<h3 class="anchored" data-anchor-id="check-data-types-carefully" id="check-data-types-carefully">Check Data Types Carefully</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Warning
</div>
</div>
<div class="callout-body-container callout-body">
<p>All of the following will cause silent incorrect results or runtime errors if you mix them up:</p>
<ul>
<li>Mixing <code>float32</code> and <code>float64</code> inputs/weights without an explicit <code>Cast</code>.</li>
<li>Using Python <code>int</code> (64-bit) where the model expects <code>int32</code>.</li>
<li>Passing NHWC image data to a Conv that expects NCHW.</li>
</ul>
</div>
</div>
<p>Always verify numpy dtypes when constructing initializers:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb43"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb43-1"><a href="#cb43-1" aria-hidden="true" tabindex="-1"></a>W <span class="op">=</span> my_array.astype(np.float32)  <span class="co"># always explicit</span></span></code></pre></div></div>
</section>
<section id="pads-are-symmetric-lists-not-single-values" class="level3">
<h3 class="anchored" data-anchor-id="pads-are-symmetric-lists-not-single-values" id="pads-are-symmetric-lists-not-single-values">Pads Are Symmetric Lists, Not Single Values</h3>
<p>The <code>pads</code> attribute on <code>Conv</code> and <code>MaxPool</code> is a flat list of all padding values: <code>[pad_h_begin, pad_w_begin, pad_h_end, pad_w_end]</code> for 2D. For 3D convolutions it extends further. Do not pass a single integer.</p>
</section>
<section id="use-squeeze-and-unsqueeze-on-axes-inputs-opset-13" class="level3">
<h3 class="anchored" data-anchor-id="use-squeeze-and-unsqueeze-on-axes-inputs-opset-13" id="use-squeeze-and-unsqueeze-on-axes-inputs-opset-13">Use <code>Squeeze</code> and <code>Unsqueeze</code> on Axes Inputs (Opset ≥ 13)</h3>
<p>In ONNX opset 13+, the <code>axes</code> argument to <code>Squeeze</code> and <code>Unsqueeze</code> moved from an attribute to an input tensor. This means you must create a constant tensor for it:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb44"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb44-1"><a href="#cb44-1" aria-hidden="true" tabindex="-1"></a>axes_const <span class="op">=</span> numpy_helper.from_array(np.array([<span class="dv">0</span>], dtype<span class="op">=</span>np.int64), name<span class="op">=</span><span class="st">"squeeze_axes"</span>)</span>
<span id="cb44-2"><a href="#cb44-2" aria-hidden="true" tabindex="-1"></a>inits.append(axes_const)</span>
<span id="cb44-3"><a href="#cb44-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb44-4"><a href="#cb44-4" aria-hidden="true" tabindex="-1"></a>squeeze_node <span class="op">=</span> helper.make_node(<span class="st">"Squeeze"</span>,</span>
<span id="cb44-5"><a href="#cb44-5" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"my_tensor"</span>, <span class="st">"squeeze_axes"</span>],</span>
<span id="cb44-6"><a href="#cb44-6" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"squeezed"</span>],</span>
<span id="cb44-7"><a href="#cb44-7" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"squeeze"</span>,</span>
<span id="cb44-8"><a href="#cb44-8" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="reshape-takes-a-tensor-input-not-an-attribute" class="level3">
<h3 class="anchored" data-anchor-id="reshape-takes-a-tensor-input-not-an-attribute" id="reshape-takes-a-tensor-input-not-an-attribute"><code>Reshape</code> Takes a Tensor Input, Not an Attribute</h3>
<p>In opset 5+, the target shape for <code>Reshape</code> is a 1D INT64 tensor input, not an attribute. Store it as an initializer:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb45"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a>target_shape <span class="op">=</span> numpy_helper.from_array(np.array([<span class="op">-</span><span class="dv">1</span>, <span class="dv">128</span>], dtype<span class="op">=</span>np.int64), name<span class="op">=</span><span class="st">"tgt_shape"</span>)</span>
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a>inits.append(target_shape)</span>
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a>reshape_node <span class="op">=</span> helper.make_node(<span class="st">"Reshape"</span>,</span>
<span id="cb45-5"><a href="#cb45-5" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[<span class="st">"flat_input"</span>, <span class="st">"tgt_shape"</span>],</span>
<span id="cb45-6"><a href="#cb45-6" aria-hidden="true" tabindex="-1"></a>    outputs<span class="op">=</span>[<span class="st">"reshaped"</span>],</span>
<span id="cb45-7"><a href="#cb45-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="profile-before-optimizing" class="level3">
<h3 class="anchored" data-anchor-id="profile-before-optimizing" id="profile-before-optimizing">Profile Before Optimizing</h3>
<p>ONNX Runtime provides built-in profiling. Enable it to find bottleneck operators before spending time on manual optimizations:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb46"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb46-1"><a href="#cb46-1" aria-hidden="true" tabindex="-1"></a>opts <span class="op">=</span> ort.SessionOptions()</span>
<span id="cb46-2"><a href="#cb46-2" aria-hidden="true" tabindex="-1"></a>opts.enable_profiling <span class="op">=</span> <span class="va">True</span></span>
<span id="cb46-3"><a href="#cb46-3" aria-hidden="true" tabindex="-1"></a>sess <span class="op">=</span> ort.InferenceSession(<span class="st">"model.onnx"</span>, sess_options<span class="op">=</span>opts)</span>
<span id="cb46-4"><a href="#cb46-4" aria-hidden="true" tabindex="-1"></a>sess.run(...)</span>
<span id="cb46-5"><a href="#cb46-5" aria-hidden="true" tabindex="-1"></a>prof_file <span class="op">=</span> sess.end_profiling()  <span class="co"># returns path to JSON profile</span></span>
<span id="cb46-6"><a href="#cb46-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Open in Chrome at chrome://tracing</span></span></code></pre></div></div>
</section>
</section>
<section id="reference-commonly-used-onnx-operators" class="level2">
<h2 class="anchored" data-anchor-id="reference-commonly-used-onnx-operators" id="reference-commonly-used-onnx-operators">Reference: Commonly Used ONNX Operators</h2>
<p>Below is a quick-reference table of the operators used most frequently in architecture construction, with their key attributes and input/output conventions.</p>
<table class="caption-top table">
<caption>Commonly used ONNX operators with key attributes and output shapes</caption>
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Operator</th>
<th>Key Inputs</th>
<th>Key Attributes</th>
<th>Output Shape (example)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>Gemm</code></td>
<td>A, B, C (bias)</td>
<td><code>transA</code>, <code>transB</code>, <code>alpha</code>, <code>beta</code></td>
<td><code>[M, N]</code></td>
</tr>
<tr class="even">
<td><code>MatMul</code></td>
<td>A, B</td>
<td>—</td>
<td><code>[..., M, N]</code></td>
</tr>
<tr class="odd">
<td><code>Conv</code></td>
<td>X, W, B</td>
<td><code>kernel_shape</code>, <code>strides</code>, <code>pads</code>, <code>dilations</code>, <code>group</code></td>
<td><code>[N, C_out, H_out, W_out]</code></td>
</tr>
<tr class="even">
<td><code>ConvTranspose</code></td>
<td>X, W, B</td>
<td><code>kernel_shape</code>, <code>strides</code>, <code>pads</code>, <code>output_padding</code></td>
<td><code>[N, C_out, H_out, W_out]</code></td>
</tr>
<tr class="odd">
<td><code>BatchNormalization</code></td>
<td>X, scale, B, mean, var</td>
<td><code>epsilon</code>, <code>momentum</code></td>
<td>same as X</td>
</tr>
<tr class="even">
<td><code>LayerNormalization</code></td>
<td>X, scale, B</td>
<td><code>axis</code>, <code>epsilon</code></td>
<td>same as X</td>
</tr>
<tr class="odd">
<td><code>Relu</code></td>
<td>X</td>
<td>—</td>
<td>same as X</td>
</tr>
<tr class="even">
<td><code>Sigmoid</code></td>
<td>X</td>
<td>—</td>
<td>same as X</td>
</tr>
<tr class="odd">
<td><code>Tanh</code></td>
<td>X</td>
<td>—</td>
<td>same as X</td>
</tr>
<tr class="even">
<td><code>Softmax</code></td>
<td>X</td>
<td><code>axis</code></td>
<td>same as X</td>
</tr>
<tr class="odd">
<td><code>Gelu</code></td>
<td>X</td>
<td><code>approximate</code></td>
<td>same as X</td>
</tr>
<tr class="even">
<td><code>MaxPool</code></td>
<td>X</td>
<td><code>kernel_shape</code>, <code>strides</code>, <code>pads</code></td>
<td><code>[N, C, H_out, W_out]</code></td>
</tr>
<tr class="odd">
<td><code>GlobalAveragePool</code></td>
<td>X</td>
<td>—</td>
<td><code>[N, C, 1, 1]</code></td>
</tr>
<tr class="even">
<td><code>Reshape</code></td>
<td>data, shape</td>
<td>—</td>
<td>as specified by <code>shape</code></td>
</tr>
<tr class="odd">
<td><code>Flatten</code></td>
<td>X</td>
<td><code>axis</code></td>
<td><code>[N, M]</code></td>
</tr>
<tr class="even">
<td><code>Transpose</code></td>
<td>X</td>
<td><code>perm</code></td>
<td>permuted axes</td>
</tr>
<tr class="odd">
<td><code>Squeeze</code></td>
<td>X, axes</td>
<td>—</td>
<td>removes specified dims</td>
</tr>
<tr class="even">
<td><code>Unsqueeze</code></td>
<td>X, axes</td>
<td>—</td>
<td>inserts specified dims</td>
</tr>
<tr class="odd">
<td><code>Concat</code></td>
<td>inputs…</td>
<td><code>axis</code></td>
<td>concatenated</td>
</tr>
<tr class="even">
<td><code>Split</code></td>
<td>X</td>
<td><code>axis</code>, <code>split</code></td>
<td>list of tensors</td>
</tr>
<tr class="odd">
<td><code>Add</code></td>
<td>A, B</td>
<td>—</td>
<td>broadcast shape</td>
</tr>
<tr class="even">
<td><code>Mul</code></td>
<td>A, B</td>
<td>—</td>
<td>broadcast shape</td>
</tr>
<tr class="odd">
<td><code>ReduceMean</code></td>
<td>X, axes</td>
<td><code>keepdims</code></td>
<td>reduced shape</td>
</tr>
<tr class="even">
<td><code>Cast</code></td>
<td>X</td>
<td><code>to</code> (dtype enum)</td>
<td>same shape, new dtype</td>
</tr>
<tr class="odd">
<td><code>Gather</code></td>
<td>data, indices</td>
<td><code>axis</code></td>
<td>indexed shape</td>
</tr>
<tr class="even">
<td><code>LSTM</code></td>
<td>X, W, R, B</td>
<td><code>hidden_size</code>, <code>direction</code></td>
<td><code>Y</code>, <code>Y_h</code>, <code>Y_c</code></td>
</tr>
<tr class="odd">
<td><code>GRU</code></td>
<td>X, W, R, B</td>
<td><code>hidden_size</code>, <code>direction</code></td>
<td><code>Y</code>, <code>Y_h</code></td>
</tr>
<tr class="even">
<td><code>Where</code></td>
<td>cond, X, Y</td>
<td>—</td>
<td>broadcast shape</td>
</tr>
<tr class="odd">
<td><code>Einsum</code></td>
<td>inputs…</td>
<td><code>equation</code></td>
<td>per equation</td>
</tr>
<tr class="even">
<td><code>Constant</code></td>
<td>—</td>
<td><code>value</code></td>
<td>shape of value</td>
</tr>
<tr class="odd">
<td><code>Shape</code></td>
<td>X</td>
<td>—</td>
<td><code>[rank(X)]</code> INT64</td>
</tr>
<tr class="even">
<td><code>Expand</code></td>
<td>X, shape</td>
<td>—</td>
<td>broadcast target shape</td>
</tr>
</tbody>
</table>
</section>
<section id="further-reading" class="level2">
<h2 class="anchored" data-anchor-id="further-reading" id="further-reading">Further Reading</h2>
<p>For everything beyond this guide, the following resources are authoritative:</p>
<ul>
<li><strong>ONNX Operator Specification</strong>: <a href="https://onnx.ai/onnx/operators/">https://onnx.ai/onnx/operators/</a> — the canonical reference for every operator, every opset version, and every attribute’s exact semantics.</li>
<li><strong>ONNX Protobuf Schema</strong>: <a href="https://github.com/onnx/onnx/blob/main/onnx/onnx.proto">https://github.com/onnx/onnx/blob/main/onnx/onnx.proto</a></li>
<li><strong>ONNX Runtime Documentation</strong>: <a href="https://onnxruntime.ai/docs/">https://onnxruntime.ai/docs/</a></li>
<li><strong>ONNX Runtime Python API Reference</strong>: <a href="https://onnxruntime.ai/docs/api/python/api_summary.html">https://onnxruntime.ai/docs/api/python/api_summary.html</a></li>
<li><strong>Netron Visualizer</strong>: <a href="https://netron.app/">https://netron.app/</a></li>
</ul>
<hr>
<div class="callout callout-style-simple callout-note no-icon">
<div class="callout-body d-flex">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-body-container">
<p>Guide written for ONNX opset 17, ONNX Runtime 1.18+, and Python 3.10+. All code examples use <code>numpy</code> 1.24+ and <code>onnx</code> 1.15+.</p>
</div>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Milvus for Computer Vision: An In-Depth Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/milvus/milvus-computer-vision-guide/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/milvus/milvus-computer-vision-guide/</guid>
      <pubDate>Thu, 07 May 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="milvus-for-computer-vision-an-in-depth-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/milvus/milvus-computer-vision-guide/milvus.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>I’ve spent a fair bit of time working on computer vision systems — the kind that start small, manageable, and almost deceptively simple, and then quietly spiral in scale until the infrastructure holding them together starts creaking at the seams. For a while I was getting by with fairly standard approaches: storing image embeddings in flat files, querying with NumPy, eventually graduating to something like FAISS. It worked. Until it didn’t.</p>
<p>The turning point came when the dataset crossed a threshold where even approximate brute-force search started adding up to latency that was genuinely painful in production. I needed something that could handle tens of millions of vectors, support filtered queries alongside similarity search, and not require me to completely rebuild the data layer every few months as the system grew. That’s when I came across Milvus.</p>
<p>What struck me first was how deliberately it had been designed for exactly this class of problem. It wasn’t a general-purpose database with a vector plugin bolted on — it was built from the ground up around the idea that your primary query is “find me things that look like this,” and everything else (filtering, metadata, persistence, scalability) flows from that. Getting started was surprisingly approachable once I understood the core concepts, and scaling from a local prototype to a distributed deployment was far more incremental than I’d expected.</p>
<p>This guide is what I wish I’d had when I started. It covers Milvus from the very beginning — what vector databases are, how embeddings work, and why you need dedicated infrastructure for this kind of search — all the way through four real computer vision use cases, three deployment modes, and the performance tuning details that actually matter in practice. Whether you’re prototyping on a laptop or planning a production system handling billions of vectors, the path forward is here.</p>
<hr>
</section>
<section id="table-of-contents" class="level2">
<h2 class="anchored" data-anchor-id="table-of-contents" id="table-of-contents">Table of Contents</h2>
<ol type="1">
<li><a href="#sec-whatisit">What Is a Vector Database — And Why Do You Need One?</a></li>
<li><a href="#sec-milvus">Introducing Milvus</a></li>
<li><a href="#sec-concepts">Core Concepts You Must Understand</a></li>
<li><a href="#sec-cvpipeline">How Computer Vision Meets Vector Search</a></li>
<li><a href="#sec-setup">Setting Up Your Environment</a></li>
<li><a href="#sec-deployment">Deployment Options: Lite → Docker → Kubernetes</a></li>
<li><a href="#sec-collections">Working with Collections and Schemas</a></li>
<li><a href="#sec-inserting">Inserting Embedding Vectors</a></li>
<li><a href="#sec-indexes">Index Types and When to Use Each</a></li>
<li><a href="#sec-querying">Querying and Searching</a></li>
<li><a href="#sec-imgsimilarity">Use Case 1 — Image Similarity Search</a></li>
<li><a href="#sec-facerecog">Use Case 2 — Face Recognition</a></li>
<li><a href="#sec-objectdetect">Use Case 3 — Object Detection &amp; Retrieval</a></li>
<li><a href="#sec-videosearch">Use Case 4 — Video Frame Search</a></li>
<li><a href="#sec-partitions">Partitions, Filtering, and Hybrid Search</a></li>
<li><a href="#sec-performance">Performance Tuning and Best Practices</a></li>
<li><a href="#sec-security">Security and Access Control</a></li>
<li><a href="#sec-monitoring">Monitoring and Observability</a></li>
<li><a href="#sec-pitfalls">Common Pitfalls and How to Avoid Them</a></li>
<li><a href="#sec-glossary">Glossary</a></li>
</ol>
<hr>
</section>
<section id="sec-whatisit" class="level2">
<h2 class="anchored" data-anchor-id="sec-whatisit" id="sec-whatisit">1. What Is a Vector Database — And Why Do You Need One?</h2>
<section id="the-problem-with-traditional-databases" class="level3">
<h3 class="anchored" data-anchor-id="the-problem-with-traditional-databases" id="the-problem-with-traditional-databases">The Problem with Traditional Databases</h3>
<p>Traditional relational databases (PostgreSQL, MySQL, SQLite) store and retrieve data that is <strong>exactly defined</strong> — rows, columns, integers, strings, dates. When you want to find a user named “Alice,” you write:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode sql code-with-copy"><code class="sourceCode sql"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">SELECT</span> <span class="op">*</span> <span class="kw">FROM</span> users <span class="kw">WHERE</span> name <span class="op">=</span> <span class="st">'Alice'</span>;</span></code></pre></div></div>
<p>This works perfectly for exact matches. But computer vision operates in an entirely different paradigm. Imagine you have a photo of a dog and you want to find all similar-looking dogs in a database of one million photos. There is no exact match to look for. The question is not “find this exact image” — it is “find images that look like this image.”</p>
<p>Traditional databases cannot answer that question efficiently. You could compare pixel-by-pixel, but that would be catastrophically slow and would fail even for the same dog photographed twice under different lighting conditions.</p>
</section>
<section id="the-role-of-embeddings" class="level3">
<h3 class="anchored" data-anchor-id="the-role-of-embeddings" id="the-role-of-embeddings">The Role of Embeddings</h3>
<p>The key insight that makes modern computer vision work is this: <strong>neural networks can compress the semantic meaning of an image into a compact numerical vector</strong> — called an <strong>embedding</strong> or <strong>feature vector</strong>.</p>
<p>An embedding is simply a list of floating-point numbers. For example, a 512-dimensional embedding is a list of 512 floats:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a>[<span class="fl">0.023</span>, <span class="op">-</span><span class="fl">0.412</span>, <span class="fl">0.881</span>, <span class="fl">0.003</span>, <span class="op">-</span><span class="fl">0.667</span>, ..., <span class="fl">0.142</span>]  <span class="co"># 512 values total</span></span></code></pre></div></div>
<p>What makes embeddings magical is that neural networks learn to place <strong>semantically similar images close together</strong> in this high-dimensional space. Two photos of the same person, taken from different angles and lighting conditions, will produce embeddings that are numerically close to each other. A cat and a dog will be closer to each other than a cat and an airplane.</p>
<p>“Close” in this context is measured by mathematical distance functions:</p>
<ul>
<li><strong>Cosine similarity</strong> — measures the angle between two vectors (ignores magnitude; good for normalized embeddings)</li>
<li><strong>Euclidean distance (L2)</strong> — measures the straight-line distance between two points in space</li>
<li><strong>Inner product (IP)</strong> — dot product; useful for recommendation systems and unnormalized embeddings</li>
</ul>
</section>
<section id="why-you-need-a-dedicated-vector-database" class="level3">
<h3 class="anchored" data-anchor-id="why-you-need-a-dedicated-vector-database" id="why-you-need-a-dedicated-vector-database">Why You Need a Dedicated Vector Database</h3>
<p>Once you have millions of embeddings, you need to answer “find me the k nearest neighbors to this query vector” — this is called <strong>Approximate Nearest Neighbor (ANN) search</strong> — as quickly as possible.</p>
<p>A naive approach (compare query against every single vector) is called <strong>exact search</strong> or <strong>brute-force search</strong>. It works fine for thousands of vectors, but:</p>
<ul>
<li>At 1 million vectors of 512 dimensions, a brute-force search involves 512 million floating-point multiplications per query</li>
<li>At 100 million vectors, this becomes computationally untenable for real-time applications</li>
</ul>
<p><strong>Vector databases</strong> solve this by building <strong>indexes</strong> — clever data structures that allow you to skip most of the comparisons and still find results that are very close to the true nearest neighbors. This is the “approximate” in ANN: you trade a small amount of accuracy for enormous speed gains.</p>
<p>Milvus is one of the most powerful, production-ready, and feature-rich open-source vector databases available today.</p>
<hr>
</section>
</section>
<section id="sec-milvus" class="level2">
<h2 class="anchored" data-anchor-id="sec-milvus" id="sec-milvus">2. Introducing Milvus</h2>
<section id="what-is-milvus" class="level3">
<h3 class="anchored" data-anchor-id="what-is-milvus" id="what-is-milvus">What Is Milvus?</h3>
<p>Milvus is an <strong>open-source vector database</strong> built specifically for storing, indexing, and searching high-dimensional vector embeddings at massive scale. It was originally created by Zilliz and donated to the Linux Foundation AI &amp; Data.</p>
<p>Key properties of Milvus:</p>
<ul>
<li>Stores <strong>billions of vectors</strong> with sub-second query latency</li>
<li>Supports multiple <strong>index algorithms</strong> (IVF, HNSW, FLAT, ScaNN, DiskANN, and more)</li>
<li>Supports <strong>multiple distance metrics</strong> (L2, IP, Cosine)</li>
<li>Has a <strong>rich filtering system</strong> — combine vector search with scalar attribute filters (like SQL WHERE clauses)</li>
<li>Supports <strong>multi-tenancy</strong> through partitions and collections</li>
<li>Offers <strong>three deployment modes</strong>: Milvus Lite (local, no server), Standalone (single-node Docker), and Distributed (Kubernetes cluster)</li>
<li>First-class <strong>Python SDK</strong> (PyMilvus), plus SDKs for Go, Java, Node.js, and REST API</li>
</ul>
</section>
<section id="milvus-vs.-alternatives" class="level3">
<h3 class="anchored" data-anchor-id="milvus-vs.-alternatives" id="milvus-vs.-alternatives">Milvus vs.&nbsp;Alternatives</h3>
<table class="caption-top table">
<colgroup>
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
</colgroup>
<thead>
<tr class="header">
<th>Feature</th>
<th>Milvus</th>
<th>Pinecone</th>
<th>Weaviate</th>
<th>pgvector</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Open source</td>
<td>✅</td>
<td>❌ (cloud only)</td>
<td>✅</td>
<td>✅</td>
</tr>
<tr class="even">
<td>Scale</td>
<td>Billions</td>
<td>Millions</td>
<td>Millions</td>
<td>Millions</td>
</tr>
<tr class="odd">
<td>Deployment</td>
<td>Lite/Docker/K8s</td>
<td>Managed cloud</td>
<td>Docker/K8s</td>
<td>PostgreSQL extension</td>
</tr>
<tr class="even">
<td>Hybrid filtering</td>
<td>✅ Rich</td>
<td>✅</td>
<td>✅</td>
<td>✅</td>
</tr>
<tr class="odd">
<td>GPU indexing</td>
<td>✅</td>
<td>❌</td>
<td>❌</td>
<td>❌</td>
</tr>
<tr class="even">
<td>Best for</td>
<td>Production scale</td>
<td>Quick SaaS start</td>
<td>Semantic search</td>
<td>Existing Postgres apps</td>
</tr>
</tbody>
</table>
<p>For computer vision at scale, Milvus is a leading choice because of its support for very large datasets, GPU-accelerated indexing, and mature Python ecosystem.</p>
</section>
<section id="milvus-architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="milvus-architecture-overview" id="milvus-architecture-overview">Milvus Architecture Overview</h3>
<p>Milvus has a <strong>layered, disaggregated architecture</strong> — each layer can be scaled independently:</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A["Client (SDK / REST)"]
    B["Access Layer (Proxy nodes — load balancing, routing)"]
    C["Coordinator Layer RootCoord · QueryCoord · DataCoord · IndexCoord"]
    D["Worker Layer QueryNode · DataNode · IndexNode"]
    E["Storage Layer etcd (metadata) · MinIO/S3 (object store) Message Queue (Pulsar/Kafka)"]

    A --&gt; B
    B --&gt; C
    C --&gt; D
    D --&gt; E

    style A fill:#4A90D9,color:#fff,stroke:#2c6faa
    style B fill:#5BA85A,color:#fff,stroke:#3d7a3d
    style C fill:#E8A838,color:#fff,stroke:#b07a1a
    style D fill:#D95F5F,color:#fff,stroke:#a03030
    style E fill:#8B6BB1,color:#fff,stroke:#5c3d8a
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<p><strong>In plain English:</strong></p>
<ul>
<li><strong>Proxy nodes</strong> receive client requests and route them</li>
<li><strong>Coordinators</strong> manage cluster metadata, query planning, and data distribution</li>
<li><strong>Worker nodes</strong> do the actual heavy lifting: storing data, building indexes, executing searches</li>
<li><strong>Storage</strong> is separated from compute — data lives in object storage (S3/MinIO), metadata in etcd</li>
</ul>
<p>This separation is what allows Milvus to scale each component independently. You can add more QueryNodes to handle more queries without touching DataNodes.</p>
<hr>
</section>
</section>
<section id="sec-concepts" class="level2">
<h2 class="anchored" data-anchor-id="sec-concepts" id="sec-concepts">3. Core Concepts You Must Understand</h2>
<p>Before writing a single line of code, you need to internalize these concepts. They map to familiar database concepts but have important differences.</p>
<section id="collection" class="level3">
<h3 class="anchored" data-anchor-id="collection" id="collection">3.1 Collection</h3>
<p>A <strong>collection</strong> in Milvus is analogous to a <strong>table</strong> in a relational database. It is the top-level container that holds your data.</p>
<p>Each collection has:</p>
<ul>
<li>A <strong>schema</strong> — defines the fields (columns) and their types</li>
<li>One or more <strong>indexes</strong> — built on the vector field(s) to enable fast ANN search</li>
<li>Optional <strong>partitions</strong> — logical subdivisions within a collection</li>
</ul>
<p><strong>Example analogy:</strong> SQL Table <code>face_embeddings</code> → Milvus Collection <code>face_embeddings</code></p>
</section>
<section id="schema-and-fields" class="level3">
<h3 class="anchored" data-anchor-id="schema-and-fields" id="schema-and-fields">3.2 Schema and Fields</h3>
<p>A Milvus schema defines the structure of every entity (row) in the collection. Each schema must have:</p>
<ol type="1">
<li><strong>A primary key field</strong> — a unique ID for each entity. Can be <code>INT64</code> (auto-generated or user-provided) or <code>VARCHAR</code>.</li>
<li><strong>At least one vector field</strong> — stores the embedding. Must specify the number of dimensions.</li>
<li><strong>Optional scalar fields</strong> — additional metadata like file path, label, timestamp, confidence score.</li>
</ol>
<p>Supported scalar field types:</p>
<ul>
<li><code>INT8</code>, <code>INT16</code>, <code>INT32</code>, <code>INT64</code></li>
<li><code>FLOAT</code>, <code>DOUBLE</code></li>
<li><code>BOOL</code></li>
<li><code>VARCHAR</code> (up to 65,535 characters)</li>
<li><code>JSON</code> — unstructured key-value data (powerful for flexible metadata)</li>
<li><code>ARRAY</code> — fixed-type arrays</li>
</ul>
<p><strong>Supported vector field types:</strong></p>
<ul>
<li><code>FLOAT_VECTOR</code> — 32-bit floating point vectors (most common)</li>
<li><code>BINARY_VECTOR</code> — packed binary vectors (more compact, useful for hashing-based embeddings)</li>
<li><code>FLOAT16_VECTOR</code> — 16-bit half-precision (reduces memory, slight accuracy tradeoff)</li>
<li><code>BFLOAT16_VECTOR</code> — brain float 16 (popular in ML hardware)</li>
<li><code>SPARSE_FLOAT_VECTOR</code> — for sparse representations (BM25, SPLADE)</li>
</ul>
</section>
<section id="entity" class="level3">
<h3 class="anchored" data-anchor-id="entity" id="entity">3.3 Entity</h3>
<p>An <strong>entity</strong> is a single record (row) in a collection. It contains values for all fields defined in the schema. When you insert data, you insert entities.</p>
</section>
<section id="segment" class="level3">
<h3 class="anchored" data-anchor-id="segment" id="segment">3.4 Segment</h3>
<p>Internally, Milvus divides data in a collection into <strong>segments</strong> — immutable chunks of data that are individually indexed. When a segment reaches a certain size threshold, it is “sealed” and an index is built on it. Smaller “growing segments” handle newly inserted data before they are sealed.</p>
<p>You rarely interact with segments directly, but understanding them explains behaviors like “why don’t my newly inserted vectors appear in search results immediately?”</p>
</section>
<section id="partition" class="level3">
<h3 class="anchored" data-anchor-id="partition" id="partition">3.5 Partition</h3>
<p>A <strong>partition</strong> is a logical subdivision of a collection. Think of it as a sub-table that can be searched independently or together.</p>
<p><strong>Why use partitions?</strong></p>
<ul>
<li>To scope searches to a subset of data (e.g., search only videos from “2024”)</li>
<li>To logically separate data (e.g., one partition per camera, one per user)</li>
<li>They improve query performance when you know which partition to target</li>
</ul>
<p>Every collection has a default partition called <code>_default</code>.</p>
</section>
<section id="index" class="level3">
<h3 class="anchored" data-anchor-id="index" id="index">3.6 Index</h3>
<p>An <strong>index</strong> is a data structure built on a vector field that makes ANN search fast. Milvus supports many index types:</p>
<ul>
<li><strong>FLAT</strong> — brute-force exact search. Perfect accuracy, slow at scale.</li>
<li><strong>IVF_FLAT</strong> — inverted file index. Divides vectors into clusters; searches only relevant clusters.</li>
<li><strong>IVF_SQ8</strong> — like IVF_FLAT but with scalar quantization (compresses vectors to 8-bit; saves memory).</li>
<li><strong>IVF_PQ</strong> — product quantization; extreme compression, lower accuracy.</li>
<li><strong>HNSW</strong> — Hierarchical Navigable Small World graph. Excellent speed/accuracy tradeoff; the gold standard for most use cases.</li>
<li><strong>SCANN</strong> — Google’s ScaNN algorithm; highly optimized for recall.</li>
<li><strong>DiskANN</strong> — designed for datasets too large to fit in RAM; stores index on disk.</li>
<li><strong>GPU_IVF_FLAT</strong>, <strong>GPU_CAGRA</strong> — GPU-accelerated variants.</li>
</ul>
<p>Choosing the right index is one of the most important decisions in your Milvus deployment. We cover this in detail in Section 9.</p>
</section>
<section id="distance-metrics" class="level3">
<h3 class="anchored" data-anchor-id="distance-metrics" id="distance-metrics">3.7 Distance Metrics</h3>
<p>When performing a vector search, Milvus computes a <strong>distance</strong> between the query vector and every candidate vector. The three supported metrics are:</p>
<p><strong>L2 (Euclidean Distance)</strong> <span class="math display">\[
d(a, b) = \sqrt{\sum_i (a_i - b_i)^2}
\]</span> Lower = more similar. Best for embeddings that are not normalized to unit length.</p>
<p><strong>IP (Inner Product / Dot Product)</strong> <span class="math display">\[
d(a, b) = \sum_i a_i \cdot b_i
\]</span> Higher = more similar. For normalized vectors, IP is equivalent to cosine similarity.</p>
<p><strong>Cosine</strong></p>
<p><span class="math display">\[
d(a, b) = 1 - \frac{a \cdot b}{\|a\| \, \|b\|}
\]</span></p>
<p>Lower = more similar. Measures angular distance; invariant to vector magnitude.</p>
<p><strong>Rule of thumb:</strong> If your embedding model normalizes its output (most do), use <strong>IP</strong> or <strong>Cosine</strong>. If not normalized, use <strong>L2</strong>.</p>
<hr>
</section>
</section>
<section id="sec-cvpipeline" class="level2">
<h2 class="anchored" data-anchor-id="sec-cvpipeline" id="sec-cvpipeline">4. How Computer Vision Meets Vector Search</h2>
<section id="the-general-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="the-general-pipeline" id="the-general-pipeline">The General Pipeline</h3>
<p>Every computer vision application that uses Milvus follows the same fundamental pipeline:</p>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A["Raw Image (or frame)"]
    B["Embedding Model (CNN, ViT, etc.)"]
    C["Feature Vector f₁, f₂, ..., fₙ"]
    D[("Milvus Collection id · vector · metadata ────────────────── 1 · [...] · dog.jpg 2 · [...] · cat.png")]
    E["Query Embed new image Search k-NN Return IDs"]

    A --&gt; B
    B --&gt; C
    C --&gt; D
    D --&gt; E

    style A fill:#E8F4FD,stroke:#4A90D9
    style B fill:#FEF9E7,stroke:#E8A838
    style C fill:#EAF7EA,stroke:#5BA85A
    style D fill:#F4ECF7,stroke:#8B6BB1
    style E fill:#FDEDEC,stroke:#D95F5F
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<p><strong>Two phases:</strong></p>
<ol type="1">
<li><strong>Ingestion (offline):</strong> Extract embeddings from all your images and insert them into Milvus along with metadata.</li>
<li><strong>Query (online):</strong> For a new query image, extract its embedding, send it to Milvus, receive the IDs of the most similar images.</li>
</ol>
</section>
<section id="choosing-the-right-embedding-dimensionality" class="level3">
<h3 class="anchored" data-anchor-id="choosing-the-right-embedding-dimensionality" id="choosing-the-right-embedding-dimensionality">Choosing the Right Embedding Dimensionality</h3>
<p>Different models produce embeddings of different sizes:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model Family</th>
<th>Typical Dimensions</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>ResNet-50 (pool layer)</td>
<td>2048</td>
<td>Large; very expressive</td>
</tr>
<tr class="even">
<td>EfficientNet-B0</td>
<td>1280</td>
<td>Good accuracy/size tradeoff</td>
</tr>
<tr class="odd">
<td>CLIP ViT-B/32</td>
<td>512</td>
<td>Multi-modal (text+image)</td>
</tr>
<tr class="even">
<td>CLIP ViT-L/14</td>
<td>768</td>
<td>Larger, more accurate</td>
</tr>
<tr class="odd">
<td>DINOv2 ViT-S/14</td>
<td>384</td>
<td>Efficient, self-supervised</td>
</tr>
<tr class="even">
<td>DINOv2 ViT-g/14</td>
<td>1536</td>
<td>Highest quality, expensive</td>
</tr>
<tr class="odd">
<td>Face (ArcFace, FaceNet)</td>
<td>128–512</td>
<td>Specialized for identity</td>
</tr>
</tbody>
</table>
<p><strong>Higher dimensions = more expressive but more memory and slower search.</strong> Always test with your target data to find the right model for your use case.</p>
</section>
<section id="normalization" class="level3">
<h3 class="anchored" data-anchor-id="normalization" id="normalization">Normalization</h3>
<p>Most ANN indexes and distance metrics assume your vectors are <strong>L2-normalized</strong> (unit vectors). Normalize before inserting:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> normalize(vector: np.ndarray) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Normalize a vector to unit length (L2 norm = 1)."""</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    norm <span class="op">=</span> np.linalg.norm(vector)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> norm <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> vector</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vector <span class="op">/</span> norm</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co"># For a batch of vectors (shape: [N, D])</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> normalize_batch(vectors: np.ndarray) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    norms <span class="op">=</span> np.linalg.norm(vectors, axis<span class="op">=</span><span class="dv">1</span>, keepdims<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    norms <span class="op">=</span> np.where(norms <span class="op">==</span> <span class="dv">0</span>, <span class="dv">1</span>, norms)  <span class="co"># avoid division by zero</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vectors <span class="op">/</span> norms</span></code></pre></div></div>
<p>Check your model’s documentation — many models (CLIP, DINOv2) already output normalized embeddings.</p>
<hr>
</section>
</section>
<section id="sec-setup" class="level2">
<h2 class="anchored" data-anchor-id="sec-setup" id="sec-setup">5. Setting Up Your Environment</h2>
<section id="python-prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="python-prerequisites" id="python-prerequisites">Python Prerequisites</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and activate a virtual environment (recommended)</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> venv milvus-cv-env</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="bu">source</span> milvus-cv-env/bin/activate  <span class="co"># Linux/Mac</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co"># milvus-cv-env\Scripts\activate   # Windows</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Install the Milvus Python SDK</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install pymilvus</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Install pymilvus with MilvusClient support (recommended, includes model utilities)</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="st">"pymilvus[model]"</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Common CV libraries</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install numpy pillow</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision  <span class="co"># if using PyTorch models</span></span></code></pre></div></div>
</section>
<section id="verifying-the-installation" class="level3">
<h3 class="anchored" data-anchor-id="verifying-the-installation" id="verifying-the-installation">Verifying the Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pymilvus</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(pymilvus.__version__)  <span class="co"># Should print e.g. "2.4.x"</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"PyMilvus installed correctly"</span>)</span></code></pre></div></div>
</section>
<section id="sdk-version-compatibility" class="level3">
<h3 class="anchored" data-anchor-id="sdk-version-compatibility" id="sdk-version-compatibility">SDK Version Compatibility</h3>
<p>Always match your SDK version to your Milvus server version. Milvus uses semantic versioning (<code>MAJOR.MINOR.PATCH</code>). The SDK minor version should match the server minor version.</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Milvus Server</th>
<th>PyMilvus SDK</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>2.4.x</td>
<td>2.4.x</td>
</tr>
<tr class="even">
<td>2.3.x</td>
<td>2.3.x</td>
</tr>
<tr class="odd">
<td>2.2.x</td>
<td>2.2.x</td>
</tr>
</tbody>
</table>
<hr>
</section>
</section>
<section id="sec-deployment" class="level2">
<h2 class="anchored" data-anchor-id="sec-deployment" id="sec-deployment">6. Deployment Options: Lite → Docker → Kubernetes</h2>
<section id="milvus-lite-local-development" class="level3">
<h3 class="anchored" data-anchor-id="milvus-lite-local-development" id="milvus-lite-local-development">6.1 Milvus Lite (Local Development)</h3>
<p><strong>Milvus Lite</strong> is a lightweight, serverless version of Milvus that runs entirely in-process — no server to start, no Docker required. It stores data in a local SQLite-like file.</p>
<p><strong>Ideal for:</strong> prototyping, unit tests, notebooks, offline processing on a single machine.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>Not suitable for production (single process, limited concurrency)</li>
<li>No distributed indexing, no GPU support</li>
<li>Maximum dataset size is limited by local RAM/disk</li>
</ul>
<p><strong>Installation:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install milvus-lite  <span class="co"># already included in pymilvus &gt;= 2.4.2</span></span></code></pre></div></div>
<p><strong>Usage:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Pass a file path — Milvus Lite creates/opens a local database file</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./my_cv_database.db"</span>)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Connected to Milvus Lite"</span>)</span></code></pre></div></div>
<p>That’s it. No servers, no configuration. The database file is portable and can be copied between machines.</p>
<p><strong>Checking stored data:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># List all collections in this database</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>collections <span class="op">=</span> client.list_collections()</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(collections)</span></code></pre></div></div>
<p><strong>When to move beyond Milvus Lite:</strong></p>
<ul>
<li>Your dataset exceeds a few million vectors</li>
<li>You need multi-user concurrent access</li>
<li>You need production reliability (backups, replication, crash recovery)</li>
<li>You want GPU-accelerated indexing</li>
</ul>
<hr>
</section>
<section id="standalone-milvus-docker-docker-compose" class="level3">
<h3 class="anchored" data-anchor-id="standalone-milvus-docker-docker-compose" id="standalone-milvus-docker-docker-compose">6.2 Standalone Milvus (Docker / Docker Compose)</h3>
<p><strong>Standalone Milvus</strong> runs Milvus as a set of Docker containers on a single machine. It includes all components: the Milvus server, etcd (for metadata), and MinIO (for object storage).</p>
<p><strong>Ideal for:</strong> single-machine production use, team development environments, moderate-scale deployments (tens of millions of vectors).</p>
<section id="installing-docker" class="level4">
<h4 class="anchored" data-anchor-id="installing-docker">Installing Docker</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Ubuntu/Debian</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="fu">sudo</span> apt-get update <span class="kw">&amp;&amp;</span> <span class="fu">sudo</span> apt-get install docker.io docker-compose-plugin <span class="at">-y</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="fu">sudo</span> systemctl enable <span class="at">--now</span> docker</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="fu">sudo</span> usermod <span class="at">-aG</span> docker <span class="va">$USER</span>  <span class="co"># allow running docker without sudo (re-login required)</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="co"># macOS — Install Docker Desktop from https://www.docker.com/products/docker-desktop/</span></span></code></pre></div></div>
</section>
<section id="starting-standalone-milvus-with-docker-compose" class="level4">
<h4 class="anchored" data-anchor-id="starting-standalone-milvus-with-docker-compose">Starting Standalone Milvus with Docker Compose</h4>
<p>Download the official <code>docker-compose.yml</code>:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Download the compose file</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="fu">wget</span> https://github.com/milvus-io/milvus/releases/download/v2.4.0/milvus-standalone-docker-compose.yml <span class="dt">\</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>     <span class="at">-O</span> docker-compose.yml</span></code></pre></div></div>
<p>The file looks like this (simplified):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="fu">version</span><span class="kw">:</span><span class="at"> </span><span class="st">'3.5'</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="fu">services</span><span class="kw">:</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">etcd</span><span class="kw">:</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">container_name</span><span class="kw">:</span><span class="at"> milvus-etcd</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> quay.io/coreos/etcd:v3.5.5</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ETCD_AUTO_COMPACTION_MODE=revision</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ETCD_AUTO_COMPACTION_RETENTION=1000</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ETCD_QUOTA_BACKEND_BYTES=4294967296</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ETCD_SNAPSHOT_COUNT=50000</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">command</span><span class="kw">:</span><span class="at"> etcd -advertise-client-urls=http://127.0.0.1:2379</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a><span class="at">      -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">minio</span><span class="kw">:</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">container_name</span><span class="kw">:</span><span class="at"> milvus-minio</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> minio/minio:RELEASE.2023-03-13T19-46-17Z</span></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">MINIO_ACCESS_KEY</span><span class="kw">:</span><span class="at"> minioadmin</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">MINIO_SECRET_KEY</span><span class="kw">:</span><span class="at"> minioadmin</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">command</span><span class="kw">:</span><span class="at"> minio server /minio_data</span></span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">healthcheck</span><span class="kw">:</span></span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">test</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">"CMD"</span><span class="kw">,</span><span class="at"> </span><span class="st">"curl"</span><span class="kw">,</span><span class="at"> </span><span class="st">"-f"</span><span class="kw">,</span><span class="at"> </span><span class="st">"http://localhost:9000/minio/health/live"</span><span class="kw">]</span></span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">standalone</span><span class="kw">:</span></span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">container_name</span><span class="kw">:</span><span class="at"> milvus-standalone</span></span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> milvusdb/milvus:v2.4.0</span></span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">command</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">"milvus"</span><span class="kw">,</span><span class="at"> </span><span class="st">"run"</span><span class="kw">,</span><span class="at"> </span><span class="st">"standalone"</span><span class="kw">]</span></span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">ETCD_ENDPOINTS</span><span class="kw">:</span><span class="at"> etcd:2379</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">MINIO_ADDRESS</span><span class="kw">:</span><span class="at"> minio:9000</span></span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"19530:19530"</span><span class="co">   # gRPC port (SDK connects here)</span></span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"9091:9091"</span><span class="co">     # HTTP/metrics port</span></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">depends_on</span><span class="kw">:</span></span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> etcd</span></span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> minio</span></span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a><span class="fu">networks</span><span class="kw">:</span></span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">default</span><span class="kw">:</span></span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">name</span><span class="kw">:</span><span class="at"> milvus</span></span></code></pre></div></div>
<p><strong>Start it:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> compose up <span class="at">-d</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Check that all three containers are running</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> compose ps</span></code></pre></div></div>
<p>Expected output:</p>
<pre><code>NAME                 STATUS
milvus-etcd          running
milvus-minio         running
milvus-standalone    running</code></pre>
<p><strong>Connect from Python:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Connect to the running Milvus server</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Default port is 19530</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(uri<span class="op">=</span><span class="st">"http://localhost:19530"</span>)</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Connected to Milvus Standalone"</span>)</span></code></pre></div></div>
<p><strong>Stop and remove containers:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> compose down           <span class="co"># Stop containers, preserve data volumes</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> compose down <span class="at">-v</span>        <span class="co"># Stop containers AND delete all data (destructive!)</span></span></code></pre></div></div>
</section>
<section id="persistent-volumes" class="level4">
<h4 class="anchored" data-anchor-id="persistent-volumes">Persistent Volumes</h4>
<p>By default, data is stored in <code>./volumes/</code> relative to where you ran the compose command. Back up this directory to preserve your data.</p>
</section>
<section id="resource-recommendations-for-standalone" class="level4">
<h4 class="anchored" data-anchor-id="resource-recommendations-for-standalone">Resource Recommendations for Standalone</h4>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Dataset Size</th>
<th>RAM</th>
<th>CPU</th>
<th>Disk</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>&lt; 10M vectors</td>
<td>16 GB</td>
<td>4 cores</td>
<td>100 GB SSD</td>
</tr>
<tr class="even">
<td>10–50M vectors</td>
<td>32–64 GB</td>
<td>8 cores</td>
<td>500 GB SSD</td>
</tr>
<tr class="odd">
<td>50–100M vectors</td>
<td>64–128 GB</td>
<td>16 cores</td>
<td>1 TB SSD</td>
</tr>
</tbody>
</table>
<hr>
</section>
</section>
<section id="distributed-milvus-on-kubernetes" class="level3">
<h3 class="anchored" data-anchor-id="distributed-milvus-on-kubernetes" id="distributed-milvus-on-kubernetes">6.3 Distributed Milvus on Kubernetes</h3>
<p><strong>Distributed Milvus</strong> is the full production-grade deployment. Each component (QueryNode, DataNode, IndexNode, Proxy) runs as a separate pod and scales independently.</p>
<p><strong>Ideal for:</strong> billion-scale datasets, high-availability requirements, multi-region deployments, enterprise use cases.</p>
<section id="prerequisites" class="level4">
<h4 class="anchored" data-anchor-id="prerequisites">Prerequisites</h4>
<ul>
<li>A running Kubernetes cluster (EKS, GKE, AKS, or self-hosted with kubeadm)</li>
<li><code>kubectl</code> configured to access your cluster</li>
<li><code>helm</code> (Kubernetes package manager) installed</li>
</ul>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install Helm</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 <span class="kw">|</span> <span class="fu">bash</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Verify</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a><span class="ex">helm</span> version</span></code></pre></div></div>
</section>
<section id="adding-the-milvus-helm-repository" class="level4">
<h4 class="anchored" data-anchor-id="adding-the-milvus-helm-repository">Adding the Milvus Helm Repository</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="ex">helm</span> repo add milvus https://zilliztech.github.io/milvus-helm/</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="ex">helm</span> repo update</span></code></pre></div></div>
</section>
<section id="minimal-distributed-deployment" class="level4">
<h4 class="anchored" data-anchor-id="minimal-distributed-deployment">Minimal Distributed Deployment</h4>
<p>Create a <code>values.yaml</code> to customize your deployment:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># values.yaml — Minimal distributed Milvus configuration</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="fu">cluster</span><span class="kw">:</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">enabled</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co">  # Enable distributed mode</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Component replica counts</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a><span class="fu">proxy</span><span class="kw">:</span></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="fu">queryNode</span><span class="kw">:</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"8Gi"</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"2"</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"16Gi"</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"4"</span></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a><span class="fu">dataNode</span><span class="kw">:</span></span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"4Gi"</span></span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a><span class="fu">indexNode</span><span class="kw">:</span></span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"8Gi"</span></span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"4"</span></span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Message queue (Pulsar for distributed mode)</span></span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a><span class="fu">pulsar</span><span class="kw">:</span></span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">enabled</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Object storage (MinIO deployed alongside)</span></span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a><span class="fu">minio</span><span class="kw">:</span></span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">enabled</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">mode</span><span class="kw">:</span><span class="at"> distributed</span></span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span></span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a><span class="co"># Metadata store</span></span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a><span class="fu">etcd</span><span class="kw">:</span></span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicaCount</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span><span class="co">  # etcd should run as odd number for quorum</span></span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a><span class="co"># Expose the service</span></span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a><span class="fu">service</span><span class="kw">:</span></span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">type</span><span class="kw">:</span><span class="at"> LoadBalancer</span></span></code></pre></div></div>
<p><strong>Deploy:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a dedicated namespace</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> create namespace milvus</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Deploy Milvus</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="ex">helm</span> install milvus milvus/milvus <span class="dt">\</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>  <span class="at">--namespace</span> milvus <span class="dt">\</span></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>  <span class="at">-f</span> values.yaml <span class="dt">\</span></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>  <span class="at">--timeout</span> 15m <span class="dt">\</span></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>  <span class="at">--wait</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Check pod status</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> get pods <span class="at">-n</span> milvus</span></code></pre></div></div>
<p>Expected pods:</p>
<pre><code>NAME                                  READY   STATUS
milvus-datacoord-xxx                  1/1     Running
milvus-datanode-xxx                   1/1     Running
milvus-etcd-0                         1/1     Running
milvus-etcd-1                         1/1     Running
milvus-etcd-2                         1/1     Running
milvus-indexcoord-xxx                 1/1     Running
milvus-indexnode-xxx                  1/1     Running
milvus-minio-0                        1/1     Running
milvus-proxy-xxx                      1/1     Running
milvus-querycoord-xxx                 1/1     Running
milvus-querynode-0                    1/1     Running
milvus-querynode-1                    1/1     Running
milvus-rootcoord-xxx                  1/1     Running</code></pre>
<p><strong>Get the external IP:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> get svc <span class="at">-n</span> milvus milvus</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="co"># EXTERNAL-IP column shows the load balancer IP</span></span></code></pre></div></div>
<p><strong>Connect from Python:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>MILVUS_HOST <span class="op">=</span> <span class="st">"YOUR_EXTERNAL_IP"</span>  <span class="co"># from kubectl get svc</span></span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(uri<span class="op">=</span><span class="ss">f"http://</span><span class="sc">{</span>MILVUS_HOST<span class="sc">}</span><span class="ss">:19530"</span>)</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Connected to Milvus Distributed"</span>)</span></code></pre></div></div>
</section>
<section id="scaling-components" class="level4">
<h4 class="anchored" data-anchor-id="scaling-components">Scaling Components</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Scale QueryNodes to handle more concurrent searches</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> scale deployment milvus-querynode <span class="at">-n</span> milvus <span class="at">--replicas</span><span class="op">=</span>5</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Scale DataNodes to handle faster data ingestion</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> scale deployment milvus-datanode <span class="at">-n</span> milvus <span class="at">--replicas</span><span class="op">=</span>3</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Scale IndexNodes for faster index building</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> scale deployment milvus-indexnode <span class="at">-n</span> milvus <span class="at">--replicas</span><span class="op">=</span>2</span></code></pre></div></div>
</section>
<section id="gpu-support-on-kubernetes" class="level4">
<h4 class="anchored" data-anchor-id="gpu-support-on-kubernetes">GPU Support on Kubernetes</h4>
<p>To enable GPU-accelerated indexing, add GPU node selectors and requests:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="co"># In values.yaml — GPU configuration for IndexNode</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="fu">indexNode</span><span class="kw">:</span></span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">nvidia.com/gpu</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span><span class="co">  # Request 1 GPU per pod</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">nodeSelector</span><span class="kw">:</span></span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">accelerator</span><span class="kw">:</span><span class="at"> nvidia-gpu</span></span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">tolerations</span><span class="kw">:</span></span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">key</span><span class="kw">:</span><span class="at"> </span><span class="st">"nvidia.com/gpu"</span></span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">operator</span><span class="kw">:</span><span class="at"> </span><span class="st">"Exists"</span></span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">effect</span><span class="kw">:</span><span class="at"> </span><span class="st">"NoSchedule"</span></span></code></pre></div></div>
<p>You must also have the NVIDIA device plugin installed in your cluster:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="ex">kubectl</span> apply <span class="at">-f</span> https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.14.0/nvidia-device-plugin.yml</span></code></pre></div></div>
<hr>
</section>
</section>
</section>
<section id="sec-collections" class="level2">
<h2 class="anchored" data-anchor-id="sec-collections" id="sec-collections">7. Working with Collections and Schemas</h2>
<section id="the-milvusclient-api" class="level3">
<h3 class="anchored" data-anchor-id="the-milvusclient-api" id="the-milvusclient-api">The MilvusClient API</h3>
<p>PyMilvus offers two API styles:</p>
<ul>
<li><strong><code>MilvusClient</code></strong> — simplified, high-level API (recommended for most use cases)</li>
<li><strong><code>connections</code> + <code>Collection</code></strong> — lower-level ORM-style API (more control)</li>
</ul>
<p>This guide uses <code>MilvusClient</code> throughout, as it is the modern recommended approach.</p>
</section>
<section id="connecting-works-for-all-deployment-modes" class="level3">
<h3 class="anchored" data-anchor-id="connecting-works-for-all-deployment-modes" id="connecting-works-for-all-deployment-modes">Connecting (works for all deployment modes)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Milvus Lite</span></span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./cv_database.db"</span>)</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Standalone (Docker)</span></span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(uri<span class="op">=</span><span class="st">"http://localhost:19530"</span>)</span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a><span class="co"># With authentication (if enabled)</span></span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(</span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>    uri<span class="op">=</span><span class="st">"http://localhost:19530"</span>,</span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>    token<span class="op">=</span><span class="st">"root:Milvus"</span>  <span class="co"># format: "username:password"</span></span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-15"><a href="#cb26-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Distributed (Kubernetes)</span></span>
<span id="cb26-16"><a href="#cb26-16" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(uri<span class="op">=</span><span class="st">"http://EXTERNAL_IP:19530"</span>)</span></code></pre></div></div>
</section>
<section id="defining-a-schema" class="level3">
<h3 class="anchored" data-anchor-id="defining-a-schema" id="defining-a-schema">Defining a Schema</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, DataType</span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./cv_database.db"</span>)</span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a schema</span></span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a>schema <span class="op">=</span> client.create_schema(</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a>    auto_id<span class="op">=</span><span class="va">True</span>,           <span class="co"># Milvus auto-generates the primary key</span></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a>    enable_dynamic_field<span class="op">=</span><span class="va">True</span>,  <span class="co"># Allow inserting extra fields not in schema</span></span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Add the primary key field</span></span>
<span id="cb27-12"><a href="#cb27-12" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-13"><a href="#cb27-13" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"id"</span>,</span>
<span id="cb27-14"><a href="#cb27-14" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.INT64,</span>
<span id="cb27-15"><a href="#cb27-15" aria-hidden="true" tabindex="-1"></a>    is_primary<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb27-16"><a href="#cb27-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-17"><a href="#cb27-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-18"><a href="#cb27-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Add the vector field — CRITICAL: dim must match your embedding model's output size</span></span>
<span id="cb27-19"><a href="#cb27-19" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-20"><a href="#cb27-20" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb27-21"><a href="#cb27-21" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.FLOAT_VECTOR,</span>
<span id="cb27-22"><a href="#cb27-22" aria-hidden="true" tabindex="-1"></a>    dim<span class="op">=</span><span class="dv">512</span>,  <span class="co"># Change this to match your model (e.g., 768, 1536, 2048)</span></span>
<span id="cb27-23"><a href="#cb27-23" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-24"><a href="#cb27-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-25"><a href="#cb27-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Add scalar metadata fields</span></span>
<span id="cb27-26"><a href="#cb27-26" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-27"><a href="#cb27-27" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"image_path"</span>,</span>
<span id="cb27-28"><a href="#cb27-28" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.VARCHAR,</span>
<span id="cb27-29"><a href="#cb27-29" aria-hidden="true" tabindex="-1"></a>    max_length<span class="op">=</span><span class="dv">1024</span>,</span>
<span id="cb27-30"><a href="#cb27-30" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-31"><a href="#cb27-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-32"><a href="#cb27-32" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-33"><a href="#cb27-33" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"label"</span>,</span>
<span id="cb27-34"><a href="#cb27-34" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.VARCHAR,</span>
<span id="cb27-35"><a href="#cb27-35" aria-hidden="true" tabindex="-1"></a>    max_length<span class="op">=</span><span class="dv">128</span>,</span>
<span id="cb27-36"><a href="#cb27-36" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-37"><a href="#cb27-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-38"><a href="#cb27-38" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-39"><a href="#cb27-39" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"confidence"</span>,</span>
<span id="cb27-40"><a href="#cb27-40" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.FLOAT,</span>
<span id="cb27-41"><a href="#cb27-41" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb27-42"><a href="#cb27-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-43"><a href="#cb27-43" aria-hidden="true" tabindex="-1"></a>schema.add_field(</span>
<span id="cb27-44"><a href="#cb27-44" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"timestamp"</span>,</span>
<span id="cb27-45"><a href="#cb27-45" aria-hidden="true" tabindex="-1"></a>    datatype<span class="op">=</span>DataType.INT64,  <span class="co"># store as Unix epoch milliseconds</span></span>
<span id="cb27-46"><a href="#cb27-46" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="creating-index-parameters" class="level3">
<h3 class="anchored" data-anchor-id="creating-index-parameters" id="creating-index-parameters">Creating Index Parameters</h3>
<p>Before creating the collection, define how the vector field should be indexed:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient</span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Define index parameters for the vector field</span></span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a>index_params <span class="op">=</span> client.prepare_index_params()</span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,      <span class="co"># must match your vector field name</span></span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"HNSW"</span>,           <span class="co"># index algorithm (see Section 9 for all options)</span></span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,        <span class="co"># distance metric: L2, IP, or COSINE</span></span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"M"</span>: <span class="dv">16</span>,                 <span class="co"># HNSW: number of neighbors per node (8–64; higher = better recall, more memory)</span></span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"efConstruction"</span>: <span class="dv">200</span>,   <span class="co"># HNSW: build-time search depth (higher = better quality index, slower build)</span></span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Also create an index on a scalar field for fast filtering</span></span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"label"</span>,</span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"Trie"</span>,           <span class="co"># inverted index for VARCHAR fields</span></span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="creating-the-collection" class="level3">
<h3 class="anchored" data-anchor-id="creating-the-collection" id="creating-the-collection">Creating the Collection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the collection with the schema and index parameters</span></span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a>client.create_collection(</span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a>    schema<span class="op">=</span>schema,</span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a>    index_params<span class="op">=</span>index_params,</span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Collection created successfully"</span>)</span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Verify it exists</span></span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a>collections <span class="op">=</span> client.list_collections()</span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Collections: </span><span class="sc">{</span>collections<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-14"><a href="#cb29-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Get collection info</span></span>
<span id="cb29-15"><a href="#cb29-15" aria-hidden="true" tabindex="-1"></a>info <span class="op">=</span> client.describe_collection(<span class="st">"image_embeddings"</span>)</span>
<span id="cb29-16"><a href="#cb29-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(info)</span></code></pre></div></div>
</section>
<section id="quick-collection-creation-simplified-api" class="level3">
<h3 class="anchored" data-anchor-id="quick-collection-creation-simplified-api" id="quick-collection-creation-simplified-api">Quick Collection Creation (Simplified API)</h3>
<p>For rapid prototyping, MilvusClient allows creating a collection with just a dimension:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Creates a collection with auto schema: id (INT64 PK) + vector (FLOAT_VECTOR)</span></span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a>client.create_collection(</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"quick_test"</span>,</span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a>    dimension<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a><span class="co"># This is great for testing but you cannot add custom metadata fields this way</span></span></code></pre></div></div>
</section>
<section id="dropping-a-collection" class="level3">
<h3 class="anchored" data-anchor-id="dropping-a-collection" id="dropping-a-collection">Dropping a Collection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="co"># </span><span class="al">WARNING</span><span class="co">: This permanently deletes all data in the collection</span></span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a>client.drop_collection(<span class="st">"image_embeddings"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-inserting" class="level2">
<h2 class="anchored" data-anchor-id="sec-inserting" id="sec-inserting">8. Inserting Embedding Vectors</h2>
<section id="basic-insertion" class="level3">
<h3 class="anchored" data-anchor-id="basic-insertion" id="basic-insertion">Basic Insertion</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate embedding extraction</span></span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mock_embed(n: <span class="bu">int</span>, dim: <span class="bu">int</span> <span class="op">=</span> <span class="dv">512</span>) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Generate random normalized vectors to simulate embeddings."""</span></span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>    vectors <span class="op">=</span> np.random.randn(n, dim).astype(np.float32)</span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>    norms <span class="op">=</span> np.linalg.norm(vectors, axis<span class="op">=</span><span class="dv">1</span>, keepdims<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (vectors <span class="op">/</span> norms).tolist()</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare data as a list of dicts (one dict per entity)</span></span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [</span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>    {</span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># "id" is omitted because auto_id=True</span></span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">"embedding"</span>: mock_embed(<span class="dv">1</span>, dim<span class="op">=</span><span class="dv">512</span>)[<span class="dv">0</span>],</span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>        <span class="st">"image_path"</span>: <span class="st">"/dataset/images/dog_001.jpg"</span>,</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>        <span class="st">"label"</span>: <span class="st">"dog"</span>,</span>
<span id="cb32-18"><a href="#cb32-18" aria-hidden="true" tabindex="-1"></a>        <span class="st">"confidence"</span>: <span class="fl">0.97</span>,</span>
<span id="cb32-19"><a href="#cb32-19" aria-hidden="true" tabindex="-1"></a>        <span class="st">"timestamp"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb32-20"><a href="#cb32-20" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb32-21"><a href="#cb32-21" aria-hidden="true" tabindex="-1"></a>    {</span>
<span id="cb32-22"><a href="#cb32-22" aria-hidden="true" tabindex="-1"></a>        <span class="st">"embedding"</span>: mock_embed(<span class="dv">1</span>, dim<span class="op">=</span><span class="dv">512</span>)[<span class="dv">0</span>],</span>
<span id="cb32-23"><a href="#cb32-23" aria-hidden="true" tabindex="-1"></a>        <span class="st">"image_path"</span>: <span class="st">"/dataset/images/cat_002.jpg"</span>,</span>
<span id="cb32-24"><a href="#cb32-24" aria-hidden="true" tabindex="-1"></a>        <span class="st">"label"</span>: <span class="st">"cat"</span>,</span>
<span id="cb32-25"><a href="#cb32-25" aria-hidden="true" tabindex="-1"></a>        <span class="st">"confidence"</span>: <span class="fl">0.92</span>,</span>
<span id="cb32-26"><a href="#cb32-26" aria-hidden="true" tabindex="-1"></a>        <span class="st">"timestamp"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb32-27"><a href="#cb32-27" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb32-28"><a href="#cb32-28" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb32-29"><a href="#cb32-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-30"><a href="#cb32-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Insert the data</span></span>
<span id="cb32-31"><a href="#cb32-31" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> client.insert(</span>
<span id="cb32-32"><a href="#cb32-32" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb32-33"><a href="#cb32-33" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>data,</span>
<span id="cb32-34"><a href="#cb32-34" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb32-35"><a href="#cb32-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-36"><a href="#cb32-36" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Inserted </span><span class="sc">{</span>result[<span class="st">'insert_count'</span>]<span class="sc">}</span><span class="ss"> entities"</span>)</span>
<span id="cb32-37"><a href="#cb32-37" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Primary keys: </span><span class="sc">{</span>result[<span class="st">'ids'</span>]<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="batch-insertion-production-pattern" class="level3">
<h3 class="anchored" data-anchor-id="batch-insertion-production-pattern" id="batch-insertion-production-pattern">Batch Insertion (Production Pattern)</h3>
<p>For large datasets, always insert in batches. Milvus recommends batch sizes of <strong>1,000–10,000 entities</strong> per insert call:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb33"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> embed_batch(image_paths: <span class="bu">list</span>, dim: <span class="bu">int</span> <span class="op">=</span> <span class="dv">512</span>) <span class="op">-&gt;</span> <span class="bu">list</span>:</span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Placeholder function — replace with your actual embedding model call.</span></span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a><span class="co">    Should return a list of normalized float vectors.</span></span>
<span id="cb33-8"><a href="#cb33-8" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb33-9"><a href="#cb33-9" aria-hidden="true" tabindex="-1"></a>    n <span class="op">=</span> <span class="bu">len</span>(image_paths)</span>
<span id="cb33-10"><a href="#cb33-10" aria-hidden="true" tabindex="-1"></a>    vectors <span class="op">=</span> np.random.randn(n, dim).astype(np.float32)</span>
<span id="cb33-11"><a href="#cb33-11" aria-hidden="true" tabindex="-1"></a>    norms <span class="op">=</span> np.linalg.norm(vectors, axis<span class="op">=</span><span class="dv">1</span>, keepdims<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb33-12"><a href="#cb33-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (vectors <span class="op">/</span> norms).tolist()</span>
<span id="cb33-13"><a href="#cb33-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-14"><a href="#cb33-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-15"><a href="#cb33-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> insert_images_in_batches(</span>
<span id="cb33-16"><a href="#cb33-16" aria-hidden="true" tabindex="-1"></a>    client: MilvusClient,</span>
<span id="cb33-17"><a href="#cb33-17" aria-hidden="true" tabindex="-1"></a>    collection_name: <span class="bu">str</span>,</span>
<span id="cb33-18"><a href="#cb33-18" aria-hidden="true" tabindex="-1"></a>    image_paths: <span class="bu">list</span>,</span>
<span id="cb33-19"><a href="#cb33-19" aria-hidden="true" tabindex="-1"></a>    labels: <span class="bu">list</span>,</span>
<span id="cb33-20"><a href="#cb33-20" aria-hidden="true" tabindex="-1"></a>    batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2000</span>,</span>
<span id="cb33-21"><a href="#cb33-21" aria-hidden="true" tabindex="-1"></a>    embedding_dim: <span class="bu">int</span> <span class="op">=</span> <span class="dv">512</span>,</span>
<span id="cb33-22"><a href="#cb33-22" aria-hidden="true" tabindex="-1"></a>):</span>
<span id="cb33-23"><a href="#cb33-23" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb33-24"><a href="#cb33-24" aria-hidden="true" tabindex="-1"></a><span class="co">    Extracts embeddings from images and inserts them into Milvus in batches.</span></span>
<span id="cb33-25"><a href="#cb33-25" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb33-26"><a href="#cb33-26" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="bu">len</span>(image_paths)</span>
<span id="cb33-27"><a href="#cb33-27" aria-hidden="true" tabindex="-1"></a>    inserted <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb33-28"><a href="#cb33-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-29"><a href="#cb33-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> start <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, total, batch_size):</span>
<span id="cb33-30"><a href="#cb33-30" aria-hidden="true" tabindex="-1"></a>        end <span class="op">=</span> <span class="bu">min</span>(start <span class="op">+</span> batch_size, total)</span>
<span id="cb33-31"><a href="#cb33-31" aria-hidden="true" tabindex="-1"></a>        batch_paths <span class="op">=</span> image_paths[start:end]</span>
<span id="cb33-32"><a href="#cb33-32" aria-hidden="true" tabindex="-1"></a>        batch_labels <span class="op">=</span> labels[start:end]</span>
<span id="cb33-33"><a href="#cb33-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-34"><a href="#cb33-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract embeddings for this batch</span></span>
<span id="cb33-35"><a href="#cb33-35" aria-hidden="true" tabindex="-1"></a>        batch_embeddings <span class="op">=</span> embed_batch(batch_paths, dim<span class="op">=</span>embedding_dim)</span>
<span id="cb33-36"><a href="#cb33-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-37"><a href="#cb33-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build the data list</span></span>
<span id="cb33-38"><a href="#cb33-38" aria-hidden="true" tabindex="-1"></a>        batch_data <span class="op">=</span> [</span>
<span id="cb33-39"><a href="#cb33-39" aria-hidden="true" tabindex="-1"></a>            {</span>
<span id="cb33-40"><a href="#cb33-40" aria-hidden="true" tabindex="-1"></a>                <span class="st">"embedding"</span>: batch_embeddings[i],</span>
<span id="cb33-41"><a href="#cb33-41" aria-hidden="true" tabindex="-1"></a>                <span class="st">"image_path"</span>: batch_paths[i],</span>
<span id="cb33-42"><a href="#cb33-42" aria-hidden="true" tabindex="-1"></a>                <span class="st">"label"</span>: batch_labels[i],</span>
<span id="cb33-43"><a href="#cb33-43" aria-hidden="true" tabindex="-1"></a>                <span class="st">"confidence"</span>: <span class="fl">1.0</span>,  <span class="co"># placeholder</span></span>
<span id="cb33-44"><a href="#cb33-44" aria-hidden="true" tabindex="-1"></a>                <span class="st">"timestamp"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb33-45"><a href="#cb33-45" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb33-46"><a href="#cb33-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(batch_paths))</span>
<span id="cb33-47"><a href="#cb33-47" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb33-48"><a href="#cb33-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-49"><a href="#cb33-49" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Insert</span></span>
<span id="cb33-50"><a href="#cb33-50" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> client.insert(</span>
<span id="cb33-51"><a href="#cb33-51" aria-hidden="true" tabindex="-1"></a>            collection_name<span class="op">=</span>collection_name,</span>
<span id="cb33-52"><a href="#cb33-52" aria-hidden="true" tabindex="-1"></a>            data<span class="op">=</span>batch_data,</span>
<span id="cb33-53"><a href="#cb33-53" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb33-54"><a href="#cb33-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-55"><a href="#cb33-55" aria-hidden="true" tabindex="-1"></a>        inserted <span class="op">+=</span> result[<span class="st">"insert_count"</span>]</span>
<span id="cb33-56"><a href="#cb33-56" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Progress: </span><span class="sc">{</span>inserted<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>total<span class="sc">}</span><span class="ss"> (</span><span class="sc">{</span><span class="dv">100</span><span class="op">*</span>inserted<span class="op">/</span>total<span class="sc">:.1f}</span><span class="ss">%)"</span>)</span>
<span id="cb33-57"><a href="#cb33-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-58"><a href="#cb33-58" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f" Done! Inserted </span><span class="sc">{</span>inserted<span class="sc">}</span><span class="ss"> entities total."</span>)</span>
<span id="cb33-59"><a href="#cb33-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> inserted</span>
<span id="cb33-60"><a href="#cb33-60" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-61"><a href="#cb33-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-62"><a href="#cb33-62" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb33-63"><a href="#cb33-63" aria-hidden="true" tabindex="-1"></a>image_paths <span class="op">=</span> [<span class="ss">f"/data/images/img_</span><span class="sc">{</span>i<span class="sc">:06d}</span><span class="ss">.jpg"</span> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100_000</span>)]</span>
<span id="cb33-64"><a href="#cb33-64" aria-hidden="true" tabindex="-1"></a>labels <span class="op">=</span> [<span class="st">"dog"</span> <span class="cf">if</span> i <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span> <span class="cf">else</span> <span class="st">"cat"</span> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100_000</span>)]</span>
<span id="cb33-65"><a href="#cb33-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-66"><a href="#cb33-66" aria-hidden="true" tabindex="-1"></a>insert_images_in_batches(</span>
<span id="cb33-67"><a href="#cb33-67" aria-hidden="true" tabindex="-1"></a>    client<span class="op">=</span>client,</span>
<span id="cb33-68"><a href="#cb33-68" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb33-69"><a href="#cb33-69" aria-hidden="true" tabindex="-1"></a>    image_paths<span class="op">=</span>image_paths,</span>
<span id="cb33-70"><a href="#cb33-70" aria-hidden="true" tabindex="-1"></a>    labels<span class="op">=</span>labels,</span>
<span id="cb33-71"><a href="#cb33-71" aria-hidden="true" tabindex="-1"></a>    batch_size<span class="op">=</span><span class="dv">2000</span>,</span>
<span id="cb33-72"><a href="#cb33-72" aria-hidden="true" tabindex="-1"></a>    embedding_dim<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb33-73"><a href="#cb33-73" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="upsert-insert-or-update" class="level3">
<h3 class="anchored" data-anchor-id="upsert-insert-or-update" id="upsert-insert-or-update">Upsert (Insert or Update)</h3>
<p>If an entity with the given primary key already exists, <code>upsert</code> replaces it; otherwise it inserts:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb34"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb34-1"><a href="#cb34-1" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> client.upsert(</span>
<span id="cb34-2"><a href="#cb34-2" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb34-3"><a href="#cb34-3" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[</span>
<span id="cb34-4"><a href="#cb34-4" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb34-5"><a href="#cb34-5" aria-hidden="true" tabindex="-1"></a>            <span class="st">"id"</span>: <span class="dv">42</span>,                  <span class="co"># specify the existing ID to update</span></span>
<span id="cb34-6"><a href="#cb34-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">"embedding"</span>: new_vector,</span>
<span id="cb34-7"><a href="#cb34-7" aria-hidden="true" tabindex="-1"></a>            <span class="st">"image_path"</span>: <span class="st">"/updated/path/image.jpg"</span>,</span>
<span id="cb34-8"><a href="#cb34-8" aria-hidden="true" tabindex="-1"></a>            <span class="st">"label"</span>: <span class="st">"updated_label"</span>,</span>
<span id="cb34-9"><a href="#cb34-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">"confidence"</span>: <span class="fl">0.99</span>,</span>
<span id="cb34-10"><a href="#cb34-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">"timestamp"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb34-11"><a href="#cb34-11" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb34-12"><a href="#cb34-12" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb34-13"><a href="#cb34-13" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="deleting-entities" class="level3">
<h3 class="anchored" data-anchor-id="deleting-entities" id="deleting-entities">Deleting Entities</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Delete by primary key</span></span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a>client.delete(</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a>    ids<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">42</span>],</span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Delete by filter expression</span></span>
<span id="cb35-8"><a href="#cb35-8" aria-hidden="true" tabindex="-1"></a>client.delete(</span>
<span id="cb35-9"><a href="#cb35-9" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb35-10"><a href="#cb35-10" aria-hidden="true" tabindex="-1"></a>    <span class="bu">filter</span><span class="op">=</span><span class="st">"label == 'cat'"</span>,</span>
<span id="cb35-11"><a href="#cb35-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="data-freshness-and-the-growing-segment-delay" class="level3">
<h3 class="anchored" data-anchor-id="data-freshness-and-the-growing-segment-delay" id="data-freshness-and-the-growing-segment-delay">Data Freshness and the “Growing Segment” Delay</h3>
<p>After inserting, your data enters a <strong>growing segment</strong> that is not yet indexed. Searches on unsealed segments use brute force, which is slower. For production use cases, you can force a flush:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb36"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb36-1"><a href="#cb36-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Force flush — seals all growing segments and ensures data is persisted</span></span>
<span id="cb36-2"><a href="#cb36-2" aria-hidden="true" tabindex="-1"></a>client.flush(collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>)</span></code></pre></div></div>
<p>After flushing, Milvus will asynchronously build the index on the new segments. For queries that need to see the absolute latest data without waiting for indexing, set <code>consistency_level="Strong"</code>:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[query_vector],</span>
<span id="cb37-4"><a href="#cb37-4" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb37-5"><a href="#cb37-5" aria-hidden="true" tabindex="-1"></a>    consistency_level<span class="op">=</span><span class="st">"Strong"</span>,  <span class="co"># waits for latest data to be visible</span></span>
<span id="cb37-6"><a href="#cb37-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>Consistency levels:</p>
<ul>
<li><code>"Strong"</code> — always sees the latest data; highest consistency, highest latency</li>
<li><code>"Bounded"</code> — sees data up to a few seconds old; good default for most CV use cases</li>
<li><code>"Eventually"</code> — fastest; may miss very recent inserts</li>
</ul>
<hr>
</section>
</section>
<section id="sec-indexes" class="level2">
<h2 class="anchored" data-anchor-id="sec-indexes" id="sec-indexes">9. Index Types and When to Use Each</h2>
<p>Choosing the right index is crucial for balancing <strong>search speed</strong>, <strong>recall accuracy</strong>, and <strong>memory usage</strong>. Here is a detailed breakdown of every major index type in Milvus.</p>
<section id="flat-exact-search-brute-force" class="level3">
<h3 class="anchored" data-anchor-id="flat-exact-search-brute-force" id="flat-exact-search-brute-force">FLAT (Exact Search / Brute Force)</h3>
<p><strong>How it works:</strong> Compares the query vector against every single vector in the collection. No approximation — always returns the true nearest neighbors.</p>
<p><strong>Parameters:</strong> None.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb38"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb38-1"><a href="#cb38-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb38-2"><a href="#cb38-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb38-3"><a href="#cb38-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"FLAT"</span>,</span>
<span id="cb38-4"><a href="#cb38-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb38-5"><a href="#cb38-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{},</span>
<span id="cb38-6"><a href="#cb38-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Pros:</strong> 100% recall (always finds the true nearest neighbors); no build time.</p>
<p><strong>Cons:</strong> O(N) query time — gets linearly slower as N grows; impractical for more than ~500K vectors.</p>
<p><strong>Best for:</strong> Exact search requirements, small datasets (&lt; 1M vectors), benchmarking other indexes.</p>
<hr>
</section>
<section id="ivf_flat-inverted-file-index" class="level3">
<h3 class="anchored" data-anchor-id="ivf_flat-inverted-file-index" id="ivf_flat-inverted-file-index">IVF_FLAT (Inverted File Index)</h3>
<p><strong>How it works:</strong> During index building, vectors are clustered into <code>nlist</code> Voronoi cells using k-means. Each vector is assigned to its nearest cluster centroid. At query time, the <code>nprobe</code> nearest cluster centroids are identified, and only the vectors in those clusters are searched.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"IVF_FLAT"</span>,</span>
<span id="cb39-4"><a href="#cb39-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb39-5"><a href="#cb39-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb39-6"><a href="#cb39-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"nlist"</span>: <span class="dv">1024</span>,  <span class="co"># number of clusters. Rule of thumb: sqrt(N) where N = dataset size</span></span>
<span id="cb39-7"><a href="#cb39-7" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb39-8"><a href="#cb39-8" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Search parameters</strong> (set at query time):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb40"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><a href="#cb40-1" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {</span>
<span id="cb40-2"><a href="#cb40-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"nprobe"</span>: <span class="dv">16</span>,  <span class="co"># number of clusters to search (higher = better recall, slower query)</span></span>
<span id="cb40-3"><a href="#cb40-3" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
<p><strong>nlist and nprobe tradeoffs:</strong></p>
<ul>
<li><code>nlist</code> = 1024, <code>nprobe</code> = 1: very fast, low recall</li>
<li><code>nlist</code> = 1024, <code>nprobe</code> = 64: slower, high recall</li>
<li><code>nprobe</code> should be between 1 and <code>nlist</code></li>
<li>Typical: <code>nprobe = nlist / 16</code> to <code>nlist / 8</code></li>
</ul>
<p><strong>Best for:</strong> Medium datasets (1M–100M vectors), balanced recall/speed.</p>
<hr>
</section>
<section id="ivf_sq8-ivf-scalar-quantization" class="level3">
<h3 class="anchored" data-anchor-id="ivf_sq8-ivf-scalar-quantization" id="ivf_sq8-ivf-scalar-quantization">IVF_SQ8 (IVF + Scalar Quantization)</h3>
<p><strong>How it works:</strong> Same as IVF_FLAT, but vectors are compressed from 32-bit floats to 8-bit integers (scalar quantization). Reduces memory by ~4x.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"IVF_SQ8"</span>,</span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{<span class="st">"nlist"</span>: <span class="dv">1024</span>},</span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Memory reduction:</strong> A 512-dim float32 vector takes 2048 bytes. IVF_SQ8 compresses it to 512 bytes.</p>
<p><strong>Recall impact:</strong> Slight degradation vs.&nbsp;IVF_FLAT (typically 0.5–2% lower recall@10).</p>
<p><strong>Best for:</strong> When you have memory constraints but can tolerate a small accuracy drop.</p>
<hr>
</section>
<section id="ivf_pq-ivf-product-quantization" class="level3">
<h3 class="anchored" data-anchor-id="ivf_pq-ivf-product-quantization" id="ivf_pq-ivf-product-quantization">IVF_PQ (IVF + Product Quantization)</h3>
<p><strong>How it works:</strong> Divides the vector into <code>m</code> sub-vectors and quantizes each sub-vector independently into one of <code>nbits</code>-bit codes. Extreme compression — a 512-dim float32 vector can be compressed to just 8–16 bytes.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb42"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb42-1"><a href="#cb42-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb42-2"><a href="#cb42-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb42-3"><a href="#cb42-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"IVF_PQ"</span>,</span>
<span id="cb42-4"><a href="#cb42-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb42-5"><a href="#cb42-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb42-6"><a href="#cb42-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"nlist"</span>: <span class="dv">1024</span>,</span>
<span id="cb42-7"><a href="#cb42-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"m"</span>: <span class="dv">8</span>,       <span class="co"># number of sub-quantizers (must divide evenly into dim)</span></span>
<span id="cb42-8"><a href="#cb42-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"nbits"</span>: <span class="dv">8</span>,   <span class="co"># bits per sub-quantizer code (typically 8)</span></span>
<span id="cb42-9"><a href="#cb42-9" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb42-10"><a href="#cb42-10" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Memory reduction:</strong> ~32x compression vs.&nbsp;FLAT (dramatic).</p>
<p><strong>Recall impact:</strong> Significant — typically 5–15% lower recall@10 than FLAT.</p>
<p><strong>Best for:</strong> Billion-scale datasets where memory is severely constrained.</p>
<hr>
</section>
<section id="hnsw-hierarchical-navigable-small-world" class="level3">
<h3 class="anchored" data-anchor-id="hnsw-hierarchical-navigable-small-world" id="hnsw-hierarchical-navigable-small-world">HNSW (Hierarchical Navigable Small World)</h3>
<p><strong>How it works:</strong> Builds a multi-layer graph where nodes are vectors and edges connect nearby vectors. The top layers are “highways” (sparse, long-range connections) and the bottom layer is a dense neighborhood graph. Search navigates from the top layer down, greedily following the nearest neighbor at each hop.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb43"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb43-1"><a href="#cb43-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb43-2"><a href="#cb43-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb43-3"><a href="#cb43-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"HNSW"</span>,</span>
<span id="cb43-4"><a href="#cb43-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb43-5"><a href="#cb43-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb43-6"><a href="#cb43-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"M"</span>: <span class="dv">16</span>,              <span class="co"># max connections per node per layer</span></span>
<span id="cb43-7"><a href="#cb43-7" aria-hidden="true" tabindex="-1"></a>                              <span class="co"># Range: 4–64. Higher = better recall, more memory, slower build</span></span>
<span id="cb43-8"><a href="#cb43-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"efConstruction"</span>: <span class="dv">200</span>, <span class="co"># search width during index construction</span></span>
<span id="cb43-9"><a href="#cb43-9" aria-hidden="true" tabindex="-1"></a>                              <span class="co"># Range: 8–512. Higher = better quality, slower build</span></span>
<span id="cb43-10"><a href="#cb43-10" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb43-11"><a href="#cb43-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Search parameters:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb44"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb44-1"><a href="#cb44-1" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {</span>
<span id="cb44-2"><a href="#cb44-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ef"</span>: <span class="dv">100</span>,  <span class="co"># search-time expansion factor (must be &gt;= limit/top_k)</span></span>
<span id="cb44-3"><a href="#cb44-3" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Higher = better recall, slower queries</span></span>
<span id="cb44-4"><a href="#cb44-4" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
<p><strong>Pros:</strong> Best-in-class query speed for high recall; no “cluster” artifacts; smooth recall curve.</p>
<p><strong>Cons:</strong> Higher memory footprint; longer index build time.</p>
<p><strong>Best for:</strong> Most production computer vision use cases — the best default choice.</p>
<p><strong>Typical <code>M</code> and <code>efConstruction</code> values:</strong></p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Use Case</th>
<th>M</th>
<th>efConstruction</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>High-speed, medium recall</td>
<td>8</td>
<td>100</td>
<td>Fastest queries</td>
</tr>
<tr class="even">
<td>Balanced (recommended)</td>
<td>16</td>
<td>200</td>
<td>Best starting point</td>
</tr>
<tr class="odd">
<td>High recall</td>
<td>32</td>
<td>400</td>
<td>Better accuracy, 2x memory</td>
</tr>
<tr class="even">
<td>Max recall</td>
<td>64</td>
<td>512</td>
<td>Use only if recall is critical</td>
</tr>
</tbody>
</table>
<hr>
</section>
<section id="scann" class="level3">
<h3 class="anchored" data-anchor-id="scann" id="scann">SCANN</h3>
<p>Google’s ScaNN algorithm, integrated into Milvus. Excellent recall/speed tradeoff, competitive with HNSW:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb45"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"SCANN"</span>,</span>
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb45-5"><a href="#cb45-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb45-6"><a href="#cb45-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"nlist"</span>: <span class="dv">1024</span>,</span>
<span id="cb45-7"><a href="#cb45-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"with_raw_data"</span>: <span class="va">True</span>,</span>
<span id="cb45-8"><a href="#cb45-8" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb45-9"><a href="#cb45-9" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
<section id="diskann-disk-based-ann" class="level3">
<h3 class="anchored" data-anchor-id="diskann-disk-based-ann" id="diskann-disk-based-ann">DiskANN (Disk-Based ANN)</h3>
<p><strong>How it works:</strong> Stores most of the index on disk (SSD) and reads it on demand. Enables searching datasets that are too large to fit in RAM.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb46"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb46-1"><a href="#cb46-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb46-2"><a href="#cb46-2" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb46-3"><a href="#cb46-3" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"DISKANN"</span>,</span>
<span id="cb46-4"><a href="#cb46-4" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb46-5"><a href="#cb46-5" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{},</span>
<span id="cb46-6"><a href="#cb46-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Requirements:</strong> Fast NVMe SSD. Query latency is higher than RAM-based indexes (5–30ms vs.&nbsp;1–5ms) but far better than brute-force.</p>
<p><strong>Best for:</strong> Truly massive datasets (100M+ vectors) on a single node.</p>
<hr>
</section>
<section id="gpu-indexes" class="level3">
<h3 class="anchored" data-anchor-id="gpu-indexes" id="gpu-indexes">GPU Indexes</h3>
<p>Available when your Milvus deployment has GPU-enabled nodes:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb47"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb47-1"><a href="#cb47-1" aria-hidden="true" tabindex="-1"></a><span class="co"># GPU-accelerated IVF_FLAT</span></span>
<span id="cb47-2"><a href="#cb47-2" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb47-3"><a href="#cb47-3" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb47-4"><a href="#cb47-4" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"GPU_IVF_FLAT"</span>,</span>
<span id="cb47-5"><a href="#cb47-5" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb47-6"><a href="#cb47-6" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{<span class="st">"nlist"</span>: <span class="dv">1024</span>},</span>
<span id="cb47-7"><a href="#cb47-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb47-8"><a href="#cb47-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-9"><a href="#cb47-9" aria-hidden="true" tabindex="-1"></a><span class="co"># GPU-accelerated CAGRA (graph-based, state of the art for GPU)</span></span>
<span id="cb47-10"><a href="#cb47-10" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb47-11"><a href="#cb47-11" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb47-12"><a href="#cb47-12" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"GPU_CAGRA"</span>,</span>
<span id="cb47-13"><a href="#cb47-13" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"L2"</span>,</span>
<span id="cb47-14"><a href="#cb47-14" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{</span>
<span id="cb47-15"><a href="#cb47-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">"intermediate_graph_degree"</span>: <span class="dv">64</span>,</span>
<span id="cb47-16"><a href="#cb47-16" aria-hidden="true" tabindex="-1"></a>        <span class="st">"graph_degree"</span>: <span class="dv">32</span>,</span>
<span id="cb47-17"><a href="#cb47-17" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb47-18"><a href="#cb47-18" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Speedups:</strong> GPU indexes can be 10–100x faster than CPU indexes for index building, and 5–20x faster for queries.</p>
<hr>
</section>
<section id="index-selection-summary" class="level3">
<h3 class="anchored" data-anchor-id="index-selection-summary" id="index-selection-summary">Index Selection Summary</h3>
<pre><code>Small dataset (&lt; 500K)?        → FLAT
Medium dataset, low memory?    → IVF_SQ8 or IVF_PQ
Medium dataset, good memory?   → IVF_FLAT or HNSW
Large dataset, best recall?    → HNSW (M=16, efConstruction=200)
Huge dataset, memory limited?  → DiskANN
GPU available?                 → GPU_CAGRA or GPU_IVF_FLAT</code></pre>
<hr>
</section>
</section>
<section id="sec-querying" class="level2">
<h2 class="anchored" data-anchor-id="sec-querying" id="sec-querying">10. Querying and Searching</h2>
<section id="vector-similarity-search" class="level3">
<h3 class="anchored" data-anchor-id="vector-similarity-search" id="vector-similarity-search">Vector Similarity Search</h3>
<p>The primary operation in Milvus — finding the <code>k</code> vectors most similar to a query vector:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb49"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb49-1"><a href="#cb49-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb49-2"><a href="#cb49-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-3"><a href="#cb49-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate a query embedding (in practice, this comes from embedding your query image)</span></span>
<span id="cb49-4"><a href="#cb49-4" aria-hidden="true" tabindex="-1"></a>query_vector <span class="op">=</span> np.random.randn(<span class="dv">512</span>).astype(np.float32)</span>
<span id="cb49-5"><a href="#cb49-5" aria-hidden="true" tabindex="-1"></a>query_vector <span class="op">=</span> (query_vector <span class="op">/</span> np.linalg.norm(query_vector)).tolist()</span>
<span id="cb49-6"><a href="#cb49-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-7"><a href="#cb49-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Perform the search</span></span>
<span id="cb49-8"><a href="#cb49-8" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb49-9"><a href="#cb49-9" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb49-10"><a href="#cb49-10" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[query_vector],          <span class="co"># list of query vectors (supports batch queries)</span></span>
<span id="cb49-11"><a href="#cb49-11" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,                     <span class="co"># return top 10 most similar</span></span>
<span id="cb49-12"><a href="#cb49-12" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"label"</span>, <span class="st">"confidence"</span>],</span>
<span id="cb49-13"><a href="#cb49-13" aria-hidden="true" tabindex="-1"></a>    search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">100</span>},    <span class="co"># HNSW-specific params (omit for FLAT)</span></span>
<span id="cb49-14"><a href="#cb49-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb49-15"><a href="#cb49-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-16"><a href="#cb49-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Results is a list of lists (one inner list per query vector)</span></span>
<span id="cb49-17"><a href="#cb49-17" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>]:</span>
<span id="cb49-18"><a href="#cb49-18" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"ID: </span><span class="sc">{</span>hit[<span class="st">'id'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb49-19"><a href="#cb49-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Distance: </span><span class="sc">{</span>hit[<span class="st">'distance'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb49-20"><a href="#cb49-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Image: </span><span class="sc">{</span>hit[<span class="st">'entity'</span>][<span class="st">'image_path'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb49-21"><a href="#cb49-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Label: </span><span class="sc">{</span>hit[<span class="st">'entity'</span>][<span class="st">'label'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb49-22"><a href="#cb49-22" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>()</span></code></pre></div></div>
</section>
<section id="batch-queries" class="level3">
<h3 class="anchored" data-anchor-id="batch-queries" id="batch-queries">Batch Queries</h3>
<p>Search for multiple query vectors in a single call — much more efficient than looping:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb50"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb50-1"><a href="#cb50-1" aria-hidden="true" tabindex="-1"></a>query_vectors <span class="op">=</span> [</span>
<span id="cb50-2"><a href="#cb50-2" aria-hidden="true" tabindex="-1"></a>    np.random.randn(<span class="dv">512</span>).astype(np.float32).tolist()</span>
<span id="cb50-3"><a href="#cb50-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>)</span>
<span id="cb50-4"><a href="#cb50-4" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb50-5"><a href="#cb50-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb50-6"><a href="#cb50-6" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb50-7"><a href="#cb50-7" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb50-8"><a href="#cb50-8" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>query_vectors,</span>
<span id="cb50-9"><a href="#cb50-9" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb50-10"><a href="#cb50-10" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"label"</span>],</span>
<span id="cb50-11"><a href="#cb50-11" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb50-12"><a href="#cb50-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb50-13"><a href="#cb50-13" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> query_idx, query_results <span class="kw">in</span> <span class="bu">enumerate</span>(results):</span>
<span id="cb50-14"><a href="#cb50-14" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Query </span><span class="sc">{</span>query_idx<span class="sc">}</span><span class="ss"> top results:"</span>)</span>
<span id="cb50-15"><a href="#cb50-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> hit <span class="kw">in</span> query_results:</span>
<span id="cb50-16"><a href="#cb50-16" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>hit[<span class="st">'entity'</span>][<span class="st">'image_path'</span>]<span class="sc">}</span><span class="ss"> (distance: </span><span class="sc">{</span>hit[<span class="st">'distance'</span>]<span class="sc">:.4f}</span><span class="ss">)"</span>)</span></code></pre></div></div>
</section>
<section id="filtered-vector-search" class="level3">
<h3 class="anchored" data-anchor-id="filtered-vector-search" id="filtered-vector-search">Filtered Vector Search</h3>
<p>Combine vector similarity with scalar attribute filtering:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb51"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb51-1"><a href="#cb51-1" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb51-2"><a href="#cb51-2" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb51-3"><a href="#cb51-3" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[query_vector],</span>
<span id="cb51-4"><a href="#cb51-4" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb51-5"><a href="#cb51-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">filter</span><span class="op">=</span><span class="st">"label == 'dog'"</span>,</span>
<span id="cb51-6"><a href="#cb51-6" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"label"</span>, <span class="st">"confidence"</span>],</span>
<span id="cb51-7"><a href="#cb51-7" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Filter expression syntax:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb52"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb52-1"><a href="#cb52-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Comparison operators</span></span>
<span id="cb52-2"><a href="#cb52-2" aria-hidden="true" tabindex="-1"></a><span class="co">"confidence &gt; 0.9"</span></span>
<span id="cb52-3"><a href="#cb52-3" aria-hidden="true" tabindex="-1"></a><span class="co">"timestamp &gt;= 1700000000000"</span></span>
<span id="cb52-4"><a href="#cb52-4" aria-hidden="true" tabindex="-1"></a><span class="co">"label != 'cat'"</span></span>
<span id="cb52-5"><a href="#cb52-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb52-6"><a href="#cb52-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Logical operators</span></span>
<span id="cb52-7"><a href="#cb52-7" aria-hidden="true" tabindex="-1"></a><span class="co">"label == 'dog' AND confidence &gt; 0.8"</span></span>
<span id="cb52-8"><a href="#cb52-8" aria-hidden="true" tabindex="-1"></a><span class="co">"label in ['dog', 'cat']"</span></span>
<span id="cb52-9"><a href="#cb52-9" aria-hidden="true" tabindex="-1"></a><span class="co">"NOT (label in ['background', 'unknown'])"</span></span>
<span id="cb52-10"><a href="#cb52-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb52-11"><a href="#cb52-11" aria-hidden="true" tabindex="-1"></a><span class="co"># String operations</span></span>
<span id="cb52-12"><a href="#cb52-12" aria-hidden="true" tabindex="-1"></a><span class="co">"image_path like '/dataset/train/%'"</span></span>
<span id="cb52-13"><a href="#cb52-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb52-14"><a href="#cb52-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Range</span></span>
<span id="cb52-15"><a href="#cb52-15" aria-hidden="true" tabindex="-1"></a><span class="co">"confidence &gt; 0.7 AND confidence &lt; 0.95"</span></span>
<span id="cb52-16"><a href="#cb52-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb52-17"><a href="#cb52-17" aria-hidden="true" tabindex="-1"></a><span class="co"># JSON field access</span></span>
<span id="cb52-18"><a href="#cb52-18" aria-hidden="true" tabindex="-1"></a><span class="co">"metadata['camera_id'] == 'cam_01'"</span></span></code></pre></div></div>
</section>
<section id="scalar-query-no-vector-search" class="level3">
<h3 class="anchored" data-anchor-id="scalar-query-no-vector-search" id="scalar-query-no-vector-search">Scalar Query (No Vector Search)</h3>
<p>Retrieve entities by scalar attributes only:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb53"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb53-1"><a href="#cb53-1" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.query(</span>
<span id="cb53-2"><a href="#cb53-2" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb53-3"><a href="#cb53-3" aria-hidden="true" tabindex="-1"></a>    <span class="bu">filter</span><span class="op">=</span><span class="st">"label == 'dog' AND confidence &gt; 0.9"</span>,</span>
<span id="cb53-4"><a href="#cb53-4" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"id"</span>, <span class="st">"image_path"</span>, <span class="st">"label"</span>, <span class="st">"confidence"</span>],</span>
<span id="cb53-5"><a href="#cb53-5" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">100</span>,</span>
<span id="cb53-6"><a href="#cb53-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="get-entity-by-id" class="level3">
<h3 class="anchored" data-anchor-id="get-entity-by-id" id="get-entity-by-id">Get Entity by ID</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb54"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb54-1"><a href="#cb54-1" aria-hidden="true" tabindex="-1"></a>entities <span class="op">=</span> client.get(</span>
<span id="cb54-2"><a href="#cb54-2" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb54-3"><a href="#cb54-3" aria-hidden="true" tabindex="-1"></a>    ids<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">42</span>],</span>
<span id="cb54-4"><a href="#cb54-4" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"label"</span>],</span>
<span id="cb54-5"><a href="#cb54-5" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-imgsimilarity" class="level2">
<h2 class="anchored" data-anchor-id="sec-imgsimilarity" id="sec-imgsimilarity">11. Use Case 1 — Image Similarity Search</h2>
<p>Image similarity search is the foundational computer vision use case for Milvus. Given a query image, find the most visually similar images in a large dataset. Applications include reverse image search, product visual search, duplicate detection, and content-based image retrieval (CBIR).</p>
<section id="architecture" class="level3">
<h3 class="anchored" data-anchor-id="architecture" id="architecture">Architecture</h3>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A["User uploads query image"]
    B["Embedding Model (ResNet, CLIP, DINOv2, etc.)"]
    C["query_vector 512-dim float array"]
    D["Milvus Search HNSW + COSINE"]
    E["Top-K similar image IDs + distances + metadata"]
    F["Fetch thumbnails from storage by path"]
    G["Return results to user"]

    A --&gt; B
    B --&gt; C
    C --&gt; D
    D --&gt; E
    E --&gt; F
    F --&gt; G

    style A fill:#E8F4FD,stroke:#4A90D9
    style B fill:#FEF9E7,stroke:#E8A838
    style C fill:#EAF7EA,stroke:#5BA85A
    style D fill:#F4ECF7,stroke:#8B6BB1
    style E fill:#FDEDEC,stroke:#D95F5F
    style F fill:#EBF5FB,stroke:#2980B9
    style G fill:#EAFAF1,stroke:#27AE60
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="full-implementation" class="level3">
<h3 class="anchored" data-anchor-id="full-implementation" id="full-implementation">Full Implementation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb55"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb55-1"><a href="#cb55-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, DataType</span>
<span id="cb55-2"><a href="#cb55-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb55-3"><a href="#cb55-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb55-4"><a href="#cb55-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-5"><a href="#cb55-5" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Configuration ────────────────────────────────────────────────────────────</span></span>
<span id="cb55-6"><a href="#cb55-6" aria-hidden="true" tabindex="-1"></a>COLLECTION_NAME <span class="op">=</span> <span class="st">"image_similarity"</span></span>
<span id="cb55-7"><a href="#cb55-7" aria-hidden="true" tabindex="-1"></a>EMBEDDING_DIM <span class="op">=</span> <span class="dv">512</span></span>
<span id="cb55-8"><a href="#cb55-8" aria-hidden="true" tabindex="-1"></a>MILVUS_URI <span class="op">=</span> <span class="st">"./image_similarity.db"</span></span>
<span id="cb55-9"><a href="#cb55-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-10"><a href="#cb55-10" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(MILVUS_URI)</span>
<span id="cb55-11"><a href="#cb55-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-12"><a href="#cb55-12" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Create Collection ────────────────────────────────────────────────────────</span></span>
<span id="cb55-13"><a href="#cb55-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_image_similarity_collection():</span>
<span id="cb55-14"><a href="#cb55-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> client.has_collection(COLLECTION_NAME):</span>
<span id="cb55-15"><a href="#cb55-15" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Collection '</span><span class="sc">{</span>COLLECTION_NAME<span class="sc">}</span><span class="ss">' already exists."</span>)</span>
<span id="cb55-16"><a href="#cb55-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span></span>
<span id="cb55-17"><a href="#cb55-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-18"><a href="#cb55-18" aria-hidden="true" tabindex="-1"></a>    schema <span class="op">=</span> client.create_schema(auto_id<span class="op">=</span><span class="va">True</span>, enable_dynamic_field<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb55-19"><a href="#cb55-19" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"id"</span>, DataType.INT64, is_primary<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb55-20"><a href="#cb55-20" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"embedding"</span>, DataType.FLOAT_VECTOR, dim<span class="op">=</span>EMBEDDING_DIM)</span>
<span id="cb55-21"><a href="#cb55-21" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"image_path"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb55-22"><a href="#cb55-22" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"category"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb55-23"><a href="#cb55-23" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"dataset_split"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb55-24"><a href="#cb55-24" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"width"</span>, DataType.INT32)</span>
<span id="cb55-25"><a href="#cb55-25" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"height"</span>, DataType.INT32)</span>
<span id="cb55-26"><a href="#cb55-26" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"file_size_bytes"</span>, DataType.INT64)</span>
<span id="cb55-27"><a href="#cb55-27" aria-hidden="true" tabindex="-1"></a>    schema.add_field(<span class="st">"inserted_at"</span>, DataType.INT64)</span>
<span id="cb55-28"><a href="#cb55-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-29"><a href="#cb55-29" aria-hidden="true" tabindex="-1"></a>    index_params <span class="op">=</span> client.prepare_index_params()</span>
<span id="cb55-30"><a href="#cb55-30" aria-hidden="true" tabindex="-1"></a>    index_params.add_index(</span>
<span id="cb55-31"><a href="#cb55-31" aria-hidden="true" tabindex="-1"></a>        field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb55-32"><a href="#cb55-32" aria-hidden="true" tabindex="-1"></a>        index_type<span class="op">=</span><span class="st">"HNSW"</span>,</span>
<span id="cb55-33"><a href="#cb55-33" aria-hidden="true" tabindex="-1"></a>        metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb55-34"><a href="#cb55-34" aria-hidden="true" tabindex="-1"></a>        params<span class="op">=</span>{<span class="st">"M"</span>: <span class="dv">16</span>, <span class="st">"efConstruction"</span>: <span class="dv">200</span>},</span>
<span id="cb55-35"><a href="#cb55-35" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb55-36"><a href="#cb55-36" aria-hidden="true" tabindex="-1"></a>    index_params.add_index(field_name<span class="op">=</span><span class="st">"category"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb55-37"><a href="#cb55-37" aria-hidden="true" tabindex="-1"></a>    index_params.add_index(field_name<span class="op">=</span><span class="st">"dataset_split"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb55-38"><a href="#cb55-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-39"><a href="#cb55-39" aria-hidden="true" tabindex="-1"></a>    client.create_collection(</span>
<span id="cb55-40"><a href="#cb55-40" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb55-41"><a href="#cb55-41" aria-hidden="true" tabindex="-1"></a>        schema<span class="op">=</span>schema,</span>
<span id="cb55-42"><a href="#cb55-42" aria-hidden="true" tabindex="-1"></a>        index_params<span class="op">=</span>index_params,</span>
<span id="cb55-43"><a href="#cb55-43" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb55-44"><a href="#cb55-44" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Created collection '</span><span class="sc">{</span>COLLECTION_NAME<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb55-45"><a href="#cb55-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-46"><a href="#cb55-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-47"><a href="#cb55-47" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Embedding Function (Model-Agnostic Placeholder) ─────────────────────────</span></span>
<span id="cb55-48"><a href="#cb55-48" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_embedding(image_path: <span class="bu">str</span>) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb55-49"><a href="#cb55-49" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb55-50"><a href="#cb55-50" aria-hidden="true" tabindex="-1"></a><span class="co">    Replace this function with your actual embedding model call.</span></span>
<span id="cb55-51"><a href="#cb55-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-52"><a href="#cb55-52" aria-hidden="true" tabindex="-1"></a><span class="co">    Example with torchvision (ResNet-50):</span></span>
<span id="cb55-53"><a href="#cb55-53" aria-hidden="true" tabindex="-1"></a><span class="co">        from torchvision import models, transforms</span></span>
<span id="cb55-54"><a href="#cb55-54" aria-hidden="true" tabindex="-1"></a><span class="co">        from PIL import Image</span></span>
<span id="cb55-55"><a href="#cb55-55" aria-hidden="true" tabindex="-1"></a><span class="co">        import torch</span></span>
<span id="cb55-56"><a href="#cb55-56" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-57"><a href="#cb55-57" aria-hidden="true" tabindex="-1"></a><span class="co">        model = models.resnet50(pretrained=True)</span></span>
<span id="cb55-58"><a href="#cb55-58" aria-hidden="true" tabindex="-1"></a><span class="co">        model.eval()</span></span>
<span id="cb55-59"><a href="#cb55-59" aria-hidden="true" tabindex="-1"></a><span class="co">        embedding_model = torch.nn.Sequential(*list(model.children())[:-1])</span></span>
<span id="cb55-60"><a href="#cb55-60" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-61"><a href="#cb55-61" aria-hidden="true" tabindex="-1"></a><span class="co">        transform = transforms.Compose([</span></span>
<span id="cb55-62"><a href="#cb55-62" aria-hidden="true" tabindex="-1"></a><span class="co">            transforms.Resize(256),</span></span>
<span id="cb55-63"><a href="#cb55-63" aria-hidden="true" tabindex="-1"></a><span class="co">            transforms.CenterCrop(224),</span></span>
<span id="cb55-64"><a href="#cb55-64" aria-hidden="true" tabindex="-1"></a><span class="co">            transforms.ToTensor(),</span></span>
<span id="cb55-65"><a href="#cb55-65" aria-hidden="true" tabindex="-1"></a><span class="co">            transforms.Normalize(mean=[0.485, 0.456, 0.406],</span></span>
<span id="cb55-66"><a href="#cb55-66" aria-hidden="true" tabindex="-1"></a><span class="co">                                 std=[0.229, 0.224, 0.225]),</span></span>
<span id="cb55-67"><a href="#cb55-67" aria-hidden="true" tabindex="-1"></a><span class="co">        ])</span></span>
<span id="cb55-68"><a href="#cb55-68" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-69"><a href="#cb55-69" aria-hidden="true" tabindex="-1"></a><span class="co">        img = Image.open(image_path).convert("RGB")</span></span>
<span id="cb55-70"><a href="#cb55-70" aria-hidden="true" tabindex="-1"></a><span class="co">        tensor = transform(img).unsqueeze(0)</span></span>
<span id="cb55-71"><a href="#cb55-71" aria-hidden="true" tabindex="-1"></a><span class="co">        with torch.no_grad():</span></span>
<span id="cb55-72"><a href="#cb55-72" aria-hidden="true" tabindex="-1"></a><span class="co">            embedding = embedding_model(tensor).squeeze().numpy()</span></span>
<span id="cb55-73"><a href="#cb55-73" aria-hidden="true" tabindex="-1"></a><span class="co">        embedding = embedding / np.linalg.norm(embedding)</span></span>
<span id="cb55-74"><a href="#cb55-74" aria-hidden="true" tabindex="-1"></a><span class="co">        return embedding</span></span>
<span id="cb55-75"><a href="#cb55-75" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb55-76"><a href="#cb55-76" aria-hidden="true" tabindex="-1"></a>    vec <span class="op">=</span> np.random.randn(EMBEDDING_DIM).astype(np.float32)</span>
<span id="cb55-77"><a href="#cb55-77" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vec <span class="op">/</span> np.linalg.norm(vec)</span>
<span id="cb55-78"><a href="#cb55-78" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-79"><a href="#cb55-79" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-80"><a href="#cb55-80" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Ingest Images ────────────────────────────────────────────────────────────</span></span>
<span id="cb55-81"><a href="#cb55-81" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> ingest_images(image_records: <span class="bu">list</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2000</span>):</span>
<span id="cb55-82"><a href="#cb55-82" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="bu">len</span>(image_records)</span>
<span id="cb55-83"><a href="#cb55-83" aria-hidden="true" tabindex="-1"></a>    inserted <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb55-84"><a href="#cb55-84" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-85"><a href="#cb55-85" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> start <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, total, batch_size):</span>
<span id="cb55-86"><a href="#cb55-86" aria-hidden="true" tabindex="-1"></a>        batch <span class="op">=</span> image_records[start : start <span class="op">+</span> batch_size]</span>
<span id="cb55-87"><a href="#cb55-87" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-88"><a href="#cb55-88" aria-hidden="true" tabindex="-1"></a>        data <span class="op">=</span> []</span>
<span id="cb55-89"><a href="#cb55-89" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> record <span class="kw">in</span> batch:</span>
<span id="cb55-90"><a href="#cb55-90" aria-hidden="true" tabindex="-1"></a>            embedding <span class="op">=</span> extract_embedding(record[<span class="st">"path"</span>])</span>
<span id="cb55-91"><a href="#cb55-91" aria-hidden="true" tabindex="-1"></a>            data.append({</span>
<span id="cb55-92"><a href="#cb55-92" aria-hidden="true" tabindex="-1"></a>                <span class="st">"embedding"</span>: embedding.tolist(),</span>
<span id="cb55-93"><a href="#cb55-93" aria-hidden="true" tabindex="-1"></a>                <span class="st">"image_path"</span>: record[<span class="st">"path"</span>],</span>
<span id="cb55-94"><a href="#cb55-94" aria-hidden="true" tabindex="-1"></a>                <span class="st">"category"</span>: record[<span class="st">"category"</span>],</span>
<span id="cb55-95"><a href="#cb55-95" aria-hidden="true" tabindex="-1"></a>                <span class="st">"dataset_split"</span>: record[<span class="st">"split"</span>],</span>
<span id="cb55-96"><a href="#cb55-96" aria-hidden="true" tabindex="-1"></a>                <span class="st">"width"</span>: record[<span class="st">"width"</span>],</span>
<span id="cb55-97"><a href="#cb55-97" aria-hidden="true" tabindex="-1"></a>                <span class="st">"height"</span>: record[<span class="st">"height"</span>],</span>
<span id="cb55-98"><a href="#cb55-98" aria-hidden="true" tabindex="-1"></a>                <span class="st">"file_size_bytes"</span>: record[<span class="st">"size"</span>],</span>
<span id="cb55-99"><a href="#cb55-99" aria-hidden="true" tabindex="-1"></a>                <span class="st">"inserted_at"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb55-100"><a href="#cb55-100" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb55-101"><a href="#cb55-101" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-102"><a href="#cb55-102" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> client.insert(collection_name<span class="op">=</span>COLLECTION_NAME, data<span class="op">=</span>data)</span>
<span id="cb55-103"><a href="#cb55-103" aria-hidden="true" tabindex="-1"></a>        inserted <span class="op">+=</span> result[<span class="st">"insert_count"</span>]</span>
<span id="cb55-104"><a href="#cb55-104" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Ingested </span><span class="sc">{</span>inserted<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>total<span class="sc">}</span><span class="ss"> images"</span>)</span>
<span id="cb55-105"><a href="#cb55-105" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-106"><a href="#cb55-106" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> inserted</span>
<span id="cb55-107"><a href="#cb55-107" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-108"><a href="#cb55-108" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-109"><a href="#cb55-109" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Search ───────────────────────────────────────────────────────────────────</span></span>
<span id="cb55-110"><a href="#cb55-110" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_similar_images(</span>
<span id="cb55-111"><a href="#cb55-111" aria-hidden="true" tabindex="-1"></a>    query_image_path: <span class="bu">str</span>,</span>
<span id="cb55-112"><a href="#cb55-112" aria-hidden="true" tabindex="-1"></a>    top_k: <span class="bu">int</span> <span class="op">=</span> <span class="dv">10</span>,</span>
<span id="cb55-113"><a href="#cb55-113" aria-hidden="true" tabindex="-1"></a>    category_filter: <span class="bu">str</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb55-114"><a href="#cb55-114" aria-hidden="true" tabindex="-1"></a>    min_dimension: <span class="bu">int</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb55-115"><a href="#cb55-115" aria-hidden="true" tabindex="-1"></a>) <span class="op">-&gt;</span> <span class="bu">list</span>:</span>
<span id="cb55-116"><a href="#cb55-116" aria-hidden="true" tabindex="-1"></a>    query_embedding <span class="op">=</span> extract_embedding(query_image_path)</span>
<span id="cb55-117"><a href="#cb55-117" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-118"><a href="#cb55-118" aria-hidden="true" tabindex="-1"></a>    filters <span class="op">=</span> []</span>
<span id="cb55-119"><a href="#cb55-119" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> category_filter:</span>
<span id="cb55-120"><a href="#cb55-120" aria-hidden="true" tabindex="-1"></a>        filters.append(<span class="ss">f"category == '</span><span class="sc">{</span>category_filter<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb55-121"><a href="#cb55-121" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> min_dimension:</span>
<span id="cb55-122"><a href="#cb55-122" aria-hidden="true" tabindex="-1"></a>        filters.append(<span class="ss">f"width &gt;= </span><span class="sc">{</span>min_dimension<span class="sc">}</span><span class="ss"> AND height &gt;= </span><span class="sc">{</span>min_dimension<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb55-123"><a href="#cb55-123" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-124"><a href="#cb55-124" aria-hidden="true" tabindex="-1"></a>    filter_expr <span class="op">=</span> <span class="st">" AND "</span>.join(filters) <span class="cf">if</span> filters <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb55-125"><a href="#cb55-125" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-126"><a href="#cb55-126" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> client.search(</span>
<span id="cb55-127"><a href="#cb55-127" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb55-128"><a href="#cb55-128" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[query_embedding.tolist()],</span>
<span id="cb55-129"><a href="#cb55-129" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>top_k,</span>
<span id="cb55-130"><a href="#cb55-130" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span>filter_expr,</span>
<span id="cb55-131"><a href="#cb55-131" aria-hidden="true" tabindex="-1"></a>        search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="bu">max</span>(top_k <span class="op">*</span> <span class="dv">10</span>, <span class="dv">100</span>)},</span>
<span id="cb55-132"><a href="#cb55-132" aria-hidden="true" tabindex="-1"></a>        output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"category"</span>, <span class="st">"width"</span>, <span class="st">"height"</span>],</span>
<span id="cb55-133"><a href="#cb55-133" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb55-134"><a href="#cb55-134" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-135"><a href="#cb55-135" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [</span>
<span id="cb55-136"><a href="#cb55-136" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb55-137"><a href="#cb55-137" aria-hidden="true" tabindex="-1"></a>            <span class="st">"id"</span>: hit[<span class="st">"id"</span>],</span>
<span id="cb55-138"><a href="#cb55-138" aria-hidden="true" tabindex="-1"></a>            <span class="st">"image_path"</span>: hit[<span class="st">"entity"</span>][<span class="st">"image_path"</span>],</span>
<span id="cb55-139"><a href="#cb55-139" aria-hidden="true" tabindex="-1"></a>            <span class="st">"similarity"</span>: hit[<span class="st">"distance"</span>],</span>
<span id="cb55-140"><a href="#cb55-140" aria-hidden="true" tabindex="-1"></a>            <span class="st">"category"</span>: hit[<span class="st">"entity"</span>][<span class="st">"category"</span>],</span>
<span id="cb55-141"><a href="#cb55-141" aria-hidden="true" tabindex="-1"></a>            <span class="st">"width"</span>: hit[<span class="st">"entity"</span>][<span class="st">"width"</span>],</span>
<span id="cb55-142"><a href="#cb55-142" aria-hidden="true" tabindex="-1"></a>            <span class="st">"height"</span>: hit[<span class="st">"entity"</span>][<span class="st">"height"</span>],</span>
<span id="cb55-143"><a href="#cb55-143" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb55-144"><a href="#cb55-144" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>]</span>
<span id="cb55-145"><a href="#cb55-145" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb55-146"><a href="#cb55-146" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-147"><a href="#cb55-147" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-148"><a href="#cb55-148" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Duplicate Detection ──────────────────────────────────────────────────────</span></span>
<span id="cb55-149"><a href="#cb55-149" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_near_duplicates(similarity_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.98</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100</span>):</span>
<span id="cb55-150"><a href="#cb55-150" aria-hidden="true" tabindex="-1"></a>    duplicates <span class="op">=</span> []</span>
<span id="cb55-151"><a href="#cb55-151" aria-hidden="true" tabindex="-1"></a>    offset <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb55-152"><a href="#cb55-152" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-153"><a href="#cb55-153" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb55-154"><a href="#cb55-154" aria-hidden="true" tabindex="-1"></a>        entities <span class="op">=</span> client.query(</span>
<span id="cb55-155"><a href="#cb55-155" aria-hidden="true" tabindex="-1"></a>            collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb55-156"><a href="#cb55-156" aria-hidden="true" tabindex="-1"></a>            <span class="bu">filter</span><span class="op">=</span><span class="st">"id &gt; 0"</span>,</span>
<span id="cb55-157"><a href="#cb55-157" aria-hidden="true" tabindex="-1"></a>            output_fields<span class="op">=</span>[<span class="st">"id"</span>, <span class="st">"embedding"</span>, <span class="st">"image_path"</span>],</span>
<span id="cb55-158"><a href="#cb55-158" aria-hidden="true" tabindex="-1"></a>            limit<span class="op">=</span>batch_size,</span>
<span id="cb55-159"><a href="#cb55-159" aria-hidden="true" tabindex="-1"></a>            offset<span class="op">=</span>offset,</span>
<span id="cb55-160"><a href="#cb55-160" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb55-161"><a href="#cb55-161" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-162"><a href="#cb55-162" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> entities:</span>
<span id="cb55-163"><a href="#cb55-163" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb55-164"><a href="#cb55-164" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-165"><a href="#cb55-165" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> entity <span class="kw">in</span> entities:</span>
<span id="cb55-166"><a href="#cb55-166" aria-hidden="true" tabindex="-1"></a>            results <span class="op">=</span> client.search(</span>
<span id="cb55-167"><a href="#cb55-167" aria-hidden="true" tabindex="-1"></a>                collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb55-168"><a href="#cb55-168" aria-hidden="true" tabindex="-1"></a>                data<span class="op">=</span>[entity[<span class="st">"embedding"</span>]],</span>
<span id="cb55-169"><a href="#cb55-169" aria-hidden="true" tabindex="-1"></a>                limit<span class="op">=</span><span class="dv">5</span>,</span>
<span id="cb55-170"><a href="#cb55-170" aria-hidden="true" tabindex="-1"></a>                search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">50</span>},</span>
<span id="cb55-171"><a href="#cb55-171" aria-hidden="true" tabindex="-1"></a>                output_fields<span class="op">=</span>[<span class="st">"image_path"</span>],</span>
<span id="cb55-172"><a href="#cb55-172" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb55-173"><a href="#cb55-173" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-174"><a href="#cb55-174" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>][<span class="dv">1</span>:]:</span>
<span id="cb55-175"><a href="#cb55-175" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> hit[<span class="st">"distance"</span>] <span class="op">&gt;=</span> similarity_threshold:</span>
<span id="cb55-176"><a href="#cb55-176" aria-hidden="true" tabindex="-1"></a>                    pair <span class="op">=</span> <span class="bu">tuple</span>(<span class="bu">sorted</span>([entity[<span class="st">"id"</span>], hit[<span class="st">"id"</span>]]))</span>
<span id="cb55-177"><a href="#cb55-177" aria-hidden="true" tabindex="-1"></a>                    entry <span class="op">=</span> (pair[<span class="dv">0</span>], pair[<span class="dv">1</span>], hit[<span class="st">"distance"</span>])</span>
<span id="cb55-178"><a href="#cb55-178" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> entry <span class="kw">not</span> <span class="kw">in</span> duplicates:</span>
<span id="cb55-179"><a href="#cb55-179" aria-hidden="true" tabindex="-1"></a>                        duplicates.append(entry)</span>
<span id="cb55-180"><a href="#cb55-180" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-181"><a href="#cb55-181" aria-hidden="true" tabindex="-1"></a>        offset <span class="op">+=</span> batch_size</span>
<span id="cb55-182"><a href="#cb55-182" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-183"><a href="#cb55-183" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> duplicates</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-facerecog" class="level2">
<h2 class="anchored" data-anchor-id="sec-facerecog" id="sec-facerecog">12. Use Case 2 — Face Recognition</h2>
<p>Face recognition is one of the highest-stakes computer vision applications. The core workflow involves face detection, alignment, embedding extraction, storage in Milvus, and similarity search for identity lookup.</p>
<section id="important-notes-on-face-recognition-ethics-and-legality" class="level3">
<h3 class="anchored" data-anchor-id="important-notes-on-face-recognition-ethics-and-legality" id="important-notes-on-face-recognition-ethics-and-legality">Important Notes on Face Recognition Ethics and Legality</h3>
<p>Face recognition systems raise serious privacy concerns. Before building and deploying such a system:</p>
<ul>
<li>Ensure you have <strong>explicit consent</strong> from individuals whose faces you are storing</li>
<li>Comply with applicable regulations (GDPR, CCPA, BIPA, etc.)</li>
<li>Implement appropriate data retention and deletion policies</li>
<li>Consider the risk of false positives in high-stakes applications (security, law enforcement)</li>
</ul>
</section>
<section id="identity-schema-design" class="level3">
<h3 class="anchored" data-anchor-id="identity-schema-design" id="identity-schema-design">Identity Schema Design</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb56"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb56-1"><a href="#cb56-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, DataType</span>
<span id="cb56-2"><a href="#cb56-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb56-3"><a href="#cb56-3" aria-hidden="true" tabindex="-1"></a>COLLECTION_NAME <span class="op">=</span> <span class="st">"face_identities"</span></span>
<span id="cb56-4"><a href="#cb56-4" aria-hidden="true" tabindex="-1"></a>FACE_EMBEDDING_DIM <span class="op">=</span> <span class="dv">512</span></span>
<span id="cb56-5"><a href="#cb56-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb56-6"><a href="#cb56-6" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./face_db.db"</span>)</span>
<span id="cb56-7"><a href="#cb56-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb56-8"><a href="#cb56-8" aria-hidden="true" tabindex="-1"></a>schema <span class="op">=</span> client.create_schema(auto_id<span class="op">=</span><span class="va">True</span>, enable_dynamic_field<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb56-9"><a href="#cb56-9" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"id"</span>, DataType.INT64, is_primary<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb56-10"><a href="#cb56-10" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"embedding"</span>, DataType.FLOAT_VECTOR, dim<span class="op">=</span>FACE_EMBEDDING_DIM)</span>
<span id="cb56-11"><a href="#cb56-11" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"person_id"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb56-12"><a href="#cb56-12" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"person_name"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb56-13"><a href="#cb56-13" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"confidence_score"</span>, DataType.FLOAT)</span>
<span id="cb56-14"><a href="#cb56-14" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"source_image"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb56-15"><a href="#cb56-15" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"enrolled_at"</span>, DataType.INT64)</span>
<span id="cb56-16"><a href="#cb56-16" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"is_active"</span>, DataType.BOOL)</span>
<span id="cb56-17"><a href="#cb56-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb56-18"><a href="#cb56-18" aria-hidden="true" tabindex="-1"></a>index_params <span class="op">=</span> client.prepare_index_params()</span>
<span id="cb56-19"><a href="#cb56-19" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb56-20"><a href="#cb56-20" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb56-21"><a href="#cb56-21" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"HNSW"</span>,</span>
<span id="cb56-22"><a href="#cb56-22" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"IP"</span>,</span>
<span id="cb56-23"><a href="#cb56-23" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{<span class="st">"M"</span>: <span class="dv">16</span>, <span class="st">"efConstruction"</span>: <span class="dv">200</span>},</span>
<span id="cb56-24"><a href="#cb56-24" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb56-25"><a href="#cb56-25" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"person_id"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb56-26"><a href="#cb56-26" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"is_active"</span>, index_type<span class="op">=</span><span class="st">"BITMAP"</span>)</span>
<span id="cb56-27"><a href="#cb56-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb56-28"><a href="#cb56-28" aria-hidden="true" tabindex="-1"></a>client.create_collection(</span>
<span id="cb56-29"><a href="#cb56-29" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb56-30"><a href="#cb56-30" aria-hidden="true" tabindex="-1"></a>    schema<span class="op">=</span>schema,</span>
<span id="cb56-31"><a href="#cb56-31" aria-hidden="true" tabindex="-1"></a>    index_params<span class="op">=</span>index_params,</span>
<span id="cb56-32"><a href="#cb56-32" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="enrolling-identities" class="level3">
<h3 class="anchored" data-anchor-id="enrolling-identities" id="enrolling-identities">Enrolling Identities</h3>
<p>A person may have multiple enrolled face embeddings. Storing multiple embeddings per person improves recognition robustness:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb57"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb57-1"><a href="#cb57-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb57-2"><a href="#cb57-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb57-3"><a href="#cb57-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-4"><a href="#cb57-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_face_embedding(aligned_face_image_path: <span class="bu">str</span>) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb57-5"><a href="#cb57-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb57-6"><a href="#cb57-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Placeholder — replace with your actual face recognition model.</span></span>
<span id="cb57-7"><a href="#cb57-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-8"><a href="#cb57-8" aria-hidden="true" tabindex="-1"></a><span class="co">    Example frameworks:</span></span>
<span id="cb57-9"><a href="#cb57-9" aria-hidden="true" tabindex="-1"></a><span class="co">    - InsightFace (ArcFace): pip install insightface</span></span>
<span id="cb57-10"><a href="#cb57-10" aria-hidden="true" tabindex="-1"></a><span class="co">    - deepface: pip install deepface</span></span>
<span id="cb57-11"><a href="#cb57-11" aria-hidden="true" tabindex="-1"></a><span class="co">    - facenet-pytorch: pip install facenet-pytorch</span></span>
<span id="cb57-12"><a href="#cb57-12" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb57-13"><a href="#cb57-13" aria-hidden="true" tabindex="-1"></a>    vec <span class="op">=</span> np.random.randn(FACE_EMBEDDING_DIM).astype(np.float32)</span>
<span id="cb57-14"><a href="#cb57-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vec <span class="op">/</span> np.linalg.norm(vec)</span>
<span id="cb57-15"><a href="#cb57-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-16"><a href="#cb57-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-17"><a href="#cb57-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> assess_face_quality(image_path: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">float</span>:</span>
<span id="cb57-18"><a href="#cb57-18" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb57-19"><a href="#cb57-19" aria-hidden="true" tabindex="-1"></a><span class="co">    Estimate the quality of a face image for enrollment (0.0–1.0).</span></span>
<span id="cb57-20"><a href="#cb57-20" aria-hidden="true" tabindex="-1"></a><span class="co">    In practice, use a dedicated face quality assessment model.</span></span>
<span id="cb57-21"><a href="#cb57-21" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb57-22"><a href="#cb57-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="fl">0.95</span>  <span class="co"># placeholder</span></span>
<span id="cb57-23"><a href="#cb57-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-24"><a href="#cb57-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-25"><a href="#cb57-25" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> enroll_person(</span>
<span id="cb57-26"><a href="#cb57-26" aria-hidden="true" tabindex="-1"></a>    person_id: <span class="bu">str</span>,</span>
<span id="cb57-27"><a href="#cb57-27" aria-hidden="true" tabindex="-1"></a>    person_name: <span class="bu">str</span>,</span>
<span id="cb57-28"><a href="#cb57-28" aria-hidden="true" tabindex="-1"></a>    face_image_paths: <span class="bu">list</span>,</span>
<span id="cb57-29"><a href="#cb57-29" aria-hidden="true" tabindex="-1"></a>    min_quality_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.7</span>,</span>
<span id="cb57-30"><a href="#cb57-30" aria-hidden="true" tabindex="-1"></a>):</span>
<span id="cb57-31"><a href="#cb57-31" aria-hidden="true" tabindex="-1"></a>    enrolled_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb57-32"><a href="#cb57-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-33"><a href="#cb57-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> image_path <span class="kw">in</span> face_image_paths:</span>
<span id="cb57-34"><a href="#cb57-34" aria-hidden="true" tabindex="-1"></a>        quality <span class="op">=</span> assess_face_quality(image_path)</span>
<span id="cb57-35"><a href="#cb57-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-36"><a href="#cb57-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> quality <span class="op">&lt;</span> min_quality_threshold:</span>
<span id="cb57-37"><a href="#cb57-37" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Skipping </span><span class="sc">{</span>image_path<span class="sc">}</span><span class="ss"> — quality </span><span class="sc">{</span>quality<span class="sc">:.2f}</span><span class="ss"> below threshold"</span>)</span>
<span id="cb57-38"><a href="#cb57-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb57-39"><a href="#cb57-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-40"><a href="#cb57-40" aria-hidden="true" tabindex="-1"></a>        embedding <span class="op">=</span> extract_face_embedding(image_path)</span>
<span id="cb57-41"><a href="#cb57-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-42"><a href="#cb57-42" aria-hidden="true" tabindex="-1"></a>        client.insert(</span>
<span id="cb57-43"><a href="#cb57-43" aria-hidden="true" tabindex="-1"></a>            collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb57-44"><a href="#cb57-44" aria-hidden="true" tabindex="-1"></a>            data<span class="op">=</span>[{</span>
<span id="cb57-45"><a href="#cb57-45" aria-hidden="true" tabindex="-1"></a>                <span class="st">"embedding"</span>: embedding.tolist(),</span>
<span id="cb57-46"><a href="#cb57-46" aria-hidden="true" tabindex="-1"></a>                <span class="st">"person_id"</span>: person_id,</span>
<span id="cb57-47"><a href="#cb57-47" aria-hidden="true" tabindex="-1"></a>                <span class="st">"person_name"</span>: person_name,</span>
<span id="cb57-48"><a href="#cb57-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">"confidence_score"</span>: quality,</span>
<span id="cb57-49"><a href="#cb57-49" aria-hidden="true" tabindex="-1"></a>                <span class="st">"source_image"</span>: image_path,</span>
<span id="cb57-50"><a href="#cb57-50" aria-hidden="true" tabindex="-1"></a>                <span class="st">"enrolled_at"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb57-51"><a href="#cb57-51" aria-hidden="true" tabindex="-1"></a>                <span class="st">"is_active"</span>: <span class="va">True</span>,</span>
<span id="cb57-52"><a href="#cb57-52" aria-hidden="true" tabindex="-1"></a>            }]</span>
<span id="cb57-53"><a href="#cb57-53" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb57-54"><a href="#cb57-54" aria-hidden="true" tabindex="-1"></a>        enrolled_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb57-55"><a href="#cb57-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-56"><a href="#cb57-56" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Enrolled </span><span class="sc">{</span>enrolled_count<span class="sc">}</span><span class="ss"> faces for </span><span class="sc">{</span>person_name<span class="sc">}</span><span class="ss"> (</span><span class="sc">{</span>person_id<span class="sc">}</span><span class="ss">)"</span>)</span>
<span id="cb57-57"><a href="#cb57-57" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> enrolled_count</span></code></pre></div></div>
</section>
<section id="recognition-1n-search" class="level3">
<h3 class="anchored" data-anchor-id="recognition-1n-search" id="recognition-1n-search">Recognition (1:N Search)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb58"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb58-1"><a href="#cb58-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> recognize_face(</span>
<span id="cb58-2"><a href="#cb58-2" aria-hidden="true" tabindex="-1"></a>    query_face_path: <span class="bu">str</span>,</span>
<span id="cb58-3"><a href="#cb58-3" aria-hidden="true" tabindex="-1"></a>    top_k: <span class="bu">int</span> <span class="op">=</span> <span class="dv">5</span>,</span>
<span id="cb58-4"><a href="#cb58-4" aria-hidden="true" tabindex="-1"></a>    similarity_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.7</span>,</span>
<span id="cb58-5"><a href="#cb58-5" aria-hidden="true" tabindex="-1"></a>) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb58-6"><a href="#cb58-6" aria-hidden="true" tabindex="-1"></a>    query_embedding <span class="op">=</span> extract_face_embedding(query_face_path)</span>
<span id="cb58-7"><a href="#cb58-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-8"><a href="#cb58-8" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> client.search(</span>
<span id="cb58-9"><a href="#cb58-9" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb58-10"><a href="#cb58-10" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[query_embedding.tolist()],</span>
<span id="cb58-11"><a href="#cb58-11" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>top_k,</span>
<span id="cb58-12"><a href="#cb58-12" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span><span class="st">"is_active == True"</span>,</span>
<span id="cb58-13"><a href="#cb58-13" aria-hidden="true" tabindex="-1"></a>        search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">200</span>},</span>
<span id="cb58-14"><a href="#cb58-14" aria-hidden="true" tabindex="-1"></a>        output_fields<span class="op">=</span>[<span class="st">"person_id"</span>, <span class="st">"person_name"</span>, <span class="st">"confidence_score"</span>],</span>
<span id="cb58-15"><a href="#cb58-15" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb58-16"><a href="#cb58-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-17"><a href="#cb58-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> results <span class="kw">or</span> <span class="kw">not</span> results[<span class="dv">0</span>]:</span>
<span id="cb58-18"><a href="#cb58-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"status"</span>: <span class="st">"unknown"</span>, <span class="st">"reason"</span>: <span class="st">"no results"</span>}</span>
<span id="cb58-19"><a href="#cb58-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-20"><a href="#cb58-20" aria-hidden="true" tabindex="-1"></a>    top_hit <span class="op">=</span> results[<span class="dv">0</span>][<span class="dv">0</span>]</span>
<span id="cb58-21"><a href="#cb58-21" aria-hidden="true" tabindex="-1"></a>    top_similarity <span class="op">=</span> top_hit[<span class="st">"distance"</span>]</span>
<span id="cb58-22"><a href="#cb58-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-23"><a href="#cb58-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> top_similarity <span class="op">&lt;</span> similarity_threshold:</span>
<span id="cb58-24"><a href="#cb58-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb58-25"><a href="#cb58-25" aria-hidden="true" tabindex="-1"></a>            <span class="st">"status"</span>: <span class="st">"unknown"</span>,</span>
<span id="cb58-26"><a href="#cb58-26" aria-hidden="true" tabindex="-1"></a>            <span class="st">"best_match"</span>: {<span class="st">"person_id"</span>: top_hit[<span class="st">"entity"</span>][<span class="st">"person_id"</span>], <span class="st">"similarity"</span>: top_similarity},</span>
<span id="cb58-27"><a href="#cb58-27" aria-hidden="true" tabindex="-1"></a>            <span class="st">"reason"</span>: <span class="ss">f"similarity </span><span class="sc">{</span>top_similarity<span class="sc">:.4f}</span><span class="ss"> below threshold </span><span class="sc">{</span>similarity_threshold<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb58-28"><a href="#cb58-28" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb58-29"><a href="#cb58-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-30"><a href="#cb58-30" aria-hidden="true" tabindex="-1"></a>    person_votes <span class="op">=</span> {}</span>
<span id="cb58-31"><a href="#cb58-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>]:</span>
<span id="cb58-32"><a href="#cb58-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> hit[<span class="st">"distance"</span>] <span class="op">&gt;=</span> similarity_threshold:</span>
<span id="cb58-33"><a href="#cb58-33" aria-hidden="true" tabindex="-1"></a>            pid <span class="op">=</span> hit[<span class="st">"entity"</span>][<span class="st">"person_id"</span>]</span>
<span id="cb58-34"><a href="#cb58-34" aria-hidden="true" tabindex="-1"></a>            person_votes.setdefault(pid, []).append(hit[<span class="st">"distance"</span>])</span>
<span id="cb58-35"><a href="#cb58-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-36"><a href="#cb58-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> person_votes:</span>
<span id="cb58-37"><a href="#cb58-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"status"</span>: <span class="st">"unknown"</span>, <span class="st">"reason"</span>: <span class="st">"no votes above threshold"</span>}</span>
<span id="cb58-38"><a href="#cb58-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-39"><a href="#cb58-39" aria-hidden="true" tabindex="-1"></a>    best_person <span class="op">=</span> <span class="bu">max</span>(person_votes, key<span class="op">=</span><span class="kw">lambda</span> pid: <span class="bu">sum</span>(person_votes[pid]) <span class="op">/</span> <span class="bu">len</span>(person_votes[pid]))</span>
<span id="cb58-40"><a href="#cb58-40" aria-hidden="true" tabindex="-1"></a>    avg_similarity <span class="op">=</span> <span class="bu">sum</span>(person_votes[best_person]) <span class="op">/</span> <span class="bu">len</span>(person_votes[best_person])</span>
<span id="cb58-41"><a href="#cb58-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-42"><a href="#cb58-42" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb58-43"><a href="#cb58-43" aria-hidden="true" tabindex="-1"></a>        <span class="st">"status"</span>: <span class="st">"recognized"</span>,</span>
<span id="cb58-44"><a href="#cb58-44" aria-hidden="true" tabindex="-1"></a>        <span class="st">"person_id"</span>: best_person,</span>
<span id="cb58-45"><a href="#cb58-45" aria-hidden="true" tabindex="-1"></a>        <span class="st">"person_name"</span>: results[<span class="dv">0</span>][<span class="dv">0</span>][<span class="st">"entity"</span>][<span class="st">"person_name"</span>],</span>
<span id="cb58-46"><a href="#cb58-46" aria-hidden="true" tabindex="-1"></a>        <span class="st">"similarity"</span>: avg_similarity,</span>
<span id="cb58-47"><a href="#cb58-47" aria-hidden="true" tabindex="-1"></a>        <span class="st">"num_matching_embeddings"</span>: <span class="bu">len</span>(person_votes[best_person]),</span>
<span id="cb58-48"><a href="#cb58-48" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb58-49"><a href="#cb58-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-50"><a href="#cb58-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-51"><a href="#cb58-51" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Verification (1:1) ───────────────────────────────────────────────────────</span></span>
<span id="cb58-52"><a href="#cb58-52" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> verify_identity(image_path_1: <span class="bu">str</span>, image_path_2: <span class="bu">str</span>, threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.7</span>) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb58-53"><a href="#cb58-53" aria-hidden="true" tabindex="-1"></a>    emb1 <span class="op">=</span> extract_face_embedding(image_path_1)</span>
<span id="cb58-54"><a href="#cb58-54" aria-hidden="true" tabindex="-1"></a>    emb2 <span class="op">=</span> extract_face_embedding(image_path_2)</span>
<span id="cb58-55"><a href="#cb58-55" aria-hidden="true" tabindex="-1"></a>    similarity <span class="op">=</span> <span class="bu">float</span>(np.dot(emb1, emb2))</span>
<span id="cb58-56"><a href="#cb58-56" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">"same_person"</span>: similarity <span class="op">&gt;=</span> threshold, <span class="st">"similarity"</span>: similarity, <span class="st">"threshold"</span>: threshold}</span>
<span id="cb58-57"><a href="#cb58-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-58"><a href="#cb58-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb58-59"><a href="#cb58-59" aria-hidden="true" tabindex="-1"></a><span class="co"># ─── Removing an Identity ─────────────────────────────────────────────────────</span></span>
<span id="cb58-60"><a href="#cb58-60" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> deactivate_person(person_id: <span class="bu">str</span>):</span>
<span id="cb58-61"><a href="#cb58-61" aria-hidden="true" tabindex="-1"></a>    entities <span class="op">=</span> client.query(</span>
<span id="cb58-62"><a href="#cb58-62" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb58-63"><a href="#cb58-63" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span><span class="ss">f"person_id == '</span><span class="sc">{</span>person_id<span class="sc">}</span><span class="ss">'"</span>,</span>
<span id="cb58-64"><a href="#cb58-64" aria-hidden="true" tabindex="-1"></a>        output_fields<span class="op">=</span>[<span class="st">"id"</span>],</span>
<span id="cb58-65"><a href="#cb58-65" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb58-66"><a href="#cb58-66" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> entities:</span>
<span id="cb58-67"><a href="#cb58-67" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"No enrollments found for person_id: </span><span class="sc">{</span>person_id<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb58-68"><a href="#cb58-68" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span></span>
<span id="cb58-69"><a href="#cb58-69" aria-hidden="true" tabindex="-1"></a>    client.delete(collection_name<span class="op">=</span>COLLECTION_NAME, ids<span class="op">=</span>[e[<span class="st">"id"</span>] <span class="cf">for</span> e <span class="kw">in</span> entities])</span>
<span id="cb58-70"><a href="#cb58-70" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Deleted </span><span class="sc">{</span><span class="bu">len</span>(entities)<span class="sc">}</span><span class="ss"> enrollments for person </span><span class="sc">{</span>person_id<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="similarity-thresholds-for-face-recognition" class="level3">
<h3 class="anchored" data-anchor-id="similarity-thresholds-for-face-recognition" id="similarity-thresholds-for-face-recognition">Similarity Thresholds for Face Recognition</h3>
<p>Thresholds vary significantly by model. Always calibrate on your target dataset:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Model</th>
<th>Typical Threshold (IP/Cosine)</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>ArcFace (ResNet-50)</td>
<td>0.65–0.75</td>
<td>Very robust model</td>
</tr>
<tr class="even">
<td>FaceNet (Inception)</td>
<td>0.70–0.80</td>
<td>Good general purpose</td>
</tr>
<tr class="odd">
<td>AdaFace</td>
<td>0.60–0.70</td>
<td>Excellent for low-quality images</td>
</tr>
<tr class="even">
<td>Your custom model</td>
<td>Must be calibrated</td>
<td>Use ROC curve on held-out set</td>
</tr>
</tbody>
</table>
<p><strong>Calibration approach:</strong> Use your validation set, plot the ROC curve, and choose the threshold at your desired false acceptance rate (FAR) and false rejection rate (FRR) operating point.</p>
<hr>
</section>
</section>
<section id="sec-objectdetect" class="level2">
<h2 class="anchored" data-anchor-id="sec-objectdetect" id="sec-objectdetect">13. Use Case 3 — Object Detection &amp; Retrieval</h2>
<p>In object detection pipelines, you first detect objects in an image (bounding boxes + class labels), then embed each detected region for downstream retrieval. Applications include defect detection in manufacturing, retail shelf monitoring, medical imaging, and autonomous driving data curation.</p>
<section id="architecture-1" class="level3">
<h3 class="anchored" data-anchor-id="architecture-1" id="architecture-1">Architecture</h3>
<div class="cell" data-eval="true" data-layout-align="center">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A["Input Image"]
    B["Object Detector YOLO, Faster R-CNN, DETR, etc."]
    C["Bounding Boxes + Class Labels"]
    D["Region Cropping crop each detected region"]
    E["Embedding Model same or different from detector"]
    F["Region Embeddings"]
    G[("Milvus source_image · bbox · class · score")]

    A --&gt; B
    B --&gt; C
    C --&gt; D
    D --&gt; E
    E --&gt; F
    F --&gt; G

    style A fill:#E8F4FD,stroke:#4A90D9
    style B fill:#FEF9E7,stroke:#E8A838
    style C fill:#EAF7EA,stroke:#5BA85A
    style D fill:#FDF2E9,stroke:#E67E22
    style E fill:#F4ECF7,stroke:#8B6BB1
    style F fill:#FDEDEC,stroke:#D95F5F
    style G fill:#EAFAF1,stroke:#27AE60
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="schema-for-object-detections" class="level3">
<h3 class="anchored" data-anchor-id="schema-for-object-detections" id="schema-for-object-detections">Schema for Object Detections</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb59"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb59-1"><a href="#cb59-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, DataType</span>
<span id="cb59-2"><a href="#cb59-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-3"><a href="#cb59-3" aria-hidden="true" tabindex="-1"></a>COLLECTION_NAME <span class="op">=</span> <span class="st">"object_detections"</span></span>
<span id="cb59-4"><a href="#cb59-4" aria-hidden="true" tabindex="-1"></a>REGION_EMBEDDING_DIM <span class="op">=</span> <span class="dv">512</span></span>
<span id="cb59-5"><a href="#cb59-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-6"><a href="#cb59-6" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./object_detection.db"</span>)</span>
<span id="cb59-7"><a href="#cb59-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-8"><a href="#cb59-8" aria-hidden="true" tabindex="-1"></a>schema <span class="op">=</span> client.create_schema(auto_id<span class="op">=</span><span class="va">True</span>, enable_dynamic_field<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb59-9"><a href="#cb59-9" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"id"</span>, DataType.INT64, is_primary<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb59-10"><a href="#cb59-10" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"embedding"</span>, DataType.FLOAT_VECTOR, dim<span class="op">=</span>REGION_EMBEDDING_DIM)</span>
<span id="cb59-11"><a href="#cb59-11" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"source_image_path"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb59-12"><a href="#cb59-12" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"source_image_id"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb59-13"><a href="#cb59-13" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"bbox_x1"</span>, DataType.FLOAT)</span>
<span id="cb59-14"><a href="#cb59-14" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"bbox_y1"</span>, DataType.FLOAT)</span>
<span id="cb59-15"><a href="#cb59-15" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"bbox_x2"</span>, DataType.FLOAT)</span>
<span id="cb59-16"><a href="#cb59-16" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"bbox_y2"</span>, DataType.FLOAT)</span>
<span id="cb59-17"><a href="#cb59-17" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"class_name"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb59-18"><a href="#cb59-18" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"class_id"</span>, DataType.INT32)</span>
<span id="cb59-19"><a href="#cb59-19" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"detection_score"</span>, DataType.FLOAT)</span>
<span id="cb59-20"><a href="#cb59-20" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"area_fraction"</span>, DataType.FLOAT)</span>
<span id="cb59-21"><a href="#cb59-21" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"detected_at"</span>, DataType.INT64)</span>
<span id="cb59-22"><a href="#cb59-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-23"><a href="#cb59-23" aria-hidden="true" tabindex="-1"></a>index_params <span class="op">=</span> client.prepare_index_params()</span>
<span id="cb59-24"><a href="#cb59-24" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb59-25"><a href="#cb59-25" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb59-26"><a href="#cb59-26" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"HNSW"</span>,</span>
<span id="cb59-27"><a href="#cb59-27" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb59-28"><a href="#cb59-28" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{<span class="st">"M"</span>: <span class="dv">16</span>, <span class="st">"efConstruction"</span>: <span class="dv">200</span>},</span>
<span id="cb59-29"><a href="#cb59-29" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb59-30"><a href="#cb59-30" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"class_name"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb59-31"><a href="#cb59-31" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"class_id"</span>, index_type<span class="op">=</span><span class="st">"STL_SORT"</span>)</span>
<span id="cb59-32"><a href="#cb59-32" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"detection_score"</span>, index_type<span class="op">=</span><span class="st">"STL_SORT"</span>)</span>
<span id="cb59-33"><a href="#cb59-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-34"><a href="#cb59-34" aria-hidden="true" tabindex="-1"></a>client.create_collection(</span>
<span id="cb59-35"><a href="#cb59-35" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb59-36"><a href="#cb59-36" aria-hidden="true" tabindex="-1"></a>    schema<span class="op">=</span>schema,</span>
<span id="cb59-37"><a href="#cb59-37" aria-hidden="true" tabindex="-1"></a>    index_params<span class="op">=</span>index_params,</span>
<span id="cb59-38"><a href="#cb59-38" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="processing-a-detection-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="processing-a-detection-pipeline" id="processing-a-detection-pipeline">Processing a Detection Pipeline</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb60"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb60-1"><a href="#cb60-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb60-2"><a href="#cb60-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb60-3"><a href="#cb60-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dataclasses <span class="im">import</span> dataclass</span>
<span id="cb60-4"><a href="#cb60-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-5"><a href="#cb60-5" aria-hidden="true" tabindex="-1"></a><span class="at">@dataclass</span></span>
<span id="cb60-6"><a href="#cb60-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Detection:</span>
<span id="cb60-7"><a href="#cb60-7" aria-hidden="true" tabindex="-1"></a>    class_name: <span class="bu">str</span></span>
<span id="cb60-8"><a href="#cb60-8" aria-hidden="true" tabindex="-1"></a>    class_id: <span class="bu">int</span></span>
<span id="cb60-9"><a href="#cb60-9" aria-hidden="true" tabindex="-1"></a>    score: <span class="bu">float</span></span>
<span id="cb60-10"><a href="#cb60-10" aria-hidden="true" tabindex="-1"></a>    x1: <span class="bu">float</span></span>
<span id="cb60-11"><a href="#cb60-11" aria-hidden="true" tabindex="-1"></a>    y1: <span class="bu">float</span></span>
<span id="cb60-12"><a href="#cb60-12" aria-hidden="true" tabindex="-1"></a>    x2: <span class="bu">float</span></span>
<span id="cb60-13"><a href="#cb60-13" aria-hidden="true" tabindex="-1"></a>    y2: <span class="bu">float</span></span>
<span id="cb60-14"><a href="#cb60-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-15"><a href="#cb60-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-16"><a href="#cb60-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> detect_objects(image_path: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">list</span>[Detection]:</span>
<span id="cb60-17"><a href="#cb60-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb60-18"><a href="#cb60-18" aria-hidden="true" tabindex="-1"></a><span class="co">    Placeholder for your object detection model.</span></span>
<span id="cb60-19"><a href="#cb60-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-20"><a href="#cb60-20" aria-hidden="true" tabindex="-1"></a><span class="co">    Example with Ultralytics YOLO:</span></span>
<span id="cb60-21"><a href="#cb60-21" aria-hidden="true" tabindex="-1"></a><span class="co">        from ultralytics import YOLO</span></span>
<span id="cb60-22"><a href="#cb60-22" aria-hidden="true" tabindex="-1"></a><span class="co">        model = YOLO("yolov8n.pt")</span></span>
<span id="cb60-23"><a href="#cb60-23" aria-hidden="true" tabindex="-1"></a><span class="co">        results = model(image_path)</span></span>
<span id="cb60-24"><a href="#cb60-24" aria-hidden="true" tabindex="-1"></a><span class="co">        detections = []</span></span>
<span id="cb60-25"><a href="#cb60-25" aria-hidden="true" tabindex="-1"></a><span class="co">        for box in results[0].boxes:</span></span>
<span id="cb60-26"><a href="#cb60-26" aria-hidden="true" tabindex="-1"></a><span class="co">            x1, y1, x2, y2 = box.xyxyn[0].tolist()</span></span>
<span id="cb60-27"><a href="#cb60-27" aria-hidden="true" tabindex="-1"></a><span class="co">            detections.append(Detection(</span></span>
<span id="cb60-28"><a href="#cb60-28" aria-hidden="true" tabindex="-1"></a><span class="co">                class_name=model.names[int(box.cls)],</span></span>
<span id="cb60-29"><a href="#cb60-29" aria-hidden="true" tabindex="-1"></a><span class="co">                class_id=int(box.cls),</span></span>
<span id="cb60-30"><a href="#cb60-30" aria-hidden="true" tabindex="-1"></a><span class="co">                score=float(box.conf),</span></span>
<span id="cb60-31"><a href="#cb60-31" aria-hidden="true" tabindex="-1"></a><span class="co">                x1=x1, y1=y1, x2=x2, y2=y2,</span></span>
<span id="cb60-32"><a href="#cb60-32" aria-hidden="true" tabindex="-1"></a><span class="co">            ))</span></span>
<span id="cb60-33"><a href="#cb60-33" aria-hidden="true" tabindex="-1"></a><span class="co">        return detections</span></span>
<span id="cb60-34"><a href="#cb60-34" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb60-35"><a href="#cb60-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [</span>
<span id="cb60-36"><a href="#cb60-36" aria-hidden="true" tabindex="-1"></a>        Detection(<span class="st">"car"</span>, <span class="dv">2</span>, <span class="fl">0.95</span>, <span class="fl">0.1</span>, <span class="fl">0.2</span>, <span class="fl">0.4</span>, <span class="fl">0.8</span>),</span>
<span id="cb60-37"><a href="#cb60-37" aria-hidden="true" tabindex="-1"></a>        Detection(<span class="st">"person"</span>, <span class="dv">0</span>, <span class="fl">0.87</span>, <span class="fl">0.5</span>, <span class="fl">0.1</span>, <span class="fl">0.7</span>, <span class="fl">0.9</span>),</span>
<span id="cb60-38"><a href="#cb60-38" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb60-39"><a href="#cb60-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-40"><a href="#cb60-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-41"><a href="#cb60-41" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_region_embedding(image_path: <span class="bu">str</span>, detection: Detection) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb60-42"><a href="#cb60-42" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb60-43"><a href="#cb60-43" aria-hidden="true" tabindex="-1"></a><span class="co">    Crop the detected region and extract its embedding.</span></span>
<span id="cb60-44"><a href="#cb60-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-45"><a href="#cb60-45" aria-hidden="true" tabindex="-1"></a><span class="co">    Example with PIL:</span></span>
<span id="cb60-46"><a href="#cb60-46" aria-hidden="true" tabindex="-1"></a><span class="co">        from PIL import Image</span></span>
<span id="cb60-47"><a href="#cb60-47" aria-hidden="true" tabindex="-1"></a><span class="co">        img = Image.open(image_path).convert("RGB")</span></span>
<span id="cb60-48"><a href="#cb60-48" aria-hidden="true" tabindex="-1"></a><span class="co">        w, h = img.size</span></span>
<span id="cb60-49"><a href="#cb60-49" aria-hidden="true" tabindex="-1"></a><span class="co">        box = (int(detection.x1*w), int(detection.y1*h),</span></span>
<span id="cb60-50"><a href="#cb60-50" aria-hidden="true" tabindex="-1"></a><span class="co">               int(detection.x2*w), int(detection.y2*h))</span></span>
<span id="cb60-51"><a href="#cb60-51" aria-hidden="true" tabindex="-1"></a><span class="co">        region = img.crop(box)</span></span>
<span id="cb60-52"><a href="#cb60-52" aria-hidden="true" tabindex="-1"></a><span class="co">        # Pass through embedding model...</span></span>
<span id="cb60-53"><a href="#cb60-53" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb60-54"><a href="#cb60-54" aria-hidden="true" tabindex="-1"></a>    vec <span class="op">=</span> np.random.randn(REGION_EMBEDDING_DIM).astype(np.float32)</span>
<span id="cb60-55"><a href="#cb60-55" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vec <span class="op">/</span> np.linalg.norm(vec)</span>
<span id="cb60-56"><a href="#cb60-56" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-57"><a href="#cb60-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-58"><a href="#cb60-58" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_and_ingest_image(image_path: <span class="bu">str</span>, image_id: <span class="bu">str</span>):</span>
<span id="cb60-59"><a href="#cb60-59" aria-hidden="true" tabindex="-1"></a>    detections <span class="op">=</span> detect_objects(image_path)</span>
<span id="cb60-60"><a href="#cb60-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> detections:</span>
<span id="cb60-61"><a href="#cb60-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">0</span></span>
<span id="cb60-62"><a href="#cb60-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-63"><a href="#cb60-63" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> []</span>
<span id="cb60-64"><a href="#cb60-64" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> det <span class="kw">in</span> detections:</span>
<span id="cb60-65"><a href="#cb60-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> det.score <span class="op">&lt;</span> <span class="fl">0.5</span>:</span>
<span id="cb60-66"><a href="#cb60-66" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb60-67"><a href="#cb60-67" aria-hidden="true" tabindex="-1"></a>        embedding <span class="op">=</span> extract_region_embedding(image_path, det)</span>
<span id="cb60-68"><a href="#cb60-68" aria-hidden="true" tabindex="-1"></a>        area <span class="op">=</span> (det.x2 <span class="op">-</span> det.x1) <span class="op">*</span> (det.y2 <span class="op">-</span> det.y1)</span>
<span id="cb60-69"><a href="#cb60-69" aria-hidden="true" tabindex="-1"></a>        data.append({</span>
<span id="cb60-70"><a href="#cb60-70" aria-hidden="true" tabindex="-1"></a>            <span class="st">"embedding"</span>: embedding.tolist(),</span>
<span id="cb60-71"><a href="#cb60-71" aria-hidden="true" tabindex="-1"></a>            <span class="st">"source_image_path"</span>: image_path,</span>
<span id="cb60-72"><a href="#cb60-72" aria-hidden="true" tabindex="-1"></a>            <span class="st">"source_image_id"</span>: image_id,</span>
<span id="cb60-73"><a href="#cb60-73" aria-hidden="true" tabindex="-1"></a>            <span class="st">"bbox_x1"</span>: det.x1, <span class="st">"bbox_y1"</span>: det.y1,</span>
<span id="cb60-74"><a href="#cb60-74" aria-hidden="true" tabindex="-1"></a>            <span class="st">"bbox_x2"</span>: det.x2, <span class="st">"bbox_y2"</span>: det.y2,</span>
<span id="cb60-75"><a href="#cb60-75" aria-hidden="true" tabindex="-1"></a>            <span class="st">"class_name"</span>: det.class_name,</span>
<span id="cb60-76"><a href="#cb60-76" aria-hidden="true" tabindex="-1"></a>            <span class="st">"class_id"</span>: det.class_id,</span>
<span id="cb60-77"><a href="#cb60-77" aria-hidden="true" tabindex="-1"></a>            <span class="st">"detection_score"</span>: det.score,</span>
<span id="cb60-78"><a href="#cb60-78" aria-hidden="true" tabindex="-1"></a>            <span class="st">"area_fraction"</span>: area,</span>
<span id="cb60-79"><a href="#cb60-79" aria-hidden="true" tabindex="-1"></a>            <span class="st">"detected_at"</span>: <span class="bu">int</span>(time.time() <span class="op">*</span> <span class="dv">1000</span>),</span>
<span id="cb60-80"><a href="#cb60-80" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb60-81"><a href="#cb60-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-82"><a href="#cb60-82" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> client.insert(collection_name<span class="op">=</span>COLLECTION_NAME, data<span class="op">=</span>data)</span>
<span id="cb60-83"><a href="#cb60-83" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result[<span class="st">"insert_count"</span>]</span>
<span id="cb60-84"><a href="#cb60-84" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-85"><a href="#cb60-85" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-86"><a href="#cb60-86" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_similar_objects(</span>
<span id="cb60-87"><a href="#cb60-87" aria-hidden="true" tabindex="-1"></a>    query_image_path: <span class="bu">str</span>,</span>
<span id="cb60-88"><a href="#cb60-88" aria-hidden="true" tabindex="-1"></a>    query_detection: Detection,</span>
<span id="cb60-89"><a href="#cb60-89" aria-hidden="true" tabindex="-1"></a>    top_k: <span class="bu">int</span> <span class="op">=</span> <span class="dv">10</span>,</span>
<span id="cb60-90"><a href="#cb60-90" aria-hidden="true" tabindex="-1"></a>    same_class_only: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span>,</span>
<span id="cb60-91"><a href="#cb60-91" aria-hidden="true" tabindex="-1"></a>    min_score: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.7</span>,</span>
<span id="cb60-92"><a href="#cb60-92" aria-hidden="true" tabindex="-1"></a>) <span class="op">-&gt;</span> <span class="bu">list</span>:</span>
<span id="cb60-93"><a href="#cb60-93" aria-hidden="true" tabindex="-1"></a>    query_embedding <span class="op">=</span> extract_region_embedding(query_image_path, query_detection)</span>
<span id="cb60-94"><a href="#cb60-94" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-95"><a href="#cb60-95" aria-hidden="true" tabindex="-1"></a>    filters <span class="op">=</span> [<span class="ss">f"detection_score &gt;= </span><span class="sc">{</span>min_score<span class="sc">}</span><span class="ss">"</span>]</span>
<span id="cb60-96"><a href="#cb60-96" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> same_class_only:</span>
<span id="cb60-97"><a href="#cb60-97" aria-hidden="true" tabindex="-1"></a>        filters.append(<span class="ss">f"class_name == '</span><span class="sc">{</span>query_detection<span class="sc">.</span>class_name<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb60-98"><a href="#cb60-98" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-99"><a href="#cb60-99" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> client.search(</span>
<span id="cb60-100"><a href="#cb60-100" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb60-101"><a href="#cb60-101" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[query_embedding.tolist()],</span>
<span id="cb60-102"><a href="#cb60-102" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>top_k,</span>
<span id="cb60-103"><a href="#cb60-103" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span><span class="st">" AND "</span>.join(filters),</span>
<span id="cb60-104"><a href="#cb60-104" aria-hidden="true" tabindex="-1"></a>        search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">150</span>},</span>
<span id="cb60-105"><a href="#cb60-105" aria-hidden="true" tabindex="-1"></a>        output_fields<span class="op">=</span>[</span>
<span id="cb60-106"><a href="#cb60-106" aria-hidden="true" tabindex="-1"></a>            <span class="st">"source_image_path"</span>, <span class="st">"class_name"</span>, <span class="st">"detection_score"</span>,</span>
<span id="cb60-107"><a href="#cb60-107" aria-hidden="true" tabindex="-1"></a>            <span class="st">"bbox_x1"</span>, <span class="st">"bbox_y1"</span>, <span class="st">"bbox_x2"</span>, <span class="st">"bbox_y2"</span></span>
<span id="cb60-108"><a href="#cb60-108" aria-hidden="true" tabindex="-1"></a>        ],</span>
<span id="cb60-109"><a href="#cb60-109" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb60-110"><a href="#cb60-110" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb60-111"><a href="#cb60-111" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [</span>
<span id="cb60-112"><a href="#cb60-112" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb60-113"><a href="#cb60-113" aria-hidden="true" tabindex="-1"></a>            <span class="st">"image_path"</span>: hit[<span class="st">"entity"</span>][<span class="st">"source_image_path"</span>],</span>
<span id="cb60-114"><a href="#cb60-114" aria-hidden="true" tabindex="-1"></a>            <span class="st">"similarity"</span>: hit[<span class="st">"distance"</span>],</span>
<span id="cb60-115"><a href="#cb60-115" aria-hidden="true" tabindex="-1"></a>            <span class="st">"class_name"</span>: hit[<span class="st">"entity"</span>][<span class="st">"class_name"</span>],</span>
<span id="cb60-116"><a href="#cb60-116" aria-hidden="true" tabindex="-1"></a>            <span class="st">"detection_score"</span>: hit[<span class="st">"entity"</span>][<span class="st">"detection_score"</span>],</span>
<span id="cb60-117"><a href="#cb60-117" aria-hidden="true" tabindex="-1"></a>            <span class="st">"bbox"</span>: {</span>
<span id="cb60-118"><a href="#cb60-118" aria-hidden="true" tabindex="-1"></a>                <span class="st">"x1"</span>: hit[<span class="st">"entity"</span>][<span class="st">"bbox_x1"</span>], <span class="st">"y1"</span>: hit[<span class="st">"entity"</span>][<span class="st">"bbox_y1"</span>],</span>
<span id="cb60-119"><a href="#cb60-119" aria-hidden="true" tabindex="-1"></a>                <span class="st">"x2"</span>: hit[<span class="st">"entity"</span>][<span class="st">"bbox_x2"</span>], <span class="st">"y2"</span>: hit[<span class="st">"entity"</span>][<span class="st">"bbox_y2"</span>],</span>
<span id="cb60-120"><a href="#cb60-120" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb60-121"><a href="#cb60-121" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb60-122"><a href="#cb60-122" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>]</span>
<span id="cb60-123"><a href="#cb60-123" aria-hidden="true" tabindex="-1"></a>    ]</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-videosearch" class="level2">
<h2 class="anchored" data-anchor-id="sec-videosearch" id="sec-videosearch">14. Use Case 4 — Video Frame Search</h2>
<p>Video frame search enables you to find specific moments in a video library by content — “find all frames that look like this scene,” “find the first time this logo appears,” or “find all shots of people wearing red jackets.”</p>
<section id="key-challenges-in-video-search" class="level3">
<h3 class="anchored" data-anchor-id="key-challenges-in-video-search" id="key-challenges-in-video-search">Key Challenges in Video Search</h3>
<ol type="1">
<li><strong>Temporal redundancy</strong> — consecutive frames are very similar. You usually don’t want to embed every single frame.</li>
<li><strong>Scale</strong> — a 1-hour video at 30fps has 108,000 frames. A large video library is billions of frames.</li>
<li><strong>Efficient storage</strong> — you need to store enough metadata to locate the exact frame (video ID, timestamp, frame index)</li>
</ol>
</section>
<section id="frame-sampling-strategies" class="level3">
<h3 class="anchored" data-anchor-id="frame-sampling-strategies" id="frame-sampling-strategies">Frame Sampling Strategies</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb61"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb61-1"><a href="#cb61-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_keyframe_indices(total_frames: <span class="bu">int</span>, fps: <span class="bu">float</span>, strategy: <span class="bu">str</span> <span class="op">=</span> <span class="st">"every_n_seconds"</span>, interval: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1.0</span>):</span>
<span id="cb61-2"><a href="#cb61-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb61-3"><a href="#cb61-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns frame indices to sample based on the chosen strategy.</span></span>
<span id="cb61-4"><a href="#cb61-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb61-5"><a href="#cb61-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Strategies:</span></span>
<span id="cb61-6"><a href="#cb61-6" aria-hidden="true" tabindex="-1"></a><span class="co">    - "every_n_seconds": sample one frame every N seconds</span></span>
<span id="cb61-7"><a href="#cb61-7" aria-hidden="true" tabindex="-1"></a><span class="co">    - "every_n_frames": sample every Nth frame</span></span>
<span id="cb61-8"><a href="#cb61-8" aria-hidden="true" tabindex="-1"></a><span class="co">    - "uniform": uniformly sample a fixed number of frames</span></span>
<span id="cb61-9"><a href="#cb61-9" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb61-10"><a href="#cb61-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> strategy <span class="op">==</span> <span class="st">"every_n_seconds"</span>:</span>
<span id="cb61-11"><a href="#cb61-11" aria-hidden="true" tabindex="-1"></a>        step <span class="op">=</span> <span class="bu">max</span>(<span class="dv">1</span>, <span class="bu">int</span>(fps <span class="op">*</span> interval))</span>
<span id="cb61-12"><a href="#cb61-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">0</span>, total_frames, step))</span>
<span id="cb61-13"><a href="#cb61-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> strategy <span class="op">==</span> <span class="st">"every_n_frames"</span>:</span>
<span id="cb61-14"><a href="#cb61-14" aria-hidden="true" tabindex="-1"></a>        step <span class="op">=</span> <span class="bu">max</span>(<span class="dv">1</span>, <span class="bu">int</span>(interval))</span>
<span id="cb61-15"><a href="#cb61-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">0</span>, total_frames, step))</span>
<span id="cb61-16"><a href="#cb61-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> strategy <span class="op">==</span> <span class="st">"uniform"</span>:</span>
<span id="cb61-17"><a href="#cb61-17" aria-hidden="true" tabindex="-1"></a>        n_samples <span class="op">=</span> <span class="bu">int</span>(interval)</span>
<span id="cb61-18"><a href="#cb61-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> n_samples <span class="op">&gt;=</span> total_frames:</span>
<span id="cb61-19"><a href="#cb61-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="bu">list</span>(<span class="bu">range</span>(total_frames))</span>
<span id="cb61-20"><a href="#cb61-20" aria-hidden="true" tabindex="-1"></a>        step <span class="op">=</span> total_frames <span class="op">/</span> n_samples</span>
<span id="cb61-21"><a href="#cb61-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [<span class="bu">int</span>(i <span class="op">*</span> step) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_samples)]</span>
<span id="cb61-22"><a href="#cb61-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb61-23"><a href="#cb61-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Unknown strategy: </span><span class="sc">{</span>strategy<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="schema-for-video-frames" class="level3">
<h3 class="anchored" data-anchor-id="schema-for-video-frames" id="schema-for-video-frames">Schema for Video Frames</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb62"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb62-1"><a href="#cb62-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, DataType</span>
<span id="cb62-2"><a href="#cb62-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-3"><a href="#cb62-3" aria-hidden="true" tabindex="-1"></a>COLLECTION_NAME <span class="op">=</span> <span class="st">"video_frames"</span></span>
<span id="cb62-4"><a href="#cb62-4" aria-hidden="true" tabindex="-1"></a>FRAME_EMBEDDING_DIM <span class="op">=</span> <span class="dv">512</span></span>
<span id="cb62-5"><a href="#cb62-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-6"><a href="#cb62-6" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(<span class="st">"./video_search.db"</span>)</span>
<span id="cb62-7"><a href="#cb62-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-8"><a href="#cb62-8" aria-hidden="true" tabindex="-1"></a>schema <span class="op">=</span> client.create_schema(auto_id<span class="op">=</span><span class="va">True</span>, enable_dynamic_field<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb62-9"><a href="#cb62-9" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"id"</span>, DataType.INT64, is_primary<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb62-10"><a href="#cb62-10" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"embedding"</span>, DataType.FLOAT_VECTOR, dim<span class="op">=</span>FRAME_EMBEDDING_DIM)</span>
<span id="cb62-11"><a href="#cb62-11" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"video_id"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb62-12"><a href="#cb62-12" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"video_path"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb62-13"><a href="#cb62-13" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"video_title"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">256</span>)</span>
<span id="cb62-14"><a href="#cb62-14" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"channel"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb62-15"><a href="#cb62-15" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"frame_index"</span>, DataType.INT64)</span>
<span id="cb62-16"><a href="#cb62-16" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"timestamp_ms"</span>, DataType.INT64)</span>
<span id="cb62-17"><a href="#cb62-17" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"fps"</span>, DataType.FLOAT)</span>
<span id="cb62-18"><a href="#cb62-18" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"scene_tag"</span>, DataType.VARCHAR, max_length<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb62-19"><a href="#cb62-19" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"has_faces"</span>, DataType.BOOL)</span>
<span id="cb62-20"><a href="#cb62-20" aria-hidden="true" tabindex="-1"></a>schema.add_field(<span class="st">"has_text"</span>, DataType.BOOL)</span>
<span id="cb62-21"><a href="#cb62-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-22"><a href="#cb62-22" aria-hidden="true" tabindex="-1"></a>index_params <span class="op">=</span> client.prepare_index_params()</span>
<span id="cb62-23"><a href="#cb62-23" aria-hidden="true" tabindex="-1"></a>index_params.add_index(</span>
<span id="cb62-24"><a href="#cb62-24" aria-hidden="true" tabindex="-1"></a>    field_name<span class="op">=</span><span class="st">"embedding"</span>,</span>
<span id="cb62-25"><a href="#cb62-25" aria-hidden="true" tabindex="-1"></a>    index_type<span class="op">=</span><span class="st">"HNSW"</span>,</span>
<span id="cb62-26"><a href="#cb62-26" aria-hidden="true" tabindex="-1"></a>    metric_type<span class="op">=</span><span class="st">"COSINE"</span>,</span>
<span id="cb62-27"><a href="#cb62-27" aria-hidden="true" tabindex="-1"></a>    params<span class="op">=</span>{<span class="st">"M"</span>: <span class="dv">16</span>, <span class="st">"efConstruction"</span>: <span class="dv">200</span>},</span>
<span id="cb62-28"><a href="#cb62-28" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb62-29"><a href="#cb62-29" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"video_id"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb62-30"><a href="#cb62-30" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"channel"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb62-31"><a href="#cb62-31" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"timestamp_ms"</span>, index_type<span class="op">=</span><span class="st">"STL_SORT"</span>)</span>
<span id="cb62-32"><a href="#cb62-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-33"><a href="#cb62-33" aria-hidden="true" tabindex="-1"></a>client.create_collection(</span>
<span id="cb62-34"><a href="#cb62-34" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb62-35"><a href="#cb62-35" aria-hidden="true" tabindex="-1"></a>    schema<span class="op">=</span>schema,</span>
<span id="cb62-36"><a href="#cb62-36" aria-hidden="true" tabindex="-1"></a>    index_params<span class="op">=</span>index_params,</span>
<span id="cb62-37"><a href="#cb62-37" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="processing-a-video" class="level3">
<h3 class="anchored" data-anchor-id="processing-a-video" id="processing-a-video">Processing a Video</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb63"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb63-1"><a href="#cb63-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb63-2"><a href="#cb63-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb63-3"><a href="#cb63-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-4"><a href="#cb63-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_frame(video_path: <span class="bu">str</span>, frame_index: <span class="bu">int</span>) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb63-5"><a href="#cb63-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb63-6"><a href="#cb63-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Extract a single frame from a video.</span></span>
<span id="cb63-7"><a href="#cb63-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-8"><a href="#cb63-8" aria-hidden="true" tabindex="-1"></a><span class="co">    Example with OpenCV:</span></span>
<span id="cb63-9"><a href="#cb63-9" aria-hidden="true" tabindex="-1"></a><span class="co">        import cv2</span></span>
<span id="cb63-10"><a href="#cb63-10" aria-hidden="true" tabindex="-1"></a><span class="co">        cap = cv2.VideoCapture(video_path)</span></span>
<span id="cb63-11"><a href="#cb63-11" aria-hidden="true" tabindex="-1"></a><span class="co">        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)</span></span>
<span id="cb63-12"><a href="#cb63-12" aria-hidden="true" tabindex="-1"></a><span class="co">        ret, frame = cap.read()</span></span>
<span id="cb63-13"><a href="#cb63-13" aria-hidden="true" tabindex="-1"></a><span class="co">        cap.release()</span></span>
<span id="cb63-14"><a href="#cb63-14" aria-hidden="true" tabindex="-1"></a><span class="co">        if ret:</span></span>
<span id="cb63-15"><a href="#cb63-15" aria-hidden="true" tabindex="-1"></a><span class="co">            return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)</span></span>
<span id="cb63-16"><a href="#cb63-16" aria-hidden="true" tabindex="-1"></a><span class="co">        return None</span></span>
<span id="cb63-17"><a href="#cb63-17" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb63-18"><a href="#cb63-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> np.random.randint(<span class="dv">0</span>, <span class="dv">255</span>, (<span class="dv">480</span>, <span class="dv">640</span>, <span class="dv">3</span>), dtype<span class="op">=</span>np.uint8)</span>
<span id="cb63-19"><a href="#cb63-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-20"><a href="#cb63-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-21"><a href="#cb63-21" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> embed_frame(frame: np.ndarray) <span class="op">-&gt;</span> np.ndarray:</span>
<span id="cb63-22"><a href="#cb63-22" aria-hidden="true" tabindex="-1"></a>    vec <span class="op">=</span> np.random.randn(FRAME_EMBEDDING_DIM).astype(np.float32)</span>
<span id="cb63-23"><a href="#cb63-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> vec <span class="op">/</span> np.linalg.norm(vec)</span>
<span id="cb63-24"><a href="#cb63-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-25"><a href="#cb63-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-26"><a href="#cb63-26" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_video_metadata(video_path: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb63-27"><a href="#cb63-27" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb63-28"><a href="#cb63-28" aria-hidden="true" tabindex="-1"></a><span class="co">    Example with OpenCV:</span></span>
<span id="cb63-29"><a href="#cb63-29" aria-hidden="true" tabindex="-1"></a><span class="co">        import cv2</span></span>
<span id="cb63-30"><a href="#cb63-30" aria-hidden="true" tabindex="-1"></a><span class="co">        cap = cv2.VideoCapture(video_path)</span></span>
<span id="cb63-31"><a href="#cb63-31" aria-hidden="true" tabindex="-1"></a><span class="co">        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))</span></span>
<span id="cb63-32"><a href="#cb63-32" aria-hidden="true" tabindex="-1"></a><span class="co">        fps = cap.get(cv2.CAP_PROP_FPS)</span></span>
<span id="cb63-33"><a href="#cb63-33" aria-hidden="true" tabindex="-1"></a><span class="co">        cap.release()</span></span>
<span id="cb63-34"><a href="#cb63-34" aria-hidden="true" tabindex="-1"></a><span class="co">        return {"total_frames": total_frames, "fps": fps}</span></span>
<span id="cb63-35"><a href="#cb63-35" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb63-36"><a href="#cb63-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">"total_frames"</span>: <span class="dv">3000</span>, <span class="st">"fps"</span>: <span class="fl">30.0</span>}</span>
<span id="cb63-37"><a href="#cb63-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-38"><a href="#cb63-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-39"><a href="#cb63-39" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_video(</span>
<span id="cb63-40"><a href="#cb63-40" aria-hidden="true" tabindex="-1"></a>    video_path: <span class="bu">str</span>,</span>
<span id="cb63-41"><a href="#cb63-41" aria-hidden="true" tabindex="-1"></a>    video_id: <span class="bu">str</span>,</span>
<span id="cb63-42"><a href="#cb63-42" aria-hidden="true" tabindex="-1"></a>    video_title: <span class="bu">str</span> <span class="op">=</span> <span class="st">""</span>,</span>
<span id="cb63-43"><a href="#cb63-43" aria-hidden="true" tabindex="-1"></a>    channel: <span class="bu">str</span> <span class="op">=</span> <span class="st">""</span>,</span>
<span id="cb63-44"><a href="#cb63-44" aria-hidden="true" tabindex="-1"></a>    sample_every_n_seconds: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1.0</span>,</span>
<span id="cb63-45"><a href="#cb63-45" aria-hidden="true" tabindex="-1"></a>    batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">256</span>,</span>
<span id="cb63-46"><a href="#cb63-46" aria-hidden="true" tabindex="-1"></a>):</span>
<span id="cb63-47"><a href="#cb63-47" aria-hidden="true" tabindex="-1"></a>    meta <span class="op">=</span> get_video_metadata(video_path)</span>
<span id="cb63-48"><a href="#cb63-48" aria-hidden="true" tabindex="-1"></a>    total_frames <span class="op">=</span> meta[<span class="st">"total_frames"</span>]</span>
<span id="cb63-49"><a href="#cb63-49" aria-hidden="true" tabindex="-1"></a>    fps <span class="op">=</span> meta[<span class="st">"fps"</span>]</span>
<span id="cb63-50"><a href="#cb63-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-51"><a href="#cb63-51" aria-hidden="true" tabindex="-1"></a>    frame_indices <span class="op">=</span> get_keyframe_indices(</span>
<span id="cb63-52"><a href="#cb63-52" aria-hidden="true" tabindex="-1"></a>        total_frames, fps, strategy<span class="op">=</span><span class="st">"every_n_seconds"</span>, interval<span class="op">=</span>sample_every_n_seconds</span>
<span id="cb63-53"><a href="#cb63-53" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb63-54"><a href="#cb63-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-55"><a href="#cb63-55" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Processing </span><span class="sc">{</span>video_path<span class="sc">}</span><span class="ss"> — sampling </span><span class="sc">{</span><span class="bu">len</span>(frame_indices)<span class="sc">}</span><span class="ss"> frames"</span>)</span>
<span id="cb63-56"><a href="#cb63-56" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-57"><a href="#cb63-57" aria-hidden="true" tabindex="-1"></a>    inserted <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb63-58"><a href="#cb63-58" aria-hidden="true" tabindex="-1"></a>    data_buffer <span class="op">=</span> []</span>
<span id="cb63-59"><a href="#cb63-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-60"><a href="#cb63-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> frame_idx <span class="kw">in</span> frame_indices:</span>
<span id="cb63-61"><a href="#cb63-61" aria-hidden="true" tabindex="-1"></a>        frame <span class="op">=</span> extract_frame(video_path, frame_idx)</span>
<span id="cb63-62"><a href="#cb63-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> frame <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb63-63"><a href="#cb63-63" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb63-64"><a href="#cb63-64" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-65"><a href="#cb63-65" aria-hidden="true" tabindex="-1"></a>        embedding <span class="op">=</span> embed_frame(frame)</span>
<span id="cb63-66"><a href="#cb63-66" aria-hidden="true" tabindex="-1"></a>        timestamp_ms <span class="op">=</span> <span class="bu">int</span>((frame_idx <span class="op">/</span> fps) <span class="op">*</span> <span class="dv">1000</span>)</span>
<span id="cb63-67"><a href="#cb63-67" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-68"><a href="#cb63-68" aria-hidden="true" tabindex="-1"></a>        data_buffer.append({</span>
<span id="cb63-69"><a href="#cb63-69" aria-hidden="true" tabindex="-1"></a>            <span class="st">"embedding"</span>: embedding.tolist(),</span>
<span id="cb63-70"><a href="#cb63-70" aria-hidden="true" tabindex="-1"></a>            <span class="st">"video_id"</span>: video_id,</span>
<span id="cb63-71"><a href="#cb63-71" aria-hidden="true" tabindex="-1"></a>            <span class="st">"video_path"</span>: video_path,</span>
<span id="cb63-72"><a href="#cb63-72" aria-hidden="true" tabindex="-1"></a>            <span class="st">"video_title"</span>: video_title,</span>
<span id="cb63-73"><a href="#cb63-73" aria-hidden="true" tabindex="-1"></a>            <span class="st">"channel"</span>: channel,</span>
<span id="cb63-74"><a href="#cb63-74" aria-hidden="true" tabindex="-1"></a>            <span class="st">"frame_index"</span>: frame_idx,</span>
<span id="cb63-75"><a href="#cb63-75" aria-hidden="true" tabindex="-1"></a>            <span class="st">"timestamp_ms"</span>: timestamp_ms,</span>
<span id="cb63-76"><a href="#cb63-76" aria-hidden="true" tabindex="-1"></a>            <span class="st">"fps"</span>: fps,</span>
<span id="cb63-77"><a href="#cb63-77" aria-hidden="true" tabindex="-1"></a>            <span class="st">"has_faces"</span>: <span class="va">False</span>,</span>
<span id="cb63-78"><a href="#cb63-78" aria-hidden="true" tabindex="-1"></a>            <span class="st">"has_text"</span>: <span class="va">False</span>,</span>
<span id="cb63-79"><a href="#cb63-79" aria-hidden="true" tabindex="-1"></a>            <span class="st">"scene_tag"</span>: <span class="st">"unknown"</span>,</span>
<span id="cb63-80"><a href="#cb63-80" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb63-81"><a href="#cb63-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-82"><a href="#cb63-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(data_buffer) <span class="op">&gt;=</span> batch_size:</span>
<span id="cb63-83"><a href="#cb63-83" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> client.insert(collection_name<span class="op">=</span>COLLECTION_NAME, data<span class="op">=</span>data_buffer)</span>
<span id="cb63-84"><a href="#cb63-84" aria-hidden="true" tabindex="-1"></a>            inserted <span class="op">+=</span> result[<span class="st">"insert_count"</span>]</span>
<span id="cb63-85"><a href="#cb63-85" aria-hidden="true" tabindex="-1"></a>            data_buffer <span class="op">=</span> []</span>
<span id="cb63-86"><a href="#cb63-86" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  Inserted </span><span class="sc">{</span>inserted<span class="sc">}</span><span class="ss"> frames so far..."</span>)</span>
<span id="cb63-87"><a href="#cb63-87" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-88"><a href="#cb63-88" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> data_buffer:</span>
<span id="cb63-89"><a href="#cb63-89" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> client.insert(collection_name<span class="op">=</span>COLLECTION_NAME, data<span class="op">=</span>data_buffer)</span>
<span id="cb63-90"><a href="#cb63-90" aria-hidden="true" tabindex="-1"></a>        inserted <span class="op">+=</span> result[<span class="st">"insert_count"</span>]</span>
<span id="cb63-91"><a href="#cb63-91" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-92"><a href="#cb63-92" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Done: inserted </span><span class="sc">{</span>inserted<span class="sc">}</span><span class="ss"> frames"</span>)</span>
<span id="cb63-93"><a href="#cb63-93" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> inserted</span>
<span id="cb63-94"><a href="#cb63-94" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-95"><a href="#cb63-95" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-96"><a href="#cb63-96" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_similar_frames(</span>
<span id="cb63-97"><a href="#cb63-97" aria-hidden="true" tabindex="-1"></a>    query_image_path: <span class="bu">str</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb63-98"><a href="#cb63-98" aria-hidden="true" tabindex="-1"></a>    query_video_path: <span class="bu">str</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb63-99"><a href="#cb63-99" aria-hidden="true" tabindex="-1"></a>    query_timestamp_ms: <span class="bu">int</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb63-100"><a href="#cb63-100" aria-hidden="true" tabindex="-1"></a>    top_k: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>,</span>
<span id="cb63-101"><a href="#cb63-101" aria-hidden="true" tabindex="-1"></a>    channel_filter: <span class="bu">str</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb63-102"><a href="#cb63-102" aria-hidden="true" tabindex="-1"></a>    time_range_ms: <span class="bu">tuple</span> <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb63-103"><a href="#cb63-103" aria-hidden="true" tabindex="-1"></a>) <span class="op">-&gt;</span> <span class="bu">list</span>:</span>
<span id="cb63-104"><a href="#cb63-104" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> query_image_path:</span>
<span id="cb63-105"><a href="#cb63-105" aria-hidden="true" tabindex="-1"></a>        frame <span class="op">=</span> np.random.randint(<span class="dv">0</span>, <span class="dv">255</span>, (<span class="dv">480</span>, <span class="dv">640</span>, <span class="dv">3</span>), dtype<span class="op">=</span>np.uint8)</span>
<span id="cb63-106"><a href="#cb63-106" aria-hidden="true" tabindex="-1"></a>        query_embedding <span class="op">=</span> embed_frame(frame)</span>
<span id="cb63-107"><a href="#cb63-107" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> query_video_path <span class="kw">and</span> query_timestamp_ms <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb63-108"><a href="#cb63-108" aria-hidden="true" tabindex="-1"></a>        meta <span class="op">=</span> get_video_metadata(query_video_path)</span>
<span id="cb63-109"><a href="#cb63-109" aria-hidden="true" tabindex="-1"></a>        frame_idx <span class="op">=</span> <span class="bu">int</span>((query_timestamp_ms <span class="op">/</span> <span class="dv">1000</span>) <span class="op">*</span> meta[<span class="st">"fps"</span>])</span>
<span id="cb63-110"><a href="#cb63-110" aria-hidden="true" tabindex="-1"></a>        frame <span class="op">=</span> extract_frame(query_video_path, frame_idx)</span>
<span id="cb63-111"><a href="#cb63-111" aria-hidden="true" tabindex="-1"></a>        query_embedding <span class="op">=</span> embed_frame(frame)</span>
<span id="cb63-112"><a href="#cb63-112" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb63-113"><a href="#cb63-113" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="st">"Must provide query_image_path or (query_video_path + query_timestamp_ms)"</span>)</span>
<span id="cb63-114"><a href="#cb63-114" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-115"><a href="#cb63-115" aria-hidden="true" tabindex="-1"></a>    filters <span class="op">=</span> []</span>
<span id="cb63-116"><a href="#cb63-116" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> channel_filter:</span>
<span id="cb63-117"><a href="#cb63-117" aria-hidden="true" tabindex="-1"></a>        filters.append(<span class="ss">f"channel == '</span><span class="sc">{</span>channel_filter<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb63-118"><a href="#cb63-118" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> time_range_ms:</span>
<span id="cb63-119"><a href="#cb63-119" aria-hidden="true" tabindex="-1"></a>        start_ms, end_ms <span class="op">=</span> time_range_ms</span>
<span id="cb63-120"><a href="#cb63-120" aria-hidden="true" tabindex="-1"></a>        filters.append(<span class="ss">f"timestamp_ms &gt;= </span><span class="sc">{</span>start_ms<span class="sc">}</span><span class="ss"> AND timestamp_ms &lt;= </span><span class="sc">{</span>end_ms<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb63-121"><a href="#cb63-121" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-122"><a href="#cb63-122" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> client.search(</span>
<span id="cb63-123"><a href="#cb63-123" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>COLLECTION_NAME,</span>
<span id="cb63-124"><a href="#cb63-124" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[query_embedding.tolist()],</span>
<span id="cb63-125"><a href="#cb63-125" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>top_k,</span>
<span id="cb63-126"><a href="#cb63-126" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span><span class="st">" AND "</span>.join(filters) <span class="cf">if</span> filters <span class="cf">else</span> <span class="va">None</span>,</span>
<span id="cb63-127"><a href="#cb63-127" aria-hidden="true" tabindex="-1"></a>        search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">200</span>},</span>
<span id="cb63-128"><a href="#cb63-128" aria-hidden="true" tabindex="-1"></a>        output_fields<span class="op">=</span>[<span class="st">"video_id"</span>, <span class="st">"video_title"</span>, <span class="st">"video_path"</span>, <span class="st">"frame_index"</span>, <span class="st">"timestamp_ms"</span>, <span class="st">"channel"</span>],</span>
<span id="cb63-129"><a href="#cb63-129" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb63-130"><a href="#cb63-130" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-131"><a href="#cb63-131" aria-hidden="true" tabindex="-1"></a>    hits <span class="op">=</span> []</span>
<span id="cb63-132"><a href="#cb63-132" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> hit <span class="kw">in</span> results[<span class="dv">0</span>]:</span>
<span id="cb63-133"><a href="#cb63-133" aria-hidden="true" tabindex="-1"></a>        ts <span class="op">=</span> hit[<span class="st">"entity"</span>][<span class="st">"timestamp_ms"</span>]</span>
<span id="cb63-134"><a href="#cb63-134" aria-hidden="true" tabindex="-1"></a>        hours <span class="op">=</span> ts <span class="op">//</span> <span class="dv">3_600_000</span></span>
<span id="cb63-135"><a href="#cb63-135" aria-hidden="true" tabindex="-1"></a>        minutes <span class="op">=</span> (ts <span class="op">%</span> <span class="dv">3_600_000</span>) <span class="op">//</span> <span class="dv">60_000</span></span>
<span id="cb63-136"><a href="#cb63-136" aria-hidden="true" tabindex="-1"></a>        seconds <span class="op">=</span> (ts <span class="op">%</span> <span class="dv">60_000</span>) <span class="op">/</span> <span class="dv">1000</span></span>
<span id="cb63-137"><a href="#cb63-137" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-138"><a href="#cb63-138" aria-hidden="true" tabindex="-1"></a>        hits.append({</span>
<span id="cb63-139"><a href="#cb63-139" aria-hidden="true" tabindex="-1"></a>            <span class="st">"video_id"</span>: hit[<span class="st">"entity"</span>][<span class="st">"video_id"</span>],</span>
<span id="cb63-140"><a href="#cb63-140" aria-hidden="true" tabindex="-1"></a>            <span class="st">"video_title"</span>: hit[<span class="st">"entity"</span>][<span class="st">"video_title"</span>],</span>
<span id="cb63-141"><a href="#cb63-141" aria-hidden="true" tabindex="-1"></a>            <span class="st">"frame_index"</span>: hit[<span class="st">"entity"</span>][<span class="st">"frame_index"</span>],</span>
<span id="cb63-142"><a href="#cb63-142" aria-hidden="true" tabindex="-1"></a>            <span class="st">"timestamp_ms"</span>: ts,</span>
<span id="cb63-143"><a href="#cb63-143" aria-hidden="true" tabindex="-1"></a>            <span class="st">"timestamp_str"</span>: <span class="ss">f"</span><span class="sc">{</span>hours<span class="sc">:02d}</span><span class="ss">:</span><span class="sc">{</span>minutes<span class="sc">:02d}</span><span class="ss">:</span><span class="sc">{</span>seconds<span class="sc">:05.2f}</span><span class="ss">"</span>,</span>
<span id="cb63-144"><a href="#cb63-144" aria-hidden="true" tabindex="-1"></a>            <span class="st">"similarity"</span>: hit[<span class="st">"distance"</span>],</span>
<span id="cb63-145"><a href="#cb63-145" aria-hidden="true" tabindex="-1"></a>            <span class="st">"channel"</span>: hit[<span class="st">"entity"</span>][<span class="st">"channel"</span>],</span>
<span id="cb63-146"><a href="#cb63-146" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb63-147"><a href="#cb63-147" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb63-148"><a href="#cb63-148" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> hits</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-partitions" class="level2">
<h2 class="anchored" data-anchor-id="sec-partitions" id="sec-partitions">15. Partitions, Filtering, and Hybrid Search</h2>
<section id="partitions" class="level3">
<h3 class="anchored" data-anchor-id="partitions" id="partitions">Partitions</h3>
<p>Partitions are logical subdivisions within a collection that allow you to scope searches to a subset of the data, dramatically improving query speed when you know which partition to target.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb64"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb64-1"><a href="#cb64-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create partitions (e.g., by year for a video archive)</span></span>
<span id="cb64-2"><a href="#cb64-2" aria-hidden="true" tabindex="-1"></a>client.create_partition(collection_name<span class="op">=</span><span class="st">"video_frames"</span>, partition_name<span class="op">=</span><span class="st">"2023"</span>)</span>
<span id="cb64-3"><a href="#cb64-3" aria-hidden="true" tabindex="-1"></a>client.create_partition(collection_name<span class="op">=</span><span class="st">"video_frames"</span>, partition_name<span class="op">=</span><span class="st">"2024"</span>)</span>
<span id="cb64-4"><a href="#cb64-4" aria-hidden="true" tabindex="-1"></a>client.create_partition(collection_name<span class="op">=</span><span class="st">"video_frames"</span>, partition_name<span class="op">=</span><span class="st">"2025"</span>)</span>
<span id="cb64-5"><a href="#cb64-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb64-6"><a href="#cb64-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Insert into a specific partition</span></span>
<span id="cb64-7"><a href="#cb64-7" aria-hidden="true" tabindex="-1"></a>client.insert(</span>
<span id="cb64-8"><a href="#cb64-8" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"video_frames"</span>,</span>
<span id="cb64-9"><a href="#cb64-9" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>my_data_2024,</span>
<span id="cb64-10"><a href="#cb64-10" aria-hidden="true" tabindex="-1"></a>    partition_name<span class="op">=</span><span class="st">"2024"</span>,</span>
<span id="cb64-11"><a href="#cb64-11" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb64-12"><a href="#cb64-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb64-13"><a href="#cb64-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Search only in the "2024" partition</span></span>
<span id="cb64-14"><a href="#cb64-14" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb64-15"><a href="#cb64-15" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"video_frames"</span>,</span>
<span id="cb64-16"><a href="#cb64-16" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[query_vector],</span>
<span id="cb64-17"><a href="#cb64-17" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb64-18"><a href="#cb64-18" aria-hidden="true" tabindex="-1"></a>    partition_names<span class="op">=</span>[<span class="st">"2024"</span>],</span>
<span id="cb64-19"><a href="#cb64-19" aria-hidden="true" tabindex="-1"></a>    search_params<span class="op">=</span>{<span class="st">"ef"</span>: <span class="dv">100</span>},</span>
<span id="cb64-20"><a href="#cb64-20" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb64-21"><a href="#cb64-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb64-22"><a href="#cb64-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Search across multiple partitions</span></span>
<span id="cb64-23"><a href="#cb64-23" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.search(</span>
<span id="cb64-24"><a href="#cb64-24" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"video_frames"</span>,</span>
<span id="cb64-25"><a href="#cb64-25" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[query_vector],</span>
<span id="cb64-26"><a href="#cb64-26" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb64-27"><a href="#cb64-27" aria-hidden="true" tabindex="-1"></a>    partition_names<span class="op">=</span>[<span class="st">"2024"</span>, <span class="st">"2025"</span>],</span>
<span id="cb64-28"><a href="#cb64-28" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p><strong>Partition Design Guidelines:</strong></p>
<ul>
<li>Use partitions for high-cardinality categorical splits (year, user_id, camera_id)</li>
<li>Avoid too many partitions (&lt; 4096 per collection is safe)</li>
<li>Don’t use partitions as a substitute for scalar filtering on low-cardinality fields</li>
</ul>
</section>
<section id="advanced-filter-expressions" class="level3">
<h3 class="anchored" data-anchor-id="advanced-filter-expressions" id="advanced-filter-expressions">Advanced Filter Expressions</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb65"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb65-1"><a href="#cb65-1" aria-hidden="true" tabindex="-1"></a><span class="co"># String operations</span></span>
<span id="cb65-2"><a href="#cb65-2" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"label in ['dog', 'cat', 'bird']"</span></span>
<span id="cb65-3"><a href="#cb65-3" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"image_path like '/dataset/train/%'"</span></span>
<span id="cb65-4"><a href="#cb65-4" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"NOT (label in ['background', 'unknown'])"</span></span>
<span id="cb65-5"><a href="#cb65-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb65-6"><a href="#cb65-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Numeric comparisons</span></span>
<span id="cb65-7"><a href="#cb65-7" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"confidence &gt; 0.85 AND detection_score &lt; 0.99"</span></span>
<span id="cb65-8"><a href="#cb65-8" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"width &gt;= 1920 AND height &gt;= 1080"</span></span>
<span id="cb65-9"><a href="#cb65-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb65-10"><a href="#cb65-10" aria-hidden="true" tabindex="-1"></a><span class="co"># JSON field access</span></span>
<span id="cb65-11"><a href="#cb65-11" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"metadata['camera_id'] == 'cam_01'"</span></span>
<span id="cb65-12"><a href="#cb65-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb65-13"><a href="#cb65-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Combining conditions</span></span>
<span id="cb65-14"><a href="#cb65-14" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"(label == 'dog' OR label == 'cat') AND confidence &gt; 0.9 AND dataset_split == 'train'"</span></span>
<span id="cb65-15"><a href="#cb65-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb65-16"><a href="#cb65-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Array containment</span></span>
<span id="cb65-17"><a href="#cb65-17" aria-hidden="true" tabindex="-1"></a><span class="bu">filter</span><span class="op">=</span><span class="st">"ARRAY_CONTAINS(tags, 'outdoor')"</span></span></code></pre></div></div>
</section>
<section id="hybrid-search-vector-full-text-search" class="level3">
<h3 class="anchored" data-anchor-id="hybrid-search-vector-full-text-search" id="hybrid-search-vector-full-text-search">Hybrid Search (Vector + Full-Text Search)</h3>
<p>Milvus 2.5+ supports hybrid search — combining dense vector search with sparse (BM25/keyword) retrieval and re-ranking results using Reciprocal Rank Fusion (RRF):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb66"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb66-1"><a href="#cb66-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> MilvusClient, AnnSearchRequest, RRFRanker, WeightedRanker</span>
<span id="cb66-2"><a href="#cb66-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb66-3"><a href="#cb66-3" aria-hidden="true" tabindex="-1"></a>dense_request <span class="op">=</span> AnnSearchRequest(</span>
<span id="cb66-4"><a href="#cb66-4" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[dense_query_vector],</span>
<span id="cb66-5"><a href="#cb66-5" aria-hidden="true" tabindex="-1"></a>    anns_field<span class="op">=</span><span class="st">"dense_embedding"</span>,</span>
<span id="cb66-6"><a href="#cb66-6" aria-hidden="true" tabindex="-1"></a>    param<span class="op">=</span>{<span class="st">"metric_type"</span>: <span class="st">"COSINE"</span>, <span class="st">"params"</span>: {<span class="st">"ef"</span>: <span class="dv">100</span>}},</span>
<span id="cb66-7"><a href="#cb66-7" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb66-8"><a href="#cb66-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb66-9"><a href="#cb66-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb66-10"><a href="#cb66-10" aria-hidden="true" tabindex="-1"></a>sparse_request <span class="op">=</span> AnnSearchRequest(</span>
<span id="cb66-11"><a href="#cb66-11" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>[sparse_query_vector],</span>
<span id="cb66-12"><a href="#cb66-12" aria-hidden="true" tabindex="-1"></a>    anns_field<span class="op">=</span><span class="st">"sparse_embedding"</span>,</span>
<span id="cb66-13"><a href="#cb66-13" aria-hidden="true" tabindex="-1"></a>    param<span class="op">=</span>{<span class="st">"metric_type"</span>: <span class="st">"IP"</span>, <span class="st">"params"</span>: {}},</span>
<span id="cb66-14"><a href="#cb66-14" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb66-15"><a href="#cb66-15" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb66-16"><a href="#cb66-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb66-17"><a href="#cb66-17" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> client.hybrid_search(</span>
<span id="cb66-18"><a href="#cb66-18" aria-hidden="true" tabindex="-1"></a>    collection_name<span class="op">=</span><span class="st">"multimodal_index"</span>,</span>
<span id="cb66-19"><a href="#cb66-19" aria-hidden="true" tabindex="-1"></a>    reqs<span class="op">=</span>[dense_request, sparse_request],</span>
<span id="cb66-20"><a href="#cb66-20" aria-hidden="true" tabindex="-1"></a>    ranker<span class="op">=</span>RRFRanker(k<span class="op">=</span><span class="dv">60</span>),</span>
<span id="cb66-21"><a href="#cb66-21" aria-hidden="true" tabindex="-1"></a>    limit<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb66-22"><a href="#cb66-22" aria-hidden="true" tabindex="-1"></a>    output_fields<span class="op">=</span>[<span class="st">"image_path"</span>, <span class="st">"caption"</span>],</span>
<span id="cb66-23"><a href="#cb66-23" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-performance" class="level2">
<h2 class="anchored" data-anchor-id="sec-performance" id="sec-performance">16. Performance Tuning and Best Practices</h2>
<section id="index-parameter-tuning-for-hnsw" class="level3">
<h3 class="anchored" data-anchor-id="index-parameter-tuning-for-hnsw" id="index-parameter-tuning-for-hnsw">16.1 Index Parameter Tuning for HNSW</h3>
<p><strong>Build-time (<code>M</code> and <code>efConstruction</code>):</strong></p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Dataset Size</th>
<th>M</th>
<th>efConstruction</th>
<th>Build Time</th>
<th>Memory</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>&lt; 1M vectors</td>
<td>8</td>
<td>100</td>
<td>Fast</td>
<td>Low</td>
</tr>
<tr class="even">
<td>1M–10M</td>
<td>16</td>
<td>200</td>
<td>Moderate</td>
<td>Moderate</td>
</tr>
<tr class="odd">
<td>10M–100M</td>
<td>16–32</td>
<td>200–400</td>
<td>Slow</td>
<td>High</td>
</tr>
<tr class="even">
<td>100M+</td>
<td>16</td>
<td>200</td>
<td>Very slow</td>
<td>Very high</td>
</tr>
</tbody>
</table>
<p><strong>Query-time (<code>ef</code>):</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb67"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb67-1"><a href="#cb67-1" aria-hidden="true" tabindex="-1"></a><span class="co"># ef must be &gt;= limit (top_k)</span></span>
<span id="cb67-2"><a href="#cb67-2" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {<span class="st">"ef"</span>: <span class="dv">50</span>}    <span class="co"># Fast, lower recall</span></span>
<span id="cb67-3"><a href="#cb67-3" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {<span class="st">"ef"</span>: <span class="dv">100</span>}   <span class="co"># Balanced (recommended starting point)</span></span>
<span id="cb67-4"><a href="#cb67-4" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {<span class="st">"ef"</span>: <span class="dv">500</span>}   <span class="co"># High recall</span></span>
<span id="cb67-5"><a href="#cb67-5" aria-hidden="true" tabindex="-1"></a>search_params <span class="op">=</span> {<span class="st">"ef"</span>: <span class="dv">2000</span>}  <span class="co"># Maximum recall (approaching FLAT accuracy)</span></span></code></pre></div></div>
</section>
<section id="batch-insertion-performance" class="level3">
<h3 class="anchored" data-anchor-id="batch-insertion-performance" id="batch-insertion-performance">16.2 Batch Insertion Performance</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb68"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb68-1"><a href="#cb68-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Bad: insert one at a time</span></span>
<span id="cb68-2"><a href="#cb68-2" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> record <span class="kw">in</span> all_records:</span>
<span id="cb68-3"><a href="#cb68-3" aria-hidden="true" tabindex="-1"></a>    client.insert(collection_name<span class="op">=</span><span class="st">"..."</span>, data<span class="op">=</span>[record])  <span class="co"># Very slow!</span></span>
<span id="cb68-4"><a href="#cb68-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb68-5"><a href="#cb68-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Good: insert in batches</span></span>
<span id="cb68-6"><a href="#cb68-6" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> chunked(all_records, batch_size<span class="op">=</span><span class="dv">2000</span>):</span>
<span id="cb68-7"><a href="#cb68-7" aria-hidden="true" tabindex="-1"></a>    client.insert(collection_name<span class="op">=</span><span class="st">"..."</span>, data<span class="op">=</span>batch)</span>
<span id="cb68-8"><a href="#cb68-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb68-9"><a href="#cb68-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Even better: use multiple workers</span></span>
<span id="cb68-10"><a href="#cb68-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor</span>
<span id="cb68-11"><a href="#cb68-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb68-12"><a href="#cb68-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> embed_and_insert(batch):</span>
<span id="cb68-13"><a href="#cb68-13" aria-hidden="true" tabindex="-1"></a>    embeddings <span class="op">=</span> embed_batch([r[<span class="st">"path"</span>] <span class="cf">for</span> r <span class="kw">in</span> batch])</span>
<span id="cb68-14"><a href="#cb68-14" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> [{<span class="st">"embedding"</span>: emb, <span class="op">**</span>meta} <span class="cf">for</span> emb, meta <span class="kw">in</span> <span class="bu">zip</span>(embeddings, batch)]</span>
<span id="cb68-15"><a href="#cb68-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> client.insert(collection_name<span class="op">=</span><span class="st">"..."</span>, data<span class="op">=</span>data)</span>
<span id="cb68-16"><a href="#cb68-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb68-17"><a href="#cb68-17" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb68-18"><a href="#cb68-18" aria-hidden="true" tabindex="-1"></a>    futures <span class="op">=</span> [executor.submit(embed_and_insert, batch) <span class="cf">for</span> batch <span class="kw">in</span> batches]</span></code></pre></div></div>
</section>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">16.3 Memory Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb69"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb69-1"><a href="#cb69-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Load collection into memory before querying</span></span>
<span id="cb69-2"><a href="#cb69-2" aria-hidden="true" tabindex="-1"></a>client.load_collection(<span class="st">"image_embeddings"</span>)</span>
<span id="cb69-3"><a href="#cb69-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb69-4"><a href="#cb69-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Release collection from memory when not actively querying</span></span>
<span id="cb69-5"><a href="#cb69-5" aria-hidden="true" tabindex="-1"></a>client.release_collection(<span class="st">"image_embeddings"</span>)</span>
<span id="cb69-6"><a href="#cb69-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb69-7"><a href="#cb69-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load only specific partitions into memory</span></span>
<span id="cb69-8"><a href="#cb69-8" aria-hidden="true" tabindex="-1"></a>client.load_partitions(<span class="st">"image_embeddings"</span>, partition_names<span class="op">=</span>[<span class="st">"2024"</span>])</span></code></pre></div></div>
</section>
<section id="query-cache" class="level3">
<h3 class="anchored" data-anchor-id="query-cache" id="query-cache">16.4 Query Cache</h3>
<p>For repeated identical queries, cache results at the application level:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb70"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb70-1"><a href="#cb70-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> hashlib</span>
<span id="cb70-2"><a href="#cb70-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb70-3"><a href="#cb70-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-4"><a href="#cb70-4" aria-hidden="true" tabindex="-1"></a>_search_cache <span class="op">=</span> {}</span>
<span id="cb70-5"><a href="#cb70-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-6"><a href="#cb70-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cached_search(collection_name, vector, limit, <span class="bu">filter</span><span class="op">=</span><span class="va">None</span>, ttl_seconds<span class="op">=</span><span class="dv">300</span>):</span>
<span id="cb70-7"><a href="#cb70-7" aria-hidden="true" tabindex="-1"></a>    vec_bytes <span class="op">=</span> json.dumps([<span class="bu">round</span>(v, <span class="dv">6</span>) <span class="cf">for</span> v <span class="kw">in</span> vector]).encode()</span>
<span id="cb70-8"><a href="#cb70-8" aria-hidden="true" tabindex="-1"></a>    cache_key <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>collection_name<span class="sc">}</span><span class="ss">:</span><span class="sc">{</span>hashlib<span class="sc">.</span>sha256(vec_bytes)<span class="sc">.</span>hexdigest()<span class="sc">}</span><span class="ss">:</span><span class="sc">{</span>limit<span class="sc">}</span><span class="ss">:</span><span class="sc">{</span><span class="bu">filter</span><span class="sc">}</span><span class="ss">"</span></span>
<span id="cb70-9"><a href="#cb70-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-10"><a href="#cb70-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> cache_key <span class="kw">in</span> _search_cache:</span>
<span id="cb70-11"><a href="#cb70-11" aria-hidden="true" tabindex="-1"></a>        cached_result, cached_at <span class="op">=</span> _search_cache[cache_key]</span>
<span id="cb70-12"><a href="#cb70-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> time.time() <span class="op">-</span> cached_at <span class="op">&lt;</span> ttl_seconds:</span>
<span id="cb70-13"><a href="#cb70-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> cached_result</span>
<span id="cb70-14"><a href="#cb70-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-15"><a href="#cb70-15" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> client.search(</span>
<span id="cb70-16"><a href="#cb70-16" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>collection_name,</span>
<span id="cb70-17"><a href="#cb70-17" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[vector],</span>
<span id="cb70-18"><a href="#cb70-18" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>limit,</span>
<span id="cb70-19"><a href="#cb70-19" aria-hidden="true" tabindex="-1"></a>        <span class="bu">filter</span><span class="op">=</span><span class="bu">filter</span>,</span>
<span id="cb70-20"><a href="#cb70-20" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb70-21"><a href="#cb70-21" aria-hidden="true" tabindex="-1"></a>    _search_cache[cache_key] <span class="op">=</span> (result, time.time())</span>
<span id="cb70-22"><a href="#cb70-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result</span></code></pre></div></div>
</section>
<section id="monitoring-query-performance" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-query-performance" id="monitoring-query-performance">16.5 Monitoring Query Performance</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb71"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb71-1"><a href="#cb71-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb71-2"><a href="#cb71-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb71-3"><a href="#cb71-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timed_search(client, collection_name, query_vector, limit<span class="op">=</span><span class="dv">10</span>, <span class="op">**</span>kwargs):</span>
<span id="cb71-4"><a href="#cb71-4" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.perf_counter()</span>
<span id="cb71-5"><a href="#cb71-5" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> client.search(</span>
<span id="cb71-6"><a href="#cb71-6" aria-hidden="true" tabindex="-1"></a>        collection_name<span class="op">=</span>collection_name,</span>
<span id="cb71-7"><a href="#cb71-7" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>[query_vector],</span>
<span id="cb71-8"><a href="#cb71-8" aria-hidden="true" tabindex="-1"></a>        limit<span class="op">=</span>limit,</span>
<span id="cb71-9"><a href="#cb71-9" aria-hidden="true" tabindex="-1"></a>        <span class="op">**</span>kwargs,</span>
<span id="cb71-10"><a href="#cb71-10" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb71-11"><a href="#cb71-11" aria-hidden="true" tabindex="-1"></a>    latency_ms <span class="op">=</span> (time.perf_counter() <span class="op">-</span> start) <span class="op">*</span> <span class="dv">1000</span></span>
<span id="cb71-12"><a href="#cb71-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Search latency: </span><span class="sc">{</span>latency_ms<span class="sc">:.2f}</span><span class="ss">ms | Results: </span><span class="sc">{</span><span class="bu">len</span>(results[<span class="dv">0</span>])<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb71-13"><a href="#cb71-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results, latency_ms</span></code></pre></div></div>
</section>
<section id="schema-design-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="schema-design-best-practices" id="schema-design-best-practices">16.6 Schema Design Best Practices</h3>
<ul>
<li><strong>Minimize the number of fields.</strong> Each additional field adds memory overhead and slows insertion.</li>
<li><strong>Use <code>enable_dynamic_field=True</code> cautiously.</strong> Dynamic fields are stored as JSON internally, which is slower to filter than typed fields.</li>
<li><strong>Use INT64 for timestamps</strong>, not VARCHAR. Numeric comparisons are much faster.</li>
<li><strong>Normalize your vectors</strong> before insertion. Non-normalized vectors with cosine metric produce incorrect results.</li>
<li><strong>Choose appropriate VARCHAR lengths.</strong> Don’t set <code>max_length=65535</code> for short strings.</li>
</ul>
<hr>
</section>
</section>
<section id="sec-security" class="level2">
<h2 class="anchored" data-anchor-id="sec-security" id="sec-security">17. Security and Access Control</h2>
<section id="authentication" class="level3">
<h3 class="anchored" data-anchor-id="authentication" id="authentication">Authentication</h3>
<p>Enable authentication on your Milvus instance to prevent unauthorized access.</p>
<p><strong>In <code>docker-compose.yml</code>:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb72"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb72-1"><a href="#cb72-1" aria-hidden="true" tabindex="-1"></a><span class="fu">standalone</span><span class="kw">:</span></span>
<span id="cb72-2"><a href="#cb72-2" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb72-3"><a href="#cb72-3" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">COMMON_SECURITY_AUTHORIZATIONENABLED</span><span class="kw">:</span><span class="at"> </span><span class="st">"true"</span></span></code></pre></div></div>
<p><strong>In Python:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb73"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb73-1"><a href="#cb73-1" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(</span>
<span id="cb73-2"><a href="#cb73-2" aria-hidden="true" tabindex="-1"></a>    uri<span class="op">=</span><span class="st">"http://localhost:19530"</span>,</span>
<span id="cb73-3"><a href="#cb73-3" aria-hidden="true" tabindex="-1"></a>    token<span class="op">=</span><span class="st">"root:Milvus"</span>,</span>
<span id="cb73-4"><a href="#cb73-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb73-5"><a href="#cb73-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb73-6"><a href="#cb73-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a new user</span></span>
<span id="cb73-7"><a href="#cb73-7" aria-hidden="true" tabindex="-1"></a>client.create_user(user_name<span class="op">=</span><span class="st">"cv_app_user"</span>, password<span class="op">=</span><span class="st">"StrongP@ssword123"</span>)</span>
<span id="cb73-8"><a href="#cb73-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb73-9"><a href="#cb73-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Grant a role</span></span>
<span id="cb73-10"><a href="#cb73-10" aria-hidden="true" tabindex="-1"></a>client.grant_role(user_name<span class="op">=</span><span class="st">"cv_app_user"</span>, role_name<span class="op">=</span><span class="st">"db_ro"</span>)</span>
<span id="cb73-11"><a href="#cb73-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb73-12"><a href="#cb73-12" aria-hidden="true" tabindex="-1"></a><span class="co"># List users</span></span>
<span id="cb73-13"><a href="#cb73-13" aria-hidden="true" tabindex="-1"></a>client.list_users()</span></code></pre></div></div>
</section>
<section id="role-based-access-control-rbac" class="level3">
<h3 class="anchored" data-anchor-id="role-based-access-control-rbac" id="role-based-access-control-rbac">Role-Based Access Control (RBAC)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb74"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb74-1"><a href="#cb74-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a custom role</span></span>
<span id="cb74-2"><a href="#cb74-2" aria-hidden="true" tabindex="-1"></a>client.create_role(role_name<span class="op">=</span><span class="st">"cv_readonly"</span>)</span>
<span id="cb74-3"><a href="#cb74-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb74-4"><a href="#cb74-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Grant specific privileges</span></span>
<span id="cb74-5"><a href="#cb74-5" aria-hidden="true" tabindex="-1"></a>client.grant_privilege(</span>
<span id="cb74-6"><a href="#cb74-6" aria-hidden="true" tabindex="-1"></a>    role_name<span class="op">=</span><span class="st">"cv_readonly"</span>,</span>
<span id="cb74-7"><a href="#cb74-7" aria-hidden="true" tabindex="-1"></a>    object_type<span class="op">=</span><span class="st">"Collection"</span>,</span>
<span id="cb74-8"><a href="#cb74-8" aria-hidden="true" tabindex="-1"></a>    object_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb74-9"><a href="#cb74-9" aria-hidden="true" tabindex="-1"></a>    privilege<span class="op">=</span><span class="st">"Query"</span>,</span>
<span id="cb74-10"><a href="#cb74-10" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb74-11"><a href="#cb74-11" aria-hidden="true" tabindex="-1"></a>client.grant_privilege(</span>
<span id="cb74-12"><a href="#cb74-12" aria-hidden="true" tabindex="-1"></a>    role_name<span class="op">=</span><span class="st">"cv_readonly"</span>,</span>
<span id="cb74-13"><a href="#cb74-13" aria-hidden="true" tabindex="-1"></a>    object_type<span class="op">=</span><span class="st">"Collection"</span>,</span>
<span id="cb74-14"><a href="#cb74-14" aria-hidden="true" tabindex="-1"></a>    object_name<span class="op">=</span><span class="st">"image_embeddings"</span>,</span>
<span id="cb74-15"><a href="#cb74-15" aria-hidden="true" tabindex="-1"></a>    privilege<span class="op">=</span><span class="st">"Search"</span>,</span>
<span id="cb74-16"><a href="#cb74-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb74-17"><a href="#cb74-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb74-18"><a href="#cb74-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Assign role to user</span></span>
<span id="cb74-19"><a href="#cb74-19" aria-hidden="true" tabindex="-1"></a>client.grant_role(user_name<span class="op">=</span><span class="st">"cv_app_user"</span>, role_name<span class="op">=</span><span class="st">"cv_readonly"</span>)</span></code></pre></div></div>
</section>
<section id="tlsssl-encryption" class="level3">
<h3 class="anchored" data-anchor-id="tlsssl-encryption" id="tlsssl-encryption">TLS/SSL Encryption</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb75"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb75-1"><a href="#cb75-1" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MilvusClient(</span>
<span id="cb75-2"><a href="#cb75-2" aria-hidden="true" tabindex="-1"></a>    uri<span class="op">=</span><span class="st">"https://milvus.example.com:19530"</span>,</span>
<span id="cb75-3"><a href="#cb75-3" aria-hidden="true" tabindex="-1"></a>    token<span class="op">=</span><span class="st">"username:password"</span>,</span>
<span id="cb75-4"><a href="#cb75-4" aria-hidden="true" tabindex="-1"></a>    server_pem_path<span class="op">=</span><span class="st">"/path/to/server.pem"</span>,</span>
<span id="cb75-5"><a href="#cb75-5" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-monitoring" class="level2">
<h2 class="anchored" data-anchor-id="sec-monitoring" id="sec-monitoring">18. Monitoring and Observability</h2>
<section id="milvus-metrics" class="level3">
<h3 class="anchored" data-anchor-id="milvus-metrics" id="milvus-metrics">Milvus Metrics</h3>
<p>Milvus exposes Prometheus metrics at <code>http://milvus-host:9091/metrics</code>. Key metrics to monitor:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Metric</th>
<th>Description</th>
<th>Alert if</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>milvus_proxy_search_latency_sum</code></td>
<td>Search latency</td>
<td>p99 &gt; 500ms</td>
</tr>
<tr class="even">
<td><code>milvus_querynode_collection_num</code></td>
<td>Collections loaded</td>
<td>High</td>
</tr>
<tr class="odd">
<td><code>milvus_datanode_insert_buffer_size</code></td>
<td>Insert buffer size</td>
<td>Near limit</td>
</tr>
<tr class="even">
<td><code>milvus_rootcoord_proxy_num</code></td>
<td>Active proxies</td>
<td>Drops to 0</td>
</tr>
<tr class="odd">
<td><code>milvus_segment_count</code></td>
<td>Total segments</td>
<td>Monitor growth</td>
</tr>
</tbody>
</table>
</section>
<section id="setting-up-grafana-dashboard" class="level3">
<h3 class="anchored" data-anchor-id="setting-up-grafana-dashboard" id="setting-up-grafana-dashboard">Setting Up Grafana Dashboard</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb76"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb76-1"><a href="#cb76-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Import the official Milvus dashboard (ID: 19120 on grafana.com)</span></span>
<span id="cb76-2"><a href="#cb76-2" aria-hidden="true" tabindex="-1"></a><span class="fu">wget</span> https://raw.githubusercontent.com/milvus-io/milvus/master/deployments/monitoring/grafana/milvus-dashboard.json</span></code></pre></div></div>
</section>
<section id="application-level-monitoring" class="level3">
<h3 class="anchored" data-anchor-id="application-level-monitoring" id="application-level-monitoring">Application-Level Monitoring</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb77"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb77-1"><a href="#cb77-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb77-2"><a href="#cb77-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> defaultdict</span>
<span id="cb77-3"><a href="#cb77-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> statistics <span class="im">import</span> mean, quantiles</span>
<span id="cb77-4"><a href="#cb77-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb77-5"><a href="#cb77-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MilvusMonitor:</span>
<span id="cb77-6"><a href="#cb77-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb77-7"><a href="#cb77-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.latencies <span class="op">=</span> defaultdict(<span class="bu">list</span>)</span>
<span id="cb77-8"><a href="#cb77-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.error_counts <span class="op">=</span> defaultdict(<span class="bu">int</span>)</span>
<span id="cb77-9"><a href="#cb77-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb77-10"><a href="#cb77-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> record_search(<span class="va">self</span>, collection: <span class="bu">str</span>, latency_ms: <span class="bu">float</span>, success: <span class="bu">bool</span>):</span>
<span id="cb77-11"><a href="#cb77-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> success:</span>
<span id="cb77-12"><a href="#cb77-12" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.latencies[collection].append(latency_ms)</span>
<span id="cb77-13"><a href="#cb77-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb77-14"><a href="#cb77-14" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.error_counts[collection] <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb77-15"><a href="#cb77-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb77-16"><a href="#cb77-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> report(<span class="va">self</span>):</span>
<span id="cb77-17"><a href="#cb77-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> collection, lats <span class="kw">in</span> <span class="va">self</span>.latencies.items():</span>
<span id="cb77-18"><a href="#cb77-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="kw">not</span> lats:</span>
<span id="cb77-19"><a href="#cb77-19" aria-hidden="true" tabindex="-1"></a>                <span class="cf">continue</span></span>
<span id="cb77-20"><a href="#cb77-20" aria-hidden="true" tabindex="-1"></a>            p50 <span class="op">=</span> quantiles(lats, n<span class="op">=</span><span class="dv">100</span>)[<span class="dv">49</span>]</span>
<span id="cb77-21"><a href="#cb77-21" aria-hidden="true" tabindex="-1"></a>            p95 <span class="op">=</span> quantiles(lats, n<span class="op">=</span><span class="dv">100</span>)[<span class="dv">94</span>]</span>
<span id="cb77-22"><a href="#cb77-22" aria-hidden="true" tabindex="-1"></a>            p99 <span class="op">=</span> quantiles(lats, n<span class="op">=</span><span class="dv">100</span>)[<span class="dv">98</span>]</span>
<span id="cb77-23"><a href="#cb77-23" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Collection: </span><span class="sc">{</span>collection<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb77-24"><a href="#cb77-24" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  Searches: </span><span class="sc">{</span><span class="bu">len</span>(lats)<span class="sc">}</span><span class="ss">, Errors: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>error_counts[collection]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb77-25"><a href="#cb77-25" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  Latency — mean: </span><span class="sc">{</span>mean(lats)<span class="sc">:.1f}</span><span class="ss">ms, p50: </span><span class="sc">{</span>p50<span class="sc">:.1f}</span><span class="ss">ms, "</span></span>
<span id="cb77-26"><a href="#cb77-26" aria-hidden="true" tabindex="-1"></a>                  <span class="ss">f"p95: </span><span class="sc">{</span>p95<span class="sc">:.1f}</span><span class="ss">ms, p99: </span><span class="sc">{</span>p99<span class="sc">:.1f}</span><span class="ss">ms"</span>)</span>
<span id="cb77-27"><a href="#cb77-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb77-28"><a href="#cb77-28" aria-hidden="true" tabindex="-1"></a>monitor <span class="op">=</span> MilvusMonitor()</span>
<span id="cb77-29"><a href="#cb77-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb77-30"><a href="#cb77-30" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> monitored_search(collection_name, query_vector, limit<span class="op">=</span><span class="dv">10</span>, <span class="op">**</span>kwargs):</span>
<span id="cb77-31"><a href="#cb77-31" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.perf_counter()</span>
<span id="cb77-32"><a href="#cb77-32" aria-hidden="true" tabindex="-1"></a>    success <span class="op">=</span> <span class="va">True</span></span>
<span id="cb77-33"><a href="#cb77-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb77-34"><a href="#cb77-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> client.search(collection_name<span class="op">=</span>collection_name, data<span class="op">=</span>[query_vector], limit<span class="op">=</span>limit, <span class="op">**</span>kwargs)</span>
<span id="cb77-35"><a href="#cb77-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span>:</span>
<span id="cb77-36"><a href="#cb77-36" aria-hidden="true" tabindex="-1"></a>        success <span class="op">=</span> <span class="va">False</span></span>
<span id="cb77-37"><a href="#cb77-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span></span>
<span id="cb77-38"><a href="#cb77-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">finally</span>:</span>
<span id="cb77-39"><a href="#cb77-39" aria-hidden="true" tabindex="-1"></a>        monitor.record_search(collection_name, (time.perf_counter() <span class="op">-</span> start) <span class="op">*</span> <span class="dv">1000</span>, success)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-pitfalls" class="level2">
<h2 class="anchored" data-anchor-id="sec-pitfalls" id="sec-pitfalls">19. Common Pitfalls and How to Avoid Them</h2>
<section id="pitfall-1-mismatched-embedding-dimensions" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-1-mismatched-embedding-dimensions" id="pitfall-1-mismatched-embedding-dimensions">Pitfall 1: Mismatched Embedding Dimensions</h3>
<p><strong>Problem:</strong> You created a collection with <code>dim=512</code> but insert vectors of size 768. Milvus rejects the insert with a dimension mismatch error.</p>
<p><strong>Solution:</strong> Always assert dimensions before inserting:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb78"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb78-1"><a href="#cb78-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_insert(client, collection_name, data, expected_dim):</span>
<span id="cb78-2"><a href="#cb78-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> entity <span class="kw">in</span> data:</span>
<span id="cb78-3"><a href="#cb78-3" aria-hidden="true" tabindex="-1"></a>        vec <span class="op">=</span> entity.get(<span class="st">"embedding"</span>, [])</span>
<span id="cb78-4"><a href="#cb78-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> <span class="bu">len</span>(vec) <span class="op">==</span> expected_dim, <span class="ss">f"Expected dim </span><span class="sc">{</span>expected_dim<span class="sc">}</span><span class="ss">, got </span><span class="sc">{</span><span class="bu">len</span>(vec)<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb78-5"><a href="#cb78-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> client.insert(collection_name<span class="op">=</span>collection_name, data<span class="op">=</span>data)</span></code></pre></div></div>
</section>
<section id="pitfall-2-searching-before-loading" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-2-searching-before-loading" id="pitfall-2-searching-before-loading">Pitfall 2: Searching Before Loading</h3>
<p><strong>Problem:</strong> In older Milvus / ORM-style API, collections must be explicitly loaded into memory before searching.</p>
<p><strong>Solution:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb79"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb79-1"><a href="#cb79-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pymilvus <span class="im">import</span> Collection</span>
<span id="cb79-2"><a href="#cb79-2" aria-hidden="true" tabindex="-1"></a>col <span class="op">=</span> Collection(<span class="st">"image_embeddings"</span>)</span>
<span id="cb79-3"><a href="#cb79-3" aria-hidden="true" tabindex="-1"></a>col.load()</span></code></pre></div></div>
</section>
<section id="pitfall-3-not-normalizing-vectors-for-cosineip-metrics" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-3-not-normalizing-vectors-for-cosineip-metrics" id="pitfall-3-not-normalizing-vectors-for-cosineip-metrics">Pitfall 3: Not Normalizing Vectors for Cosine/IP Metrics</h3>
<p><strong>Problem:</strong> Using cosine or IP metric with unnormalized vectors gives incorrect similarity scores.</p>
<p><strong>Solution:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb80"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb80-1"><a href="#cb80-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb80-2"><a href="#cb80-2" aria-hidden="true" tabindex="-1"></a>vec <span class="op">=</span> np.array(raw_embedding)</span>
<span id="cb80-3"><a href="#cb80-3" aria-hidden="true" tabindex="-1"></a>vec <span class="op">=</span> vec <span class="op">/</span> np.linalg.norm(vec)</span></code></pre></div></div>
</section>
<section id="pitfall-4-setting-nprobe-too-low-ivf-indexes" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-4-setting-nprobe-too-low-ivf-indexes" id="pitfall-4-setting-nprobe-too-low-ivf-indexes">Pitfall 4: Setting <code>nprobe</code> Too Low (IVF Indexes)</h3>
<p><strong>Problem:</strong> Low <code>nprobe</code> (e.g., 1 or 2) with IVF indexes causes very poor recall.</p>
<p><strong>Solution:</strong> Start with <code>nprobe = nlist / 16</code> and benchmark recall. Never use <code>nprobe=1</code> in production without measurement.</p>
</section>
<section id="pitfall-5-growing-segments-and-slow-queries-on-fresh-data" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-5-growing-segments-and-slow-queries-on-fresh-data" id="pitfall-5-growing-segments-and-slow-queries-on-fresh-data">Pitfall 5: Growing Segments and Slow Queries on Fresh Data</h3>
<p><strong>Problem:</strong> Freshly inserted data sits in unsealed “growing segments” that are brute-force searched.</p>
<p><strong>Solution:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb81"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb81-1"><a href="#cb81-1" aria-hidden="true" tabindex="-1"></a>client.flush(<span class="st">"image_embeddings"</span>)</span>
<span id="cb81-2"><a href="#cb81-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Then wait for index building to complete before running benchmarks</span></span></code></pre></div></div>
</section>
<section id="pitfall-6-varchar-filter-on-unindexed-fields" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-6-varchar-filter-on-unindexed-fields" id="pitfall-6-varchar-filter-on-unindexed-fields">Pitfall 6: VARCHAR Filter on Unindexed Fields</h3>
<p><strong>Problem:</strong> Filtering on a VARCHAR field without a scalar index forces a full scan.</p>
<p><strong>Solution:</strong> Always create scalar indexes on fields you filter by:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb82"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb82-1"><a href="#cb82-1" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"label"</span>, index_type<span class="op">=</span><span class="st">"Trie"</span>)</span>
<span id="cb82-2"><a href="#cb82-2" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"score"</span>, index_type<span class="op">=</span><span class="st">"STL_SORT"</span>)</span>
<span id="cb82-3"><a href="#cb82-3" aria-hidden="true" tabindex="-1"></a>index_params.add_index(field_name<span class="op">=</span><span class="st">"flags"</span>, index_type<span class="op">=</span><span class="st">"BITMAP"</span>)</span></code></pre></div></div>
</section>
<section id="pitfall-7-using-auto_idfalse-without-providing-unique-ids" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-7-using-auto_idfalse-without-providing-unique-ids" id="pitfall-7-using-auto_idfalse-without-providing-unique-ids">Pitfall 7: Using <code>auto_id=False</code> Without Providing Unique IDs</h3>
<p><strong>Problem:</strong> Inserting duplicate IDs causes errors or silent overwrites.</p>
<p><strong>Solution:</strong> Use <code>auto_id=True</code> unless you have a strong reason to manage IDs yourself.</p>
</section>
<section id="pitfall-8-confusing-distance-values-by-metric-type" class="level3">
<h3 class="anchored" data-anchor-id="pitfall-8-confusing-distance-values-by-metric-type" id="pitfall-8-confusing-distance-values-by-metric-type">Pitfall 8: Confusing Distance Values by Metric Type</h3>
<p><strong>Problem:</strong> For L2 and COSINE, a <strong>lower</strong> distance means more similar. For IP, <strong>higher</strong> means more similar. Misinterpreting this leads to sorting results in the wrong direction.</p>
<p><strong>Solution:</strong> Trust Milvus’s returned sort order — it always returns results from most to least similar. Just be careful when comparing raw distance scores across different metric types.</p>
<hr>
</section>
</section>
<section id="sec-glossary" class="level2">
<h2 class="anchored" data-anchor-id="sec-glossary" id="sec-glossary">20. Glossary</h2>
<p><strong>ANN (Approximate Nearest Neighbor):</strong> A search approach that finds results very close to the true nearest neighbors, trading a small amount of accuracy for enormous speed gains.</p>
<p><strong>BM25:</strong> A sparse retrieval algorithm based on term frequency and inverse document frequency. Used in hybrid search alongside dense vector search.</p>
<p><strong>Collection:</strong> The top-level data container in Milvus, analogous to a table in a relational database.</p>
<p><strong>Cosine Similarity:</strong> A distance metric measuring the cosine of the angle between two vectors. Values range from -1 (opposite) to 1 (identical).</p>
<p><strong>DiskANN:</strong> A graph-based ANN index designed to work with data stored on disk rather than RAM.</p>
<p><strong>Embedding / Feature Vector:</strong> A compact numerical representation of complex data (images, text, audio) produced by a neural network. Similar inputs produce numerically close embeddings.</p>
<p><strong>Entity:</strong> A single record (row) in a Milvus collection.</p>
<p><strong>etcd:</strong> A distributed key-value store used by Milvus to store cluster metadata, configuration, and service discovery information.</p>
<p><strong>HNSW (Hierarchical Navigable Small World):</strong> A graph-based ANN index that builds a multi-layer proximity graph for fast nearest neighbor search. Generally considered the best-performing index for most use cases.</p>
<p><strong>Inner Product (IP):</strong> The dot product of two vectors. For normalized (unit) vectors, IP equals cosine similarity.</p>
<p><strong>IVF (Inverted File Index):</strong> A family of ANN indexes that clusters vectors into Voronoi cells and searches only the nearest clusters at query time.</p>
<p><strong>L2 (Euclidean Distance):</strong> The straight-line distance between two points in Euclidean space.</p>
<p><strong>MinIO:</strong> An open-source, S3-compatible object storage system used by Milvus to persist vector data and index files.</p>
<p><strong>Milvus Lite:</strong> An embedded, serverless version of Milvus that runs entirely in-process. Best for development and prototyping.</p>
<p><strong>Normalization (L2 normalization):</strong> The process of scaling a vector to have unit length (L2 norm = 1). Required for correct behavior with cosine and IP metrics.</p>
<p><strong>Partition:</strong> A logical subdivision of a Milvus collection that can be searched independently.</p>
<p><strong>Primary Key:</strong> A unique identifier for each entity in a collection.</p>
<p><strong>Product Quantization (PQ):</strong> A vector compression technique that divides vectors into sub-vectors and quantizes each independently.</p>
<p><strong>PyMilvus:</strong> The official Python SDK for Milvus.</p>
<p><strong>Recall@K:</strong> The fraction of the true K nearest neighbors that appear in the returned K results.</p>
<p><strong>Scalar Field:</strong> A non-vector field in a Milvus schema used for metadata storage and filtered search.</p>
<p><strong>Schema:</strong> The definition of the fields and their data types in a Milvus collection.</p>
<p><strong>Segment:</strong> An internal data chunk within a Milvus collection. Growing segments hold new data; sealed segments are immutable and fully indexed.</p>
<p><strong>Sparse Vector:</strong> A vector representation where most values are zero, stored as a list of (index, value) pairs.</p>
<p><strong>UPSERT:</strong> An operation that inserts an entity if it does not exist, or updates it if it does.</p>
<p><strong>Vector Database:</strong> A specialized database designed to store, index, and efficiently search high-dimensional vector embeddings using approximate nearest neighbor algorithms.</p>
<hr>
<p><em>Guide version: May 2026 | Milvus 2.4.x+ | PyMilvus 2.4.x+</em></p>
<p><em>For the latest Milvus documentation, visit <a href="https://milvus.io/docs">milvus.io/docs</a></em></p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[TIPSv2: Advancing Vision-Language Pretraining with Enhanced Patch-Text Alignment]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/tips/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/tips/</guid>
      <pubDate>Thu, 23 Apr 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="tipsv2-advancing-vision-language-pretraining-with-enhanced-patch-text-alignment" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/tips/tips.png" class="img-fluid"></p>
<section id="introduction-and-motivation" class="level2">
<h2 class="anchored" data-anchor-id="introduction-and-motivation" id="introduction-and-motivation">Introduction and Motivation</h2>
<p>Vision-language models (VLMs) have become a cornerstone of modern computer vision and multimodal AI. Systems like CLIP, SigLIP, ALIGN, and their descendants have demonstrated remarkable capability at associating images with textual descriptions, enabling zero-shot classification, cross-modal retrieval, and a growing ecosystem of downstream multimodal tasks. However, despite their strong global image-text alignment abilities, these models share a common and often underappreciated weakness: <strong>they fail to align individual image patches with the corresponding textual concepts</strong>.</p>
<p>This limitation is not merely academic. In applications such as semantic segmentation, object detection, depth estimation, visual question answering, and referring expression comprehension, the model must understand <em>where</em> in an image a concept lives, not merely <em>whether</em> a concept is present. A model that can recognize “a dog” in a scene but cannot precisely localize the dog’s spatial extent in the feature space is fundamentally limited for such dense understanding tasks.</p>
<p><strong>TIPSv2</strong> — short for the second generation of <em>Text-Image Pretraining with Spatial Awareness</em> — is a foundational vision-language model family developed by Google DeepMind that directly and systematically addresses this challenge. Accepted at CVPR 2026, TIPSv2 introduces three carefully designed innovations — <strong>iBOT++</strong>, <strong>Head-only EMA</strong>, and <strong>Multi-Granularity Captions</strong> — that together yield dramatic improvements in dense patch-text alignment without sacrificing global representation quality. The result is a model family that achieves state-of-the-art performance across a remarkably broad suite of tasks, including zero-shot semantic segmentation, monocular depth estimation, image-text retrieval, and standard image classification.</p>
<p>What makes TIPSv2 particularly compelling is that its central innovations were not conceived in a vacuum. They arose from a counter-intuitive empirical observation uncovered during controlled experiments with knowledge distillation — a finding that then inspired the core design of the pretraining objective.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Accepted at CVPR 2026</strong> | arXiv: <a href="https://arxiv.org/abs/2604.12012">2604.12012</a> | Project Page: <a href="https://gdm-tipsv2.github.io/">gdm-tipsv2.github.io</a> | Code: <a href="https://github.com/google-deepmind/tips">google-deepmind/tips</a></p>
</div>
</div>
<hr>
</section>
<section id="background-the-tips-lineage" class="level2">
<h2 class="anchored" data-anchor-id="background-the-tips-lineage" id="background-the-tips-lineage">Background: The TIPS Lineage</h2>
<p>To appreciate TIPSv2 fully, it is essential to understand its predecessor, <strong>TIPS (Text-Image Pretraining with Spatial Awareness)</strong>, which was published at <strong>ICLR 2025</strong>.</p>
<section id="what-tips-v1-did" class="level3">
<h3 class="anchored" data-anchor-id="what-tips-v1-did" id="what-tips-v1-did">What TIPS (v1) Did</h3>
<p>The original TIPS model identified a fundamental problem with standard contrastive vision-language pretraining: models trained with objectives like CLIP’s InfoNCE loss operate at the level of global image embeddings, aggregating all spatial information into a single vector. While this is excellent for global classification and retrieval, the resulting patch-level features are not aligned with text in any explicit way — they tend to be entangled and spatially incoherent.</p>
<p>TIPS addressed this in two main ways:</p>
<p><strong>Synthetic Caption Replacement.</strong> Rather than training on raw, noisy web-scraped image-caption pairs, TIPS replaced these captions with synthetically generated textual descriptions produced by capable captioning models. These synthetic captions are semantically richer, more spatially descriptive, and significantly less noisy than typical alt-text from the web.</p>
<p><strong>Combining Contrastive and Masked Image Modeling.</strong> TIPS combined CLIP-style contrastive learning (for global image-text alignment) with masked image modeling (MIM) in the style of iBOT (Image BERT Pre-Training with Online Tokenizer). The MIM component encourages the model to develop spatially coherent patch representations, since it must reconstruct masked patches from the remaining visible context.</p>
<p>Together, these two ideas yielded a model validated on a comprehensive suite of 9 tasks and 20 datasets, displaying strong performance that matched or exceeded other recent vision encoders — particularly on dense spatial understanding tasks.</p>
</section>
<section id="what-tipsv2-builds-upon" class="level3">
<h3 class="anchored" data-anchor-id="what-tipsv2-builds-upon" id="what-tipsv2-builds-upon">What TIPSv2 Builds Upon</h3>
<p>TIPSv2 inherits the foundational ideas of TIPS but goes significantly further. Rather than simply scaling up or tuning existing components, the TIPSv2 team performed careful analysis that led to three orthogonal, complementary innovations. Each innovation is principled and theoretically motivated, and the ablations demonstrate that each contributes meaningfully to the final performance.</p>
<hr>
</section>
</section>
<section id="the-core-problem-dense-patch-text-misalignment" class="level2">
<h2 class="anchored" data-anchor-id="the-core-problem-dense-patch-text-misalignment" id="the-core-problem-dense-patch-text-misalignment">The Core Problem: Dense Patch-Text Misalignment</h2>
<section id="understanding-patch-level-representations" class="level3">
<h3 class="anchored" data-anchor-id="understanding-patch-level-representations" id="understanding-patch-level-representations">Understanding Patch-Level Representations</h3>
<p>In Vision Transformer (ViT) based architectures, an image is divided into a grid of non-overlapping patches (e.g., 14×14 pixel patches, giving 256 patch tokens for a 224×224 image at ViT-B/14 resolution). These patch tokens, along with a <code>[CLS]</code> token representing the global image embedding, are processed by self-attention layers to produce final representations.</p>
<p>In a globally-trained model like CLIP, the <code>[CLS]</code> token embedding is explicitly trained to align with text embeddings via contrastive loss. The individual patch tokens, however, receive no direct text supervision.</p>
</section>
<section id="why-this-matters-in-practice" class="level3">
<h3 class="anchored" data-anchor-id="why-this-matters-in-practice" id="why-this-matters-in-practice">Why This Matters in Practice</h3>
<p>The consequence of patch-text misalignment is measurable. When one visualizes the feature similarity maps of CLIP patch tokens with respect to text queries, the resulting maps tend to be diffuse and spatially incoherent.</p>
<p>This directly limits performance on:</p>
<ul>
<li><strong>Semantic segmentation</strong> — requires associating region-level features with class names</li>
<li><strong>Object detection</strong> — requires localizing objects within a spatial grid</li>
<li><strong>Depth estimation</strong> — requires per-pixel feature quality and coherence</li>
<li><strong>Open-vocabulary dense prediction</strong> — requires generalizable patch-level semantics</li>
</ul>
</section>
<section id="prior-approaches-and-their-limits" class="level3">
<h3 class="anchored" data-anchor-id="prior-approaches-and-their-limits" id="prior-approaches-and-their-limits">Prior Approaches and Their Limits</h3>
<p>Several prior works have attempted to improve spatial understanding in vision-language models:</p>
<ul>
<li><strong>DINOv2</strong> introduced self-supervised pretraining with excellent spatial features but lacks text alignment, limiting its utility for language-grounded tasks.</li>
<li><strong>SILC</strong> and related works combine self-supervised and image-text objectives but with limited patch-level text supervision.</li>
<li><strong>RegionCLIP, CLIPSelf, MaskCLIP</strong> propose post-hoc or fine-tuning-based approaches to improve patch-level features, but do not address the fundamental gap at pretraining.</li>
</ul>
<p>TIPSv2’s contribution is to solve this problem <strong>directly during pretraining</strong>, in a principled way that is computationally tractable and scalable.</p>
<hr>
</section>
</section>
<section id="a-surprising-discovery-the-distillation-phenomenon" class="level2">
<h2 class="anchored" data-anchor-id="a-surprising-discovery-the-distillation-phenomenon" id="a-surprising-discovery-the-distillation-phenomenon">A Surprising Discovery: The Distillation Phenomenon</h2>
<section id="the-observation" class="level3">
<h3 class="anchored" data-anchor-id="the-observation" id="the-observation">The Observation</h3>
<p>A central motivation for TIPSv2’s design choices is an unexpected empirical discovery made during exploratory experiments with knowledge distillation. The TIPSv2 authors trained student models to distill representations from a large teacher model (ViT-g) at the <strong>patch level</strong> — the student was trained to reproduce the teacher’s patch token representations, not just the global <code>[CLS]</code> embedding.</p>
<p>The result was striking: <strong>the patch-level text alignment of the distilled student model substantially surpassed that of the teacher model</strong>.</p>
<p>This is a counter-intuitive finding. Naively, one would expect distillation to produce a student that approximates but does not exceed the teacher. Yet patch-level distillation acted as a powerful regularizer that forced the student to develop more semantically coherent, text-aligned patch representations than the teacher ever had.</p>
</section>
<section id="why-does-this-happen" class="level3">
<h3 class="anchored" data-anchor-id="why-does-this-happen" id="why-does-this-happen">Why Does This Happen?</h3>
<p>The authors’ interpretation is that patch-level distillation imposes a strong constraint: the student must make every patch token predictive and consistent. The distillation loss penalizes any patch token that is not representationally coherent with the corresponding patch in the teacher’s embedding space. Combined with the text supervision inherited from the teacher, this pushes the student’s patch representations toward semantic clusters that correspond to recognizable visual concepts.</p>
<p>In essence, patch-level distillation acts like a <strong>spatial regularizer</strong> that promotes the emergence of text-aligned, spatially coherent patch features.</p>
</section>
<section id="the-design-insight" class="level3">
<h3 class="anchored" data-anchor-id="the-design-insight" id="the-design-insight">The Design Insight</h3>
<p>This discovery raised an obvious and actionable question: if patch-level distillation produces better patch-text alignment than the teacher itself, can we design a pretraining objective that mimics this effect <strong>without</strong> requiring a separate distillation stage?</p>
<p>The answer is <strong>yes</strong>, and this insight is the genesis of <strong>iBOT++</strong>, TIPSv2’s first and most impactful innovation.</p>
<hr>
</section>
</section>
<section id="tipsv2-architecture-and-model-family" class="level2">
<h2 class="anchored" data-anchor-id="tipsv2-architecture-and-model-family" id="tipsv2-architecture-and-model-family">TIPSv2 Architecture and Model Family</h2>
<section id="vision-transformer-backbone" class="level3">
<h3 class="anchored" data-anchor-id="vision-transformer-backbone" id="vision-transformer-backbone">Vision Transformer Backbone</h3>
<p>TIPSv2 uses the Vision Transformer (ViT) architecture as its image encoder across all model sizes:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 35%">
<col style="width: 65%">
</colgroup>
<thead>
<tr class="header">
<th>Model</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>ViT-B/14</strong></td>
<td>Base-sized model with 14×14 patch size (<code>tipsv2-b14</code>)</td>
</tr>
<tr class="even">
<td><strong>ViT-L/14</strong></td>
<td>Large-sized model with 14×14 patch size</td>
</tr>
<tr class="odd">
<td><strong>ViT-g/14</strong></td>
<td>Giant-sized model, the largest and highest-performing variant</td>
</tr>
<tr class="even">
<td><strong>SO-400m</strong></td>
<td>Sigmoid Loss–optimized 400M parameter variant</td>
</tr>
</tbody>
</table>
</section>
<section id="training-hierarchy" class="level3">
<h3 class="anchored" data-anchor-id="training-hierarchy" id="training-hierarchy">Training Hierarchy</h3>
<p>The model family is trained in two stages:</p>
<p><strong>Stage 1: Direct Pretraining of ViT-g.</strong> The giant model is pretrained from scratch using the full TIPSv2 objective (iBOT++, Head-only EMA, and Multi-Granularity Captions). This serves as the base teacher model.</p>
<p><strong>Stage 2: Patch-Level Distillation for Smaller Models.</strong> The ViT-B, ViT-L, and SO-400m models are trained via patch-level knowledge distillation from the ViT-g teacher — deliberately exploiting the alignment improvement originally discovered.</p>
</section>
<section id="text-encoder-and-projection-heads" class="level3">
<h3 class="anchored" data-anchor-id="text-encoder-and-projection-heads" id="text-encoder-and-projection-heads">Text Encoder and Projection Heads</h3>
<p>TIPSv2 employs a text encoder trained alongside the image encoder using contrastive objectives. Both image and text encoders attach lightweight MLP projection heads that map representations to the shared embedding space. The design of these projection heads is central to the Head-only EMA strategy.</p>
<hr>
</section>
</section>
<section id="key-innovation-1-ibot-extending-the-self-supervised-loss" class="level2">
<h2 class="anchored" data-anchor-id="key-innovation-1-ibot-extending-the-self-supervised-loss" id="key-innovation-1-ibot-extending-the-self-supervised-loss">Key Innovation 1 — iBOT++: Extending the Self-Supervised Loss</h2>
<section id="background-ibot-and-masked-image-modeling" class="level3">
<h3 class="anchored" data-anchor-id="background-ibot-and-masked-image-modeling" id="background-ibot-and-masked-image-modeling">Background: iBOT and Masked Image Modeling</h3>
<p>iBOT (Image BERT Pre-Training with Online Tokenizer) is a self-supervised pretraining technique for ViTs that combines masked image modeling with online tokenization. In standard iBOT:</p>
<ol type="1">
<li>A random subset of patches is <strong>masked</strong> (replaced with a learnable mask token).</li>
<li>The model is trained to predict the representations of masked patches, using a momentum-updated teacher as the target.</li>
<li>The <code>[CLS]</code> token is aligned across two augmented views via a self-supervised classification loss (DINO-style).</li>
</ol>
<p>The key signal comes <strong>exclusively from masked patches</strong> — visible patches do not directly contribute to the MIM loss.</p>
</section>
<section id="the-limitation-of-masking-only-supervision" class="level3">
<h3 class="anchored" data-anchor-id="the-limitation-of-masking-only-supervision" id="the-limitation-of-masking-only-supervision">The Limitation of Masking-Only Supervision</h3>
<p>This masking-only paradigm has an implicit inefficiency: at any given training step, the majority of patches (those not masked) are not contributing to the patch-level self-supervised objective. Given the distillation discovery, this is a missed opportunity — enforcing patch-level representation consistency even for visible patches dramatically improves patch-text alignment.</p>
</section>
<section id="ibot-all-tokens-contribute" class="level3">
<h3 class="anchored" data-anchor-id="ibot-all-tokens-contribute" id="ibot-all-tokens-contribute">iBOT++: All Tokens Contribute</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Important
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>iBOT++</strong> extends the patch-level self-supervised loss to <strong>ALL patch tokens</strong> — both masked and unmasked — rather than only to masked patches.</p>
</div>
</div>
<p>At each training step, the iBOT++ loss computes a representation consistency target for every patch in the image, using the momentum teacher as the target generator. This means even visible patches must align with the teacher’s patch embeddings.</p>
<p>This change:</p>
<ul>
<li>Forces semantically coherent, consistent patch representations across all spatial locations</li>
<li>Propagates dense patch-level gradients at every step</li>
<li>Mimics the effect of patch-level distillation within the pretraining loop</li>
<li>Produces dramatically smoother, more spatially coherent feature maps</li>
</ul>
</section>
<section id="quantitative-impact-of-ibot" class="level3">
<h3 class="anchored" data-anchor-id="quantitative-impact-of-ibot" id="quantitative-impact-of-ibot">Quantitative Impact of iBOT++</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>The addition of iBOT++ improved zero-shot semantic segmentation performance by <strong>+14.1 mIoU</strong> — a large gain by any standard.</p>
</div>
</div>
<p>Qualitative visualizations confirm this: iBOT++ models produce attention maps and PCA-based feature visualizations that clearly delineate object boundaries, textures, and semantic regions far more distinctly than standard iBOT-trained counterparts.</p>
</section>
<section id="why-this-works-connecting-to-the-distillation-insight" class="level3">
<h3 class="anchored" data-anchor-id="why-this-works-connecting-to-the-distillation-insight" id="why-this-works-connecting-to-the-distillation-insight">Why This Works: Connecting to the Distillation Insight</h3>
<p>The connection to the distillation discovery is direct. In distillation, all patch positions receive a loss signal. iBOT++ replicates this regime by applying the MIM-style loss to all positions. The momentum teacher plays the role of the large pre-trained teacher in the distillation setup, while teacher and student evolve jointly during pretraining via EMA updates.</p>
<hr>
</section>
</section>
<section id="key-innovation-2-head-only-ema-efficient-teacher-student-training" class="level2">
<h2 class="anchored" data-anchor-id="key-innovation-2-head-only-ema-efficient-teacher-student-training" id="key-innovation-2-head-only-ema-efficient-teacher-student-training">Key Innovation 2 — Head-Only EMA: Efficient Teacher-Student Training</h2>
<section id="the-standard-ema-teacher-in-self-supervised-learning" class="level3">
<h3 class="anchored" data-anchor-id="the-standard-ema-teacher-in-self-supervised-learning" id="the-standard-ema-teacher-in-self-supervised-learning">The Standard EMA Teacher in Self-Supervised Learning</h3>
<p>In methods like DINO and iBOT, the teacher network is maintained as an exponential moving average (EMA) of the student’s parameters:</p>
<p><span class="math display">\[\theta_t \leftarrow \lambda \cdot \theta_t + (1 - \lambda) \cdot \theta_s\]</span></p>
<p>where <span class="math inline">\(\lambda\)</span> is a momentum coefficient (typically <span class="math inline">\(\approx 0.999\)</span>).</p>
<p>This results in a stable teacher that provides high-quality targets. <strong>However</strong>, maintaining a full teacher network doubles the memory footprint and significantly increases training time.</p>
</section>
<section id="the-head-only-ema-strategy" class="level3">
<h3 class="anchored" data-anchor-id="the-head-only-ema-strategy" id="the-head-only-ema-strategy">The Head-Only EMA Strategy</h3>
<p>TIPSv2 introduces a more efficient variant: <strong>Head-only EMA</strong>. The key enabling observation is that TIPSv2 has <strong>text supervision</strong>, which fundamentally changes training dynamics compared to purely self-supervised approaches.</p>
<p>With text supervision, the contrastive image-text alignment loss provides a powerful anchor that prevents collapse — even without a full-backbone EMA. The language signal enforces that representations must remain semantically meaningful and discriminative.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Important
</div>
</div>
<div class="callout-body-container callout-body">
<p>In <strong>Head-only EMA</strong>, the EMA update is applied only to the <strong>projection heads</strong> (lightweight MLP heads), while the teacher encoder is set equal to the student encoder at each step.</p>
</div>
</div>
<p>In effect: - The teacher backbone <strong>is</strong> the student backbone (no separate copy needed for the encoder) - Only the much smaller projection heads maintain EMA-updated parameters</p>
</section>
<section id="benefits-of-head-only-ema" class="level3">
<h3 class="anchored" data-anchor-id="benefits-of-head-only-ema" id="benefits-of-head-only-ema">Benefits of Head-Only EMA</h3>
<table class="caption-top table">
<colgroup>
<col style="width: 50%">
<col style="width: 50%">
</colgroup>
<thead>
<tr class="header">
<th>Benefit</th>
<th>Details</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Memory Efficiency</strong></td>
<td>Eliminates the EMA teacher backbone copy; saves tens of GB of GPU memory for ViT-g models</td>
</tr>
<tr class="even">
<td><strong>Training Throughput</strong></td>
<td>~<strong>42% reduction in trainable parameters</strong> during training; meaningfully improved throughput</td>
</tr>
<tr class="odd">
<td><strong>Performance Retention</strong></td>
<td>Performance is comparable to full EMA, demonstrating that text supervision prevents collapse</td>
</tr>
</tbody>
</table>
</section>
<section id="connection-to-distillation" class="level3">
<h3 class="anchored" data-anchor-id="connection-to-distillation" id="connection-to-distillation">Connection to Distillation</h3>
<p>Head-only EMA is also conceptually motivated by the distillation setting: in patch-level distillation, the teacher encoder is completely fixed. Head-only EMA approximates this in spirit — encoder-level EMA is eliminated, and only the projection heads maintain temporal momentum smoothing.</p>
<hr>
</section>
</section>
<section id="key-innovation-3-multi-granularity-captions-richer-text-supervision" class="level2">
<h2 class="anchored" data-anchor-id="key-innovation-3-multi-granularity-captions-richer-text-supervision" id="key-innovation-3-multi-granularity-captions-richer-text-supervision">Key Innovation 3 — Multi-Granularity Captions: Richer Text Supervision</h2>
<section id="the-problem-with-standard-image-text-pairs" class="level3">
<h3 class="anchored" data-anchor-id="the-problem-with-standard-image-text-pairs" id="the-problem-with-standard-image-text-pairs">The Problem with Standard Image-Text Pairs</h3>
<p>Most large-scale VLMs are trained on web-sourced captions that are short, noisy alt-text strings describing only the most salient element (e.g., “a cat”) without spatial or relational detail. The original TIPS model already demonstrated that <strong>synthetic captions</strong> significantly improve representation quality.</p>
<p>TIPSv2 takes this further with a multi-granularity approach.</p>
</section>
<section id="three-levels-of-textual-granularity" class="level3">
<h3 class="anchored" data-anchor-id="three-levels-of-textual-granularity" id="three-levels-of-textual-granularity">Three Levels of Textual Granularity</h3>
<p>During TIPSv2 pretraining, each image is paired with captions at three distinct granularity levels:</p>
<p><strong>Short captions (web-scale).</strong> Brief, general descriptions of overall image content. Provide coarse global semantic signal and help the model learn broad visual-semantic associations.</p>
<p><strong>Medium-length detailed captions (PaliGemma-generated).</strong> Descriptions generated by PaliGemma naming more objects, describing attributes (color, shape, texture, size), and capturing spatial relationships. Provide a richer intermediate-level signal.</p>
<p><strong>Long, comprehensive captions (Gemini-generated).</strong> Highly detailed, multi-sentence descriptions covering fine-grained attributes, scene context, inter-object relationships, spatial layout, and subtle semantic details. The richest and most informative level.</p>
</section>
<section id="caption-sampling-strategy" class="level3">
<h3 class="anchored" data-anchor-id="caption-sampling-strategy" id="caption-sampling-strategy">Caption Sampling Strategy</h3>
<p>A key design choice is the <strong>random sampling strategy</strong>: during training, for each image, the model randomly samples from the available caption granularities. This introduces diversity, prevents overfitting to any single caption style, and teaches the model to be robust to varying levels of textual specificity.</p>
</section>
<section id="why-multi-granularity-captions-improve-patch-text-alignment" class="level3">
<h3 class="anchored" data-anchor-id="why-multi-granularity-captions-improve-patch-text-alignment" id="why-multi-granularity-captions-improve-patch-text-alignment">Why Multi-Granularity Captions Improve Patch-Text Alignment</h3>
<p>When a long, detailed caption describes <em>“a red fire hydrant near the curb, partially obscured by autumn leaves, with a yellow parking sign to its left,”</em> the model must develop image representations that encode these spatial and attribute details to align with the caption. This directly pushes patch representations toward being semantically informative about their local visual content.</p>
<hr>
</section>
</section>
<section id="pretraining-objectives-putting-it-all-together" class="level2">
<h2 class="anchored" data-anchor-id="pretraining-objectives-putting-it-all-together" id="pretraining-objectives-putting-it-all-together">Pretraining Objectives: Putting It All Together</h2>
<p>TIPSv2’s pretraining combines multiple objectives into a single training loss.</p>
<section id="contrastive-image-text-alignment-loss" class="level3">
<h3 class="anchored" data-anchor-id="contrastive-image-text-alignment-loss" id="contrastive-image-text-alignment-loss">Contrastive Image-Text Alignment Loss</h3>
<p>The foundational objective is a <strong>CLIP-style contrastive loss</strong> (or SigLIP-style sigmoid loss for the SO-400m variant) between global image embeddings and text embeddings:</p>
<ul>
<li>The image encoder produces a <code>[CLS]</code> token embedding for each image.</li>
<li>The text encoder produces an embedding for each caption.</li>
<li>A cross-modal contrastive loss (InfoNCE or sigmoid binary cross-entropy) aligns matched pairs and pushes apart mismatched pairs.</li>
</ul>
</section>
<section id="ibot-self-supervised-loss" class="level3">
<h3 class="anchored" data-anchor-id="ibot-self-supervised-loss" id="ibot-self-supervised-loss">iBOT++ Self-Supervised Loss</h3>
<p>The iBOT++ patch-level loss operates alongside the contrastive loss:</p>
<ol type="1">
<li>Two augmented views of each image are passed through the student encoder.</li>
<li>A momentum-updated teacher (with head-only EMA on projection heads) produces target representations.</li>
<li>For <strong>every</strong> patch token in both views, a distribution prediction loss is computed.</li>
<li>A <code>[CLS]</code>-level self-supervised classification loss (DINO-style) is also applied.</li>
</ol>
</section>
<section id="combined-loss-function" class="level3">
<h3 class="anchored" data-anchor-id="combined-loss-function" id="combined-loss-function">Combined Loss Function</h3>
<p>The final training loss is a weighted combination:</p>
<p><span class="math display">\[\mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{contrastive}} + \beta \cdot \mathcal{L}_{\text{iBOT++}}\]</span></p>
<p>where <span class="math inline">\(\alpha\)</span> and <span class="math inline">\(\beta\)</span> are hyperparameters balancing global alignment and dense patch-level alignment.</p>
<hr>
</section>
</section>
<section id="evaluation-protocol-9-tasks-20-datasets" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-protocol-9-tasks-20-datasets" id="evaluation-protocol-9-tasks-20-datasets">Evaluation Protocol: 9 Tasks, 20 Datasets</h2>
<p>One of TIPSv2’s distinguishing features is the scope and rigor of its evaluation.</p>
<section id="global-image-text-tasks-7-evaluations" class="level3">
<h3 class="anchored" data-anchor-id="global-image-text-tasks-7-evaluations" id="global-image-text-tasks-7-evaluations">Global Image-Text Tasks (7 Evaluations)</h3>
<ul>
<li><strong>Zero-shot image classification</strong> (ImageNet) — standard measure of global semantic recognition</li>
<li><strong>Image-text retrieval</strong> — matching images to captions and vice versa (COCO, Flickr30k)</li>
<li><strong>Image captioning</strong> (DOCCI) — generating or retrieving descriptive captions</li>
</ul>
</section>
<section id="dense-image-understanding-tasks-9-evaluations" class="level3">
<h3 class="anchored" data-anchor-id="dense-image-understanding-tasks-9-evaluations" id="dense-image-understanding-tasks-9-evaluations">Dense Image Understanding Tasks (9 Evaluations)</h3>
<ul>
<li><strong>Zero-shot semantic segmentation</strong> — identifying and delineating semantic regions without task-specific fine-tuning (PASCAL VOC, ADE20k, COCO-Stuff, Pascal Context)</li>
<li><strong>Semantic segmentation with linear probing</strong> — evaluating patch features with a linear classifier</li>
<li><strong>Depth estimation</strong> (NYUv2) — monocular depth prediction from a single image with frozen features</li>
<li><strong>Open-vocabulary dense prediction</strong> — generalizing segmentation to unseen categories</li>
</ul>
</section>
<section id="evaluation-regime-frozen-features" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-regime-frozen-features" id="evaluation-regime-frozen-features">Evaluation Regime: Frozen Features</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Most benchmarks are conducted with <strong>frozen encoder features</strong> — weights are not fine-tuned on the downstream task. This is the hardest and most informative evaluation regime for foundation models.</p>
</div>
</div>
<hr>
</section>
</section>
<section id="experimental-results-and-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="experimental-results-and-benchmarks" id="experimental-results-and-benchmarks">Experimental Results and Benchmarks</h2>
<section id="dense-understanding-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="dense-understanding-segmentation" id="dense-understanding-segmentation">Dense Understanding: Segmentation</h3>
<p>TIPSv2 achieves <strong>state-of-the-art performance on all four zero-shot semantic segmentation benchmarks</strong> evaluated:</p>
<ul>
<li><strong>iBOT++ alone</strong> improves zero-shot segmentation by <strong>+14.1 mIoU</strong> vs.&nbsp;the standard iBOT baseline.</li>
<li>TIPSv2 outperforms both <strong>SILC</strong> and <strong>DINOv2</strong> across all four segmentation datasets.</li>
<li>Performance on PASCAL VOC and COCO-Stuff shows cleanly delineated semantic boundaries.</li>
</ul>
</section>
<section id="global-tasks-classification-and-retrieval" class="level3">
<h3 class="anchored" data-anchor-id="global-tasks-classification-and-retrieval" id="global-tasks-classification-and-retrieval">Global Tasks: Classification and Retrieval</h3>
<ul>
<li>Achieves best or second-best performance in <strong>5 out of 7 global evaluations</strong>.</li>
<li>On COCO image-text retrieval and DOCCI captioning, <strong>TIPSv2 outperforms models with 56% more parameters</strong>.</li>
<li>Zero-shot ImageNet classification remains strong — dense alignment improvements do not compromise global discriminability.</li>
</ul>
</section>
<section id="depth-estimation" class="level3">
<h3 class="anchored" data-anchor-id="depth-estimation" id="depth-estimation">Depth Estimation</h3>
<p>On NYUv2 monocular depth estimation with frozen features, TIPSv2 achieves best or second-best results, validating that spatially coherent patch representations also encode meaningful metric depth information.</p>
</section>
<section id="summary-best-or-second-best-across-the-board" class="level3">
<h3 class="anchored" data-anchor-id="summary-best-or-second-best-across-the-board" id="summary-best-or-second-best-across-the-board">Summary: Best or Second-Best Across the Board</h3>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Category</th>
<th>Performance</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Global evaluations</td>
<td>Best or 2nd-best in <strong>5 of 7</strong></td>
</tr>
<tr class="even">
<td>Dense understanding evaluations</td>
<td>Best or 2nd-best in <strong>7 of 9</strong></td>
</tr>
<tr class="odd">
<td>Zero-shot segmentation benchmarks</td>
<td><strong>State-of-the-art on all 4</strong></td>
</tr>
</tbody>
</table>
<p>This breadth of strong performance across qualitatively different task types is unusual — most models specialize at either global alignment (CLIP) or dense tasks (DINOv2). TIPSv2 achieves strong results on both families simultaneously.</p>
<hr>
</section>
</section>
<section id="comparison-with-prior-work" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-prior-work" id="comparison-with-prior-work">Comparison with Prior Work</h2>
<section id="vs.-clip-siglip" class="level3">
<h3 class="anchored" data-anchor-id="vs.-clip-siglip" id="vs.-clip-siglip">vs.&nbsp;CLIP / SigLIP</h3>
<p><strong>CLIP</strong> and <strong>SigLIP</strong> excel at image classification and image-text retrieval but have limited spatial awareness due to their purely global training objective. TIPSv2 significantly outperforms them on dense tasks while remaining competitive on global tasks.</p>
</section>
<section id="vs.-dinov2" class="level3">
<h3 class="anchored" data-anchor-id="vs.-dinov2" id="vs.-dinov2">vs.&nbsp;DINOv2</h3>
<p><strong>DINOv2</strong> is known for excellent patch-level representations and strong dense task performance. However, DINOv2 has no text alignment — it cannot support cross-modal retrieval or language-grounded zero-shot classification. TIPSv2 surpasses DINOv2 on zero-shot segmentation while also performing strongly on text-grounded tasks that DINOv2 cannot natively address.</p>
</section>
<section id="vs.-silc" class="level3">
<h3 class="anchored" data-anchor-id="vs.-silc" id="vs.-silc">vs.&nbsp;SILC</h3>
<p><strong>SILC</strong> combines self-supervised and image-text learning objectives, making it a close conceptual relative of TIPS and TIPSv2. TIPSv2 outperforms SILC on dense segmentation benchmarks, demonstrating that iBOT++ and multi-granularity captions provide meaningful gains.</p>
</section>
<section id="vs.-pe-core-vit-g" class="level3">
<h3 class="anchored" data-anchor-id="vs.-pe-core-vit-g" id="vs.-pe-core-vit-g">vs.&nbsp;PE-core ViT-G</h3>
<p><strong>PE-core</strong> (Perception Encoder) ViT-G is a much larger vision-language model. Despite its greater capacity, TIPSv2 outperforms PE-core ViT-G on COCO and DOCCI evaluations — a striking result given that PE-core has roughly 56% more parameters.</p>
</section>
<section id="vs.-tips-v1" class="level3">
<h3 class="anchored" data-anchor-id="vs.-tips-v1" id="vs.-tips-v1">vs.&nbsp;TIPS (v1)</h3>
<p>TIPSv2 improves upon TIPS on virtually all benchmarks, with the most pronounced gains on dense tasks. iBOT++ accounts for the bulk of the dense task improvement, multi-granularity captions primarily improve global text-image tasks, and head-only EMA improves training efficiency without sacrificing performance.</p>
<hr>
</section>
</section>
<section id="practical-applications-and-downstream-tasks" class="level2">
<h2 class="anchored" data-anchor-id="practical-applications-and-downstream-tasks" id="practical-applications-and-downstream-tasks">Practical Applications and Downstream Tasks</h2>
<section id="zero-shot-semantic-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="zero-shot-semantic-segmentation" id="zero-shot-semantic-segmentation">Zero-Shot Semantic Segmentation</h3>
<p>TIPSv2’s strong patch-text alignment makes it directly applicable to open-vocabulary semantic segmentation without task-specific fine-tuning. By computing cosine similarity between patch embeddings and text embeddings of class names, one can generate segmentation maps that correctly delineate semantic regions.</p>
</section>
<section id="multimodal-retrieval-and-search" class="level3">
<h3 class="anchored" data-anchor-id="multimodal-retrieval-and-search" id="multimodal-retrieval-and-search">Multimodal Retrieval and Search</h3>
<p>The strong global image-text alignment makes TIPSv2 suitable as a backbone for large-scale multimodal search engines. Applications range from e-commerce visual search to scientific image database querying.</p>
</section>
<section id="monocular-depth-estimation" class="level3">
<h3 class="anchored" data-anchor-id="monocular-depth-estimation" id="monocular-depth-estimation">Monocular Depth Estimation</h3>
<p>The spatially coherent patch features encode metric depth information surprisingly well, enabling monocular depth estimation with simple linear probing. Applications include robotics, augmented reality, and 3D scene understanding.</p>
</section>
<section id="foundation-for-multimodal-large-language-models" class="level3">
<h3 class="anchored" data-anchor-id="foundation-for-multimodal-large-language-models" id="foundation-for-multimodal-large-language-models">Foundation for Multimodal Large Language Models</h3>
<p>High-quality vision encoders are a critical component of MLLMs such as PaLI, LLaVA, InstructBLIP, and Gemini. TIPSv2’s combination of strong global text alignment and rich patch-level semantics makes it an excellent candidate as a visual backbone for MLLMs.</p>
</section>
<section id="zero-shot-visual-question-answering" class="level3">
<h3 class="anchored" data-anchor-id="zero-shot-visual-question-answering" id="zero-shot-visual-question-answering">Zero-Shot Visual Question Answering</h3>
<p>By leveraging the rich spatial semantics of TIPSv2 patch representations, downstream VQA models can more accurately localize relevant regions in response to questions requiring spatial reasoning.</p>
</section>
<section id="referring-expression-comprehension" class="level3">
<h3 class="anchored" data-anchor-id="referring-expression-comprehension" id="referring-expression-comprehension">Referring Expression Comprehension</h3>
<p>TIPSv2’s multi-granularity caption training directly prepares the model for fine-grained grounded comprehension such as “the second person from the left wearing a red hat.”</p>
<hr>
</section>
</section>
<section id="model-weights-and-usage" class="level2">
<h2 class="anchored" data-anchor-id="model-weights-and-usage" id="model-weights-and-usage">Model Weights and Usage</h2>
<section id="publicly-released-models" class="level3">
<h3 class="anchored" data-anchor-id="publicly-released-models" id="publicly-released-models">Publicly Released Models</h3>
<p>The TIPSv2 team has released pre-trained model weights via Hugging Face:</p>
<ul>
<li><strong><code>google/tipsv2-b14</code></strong> — ViT-B/14 model, distilled from ViT-g teacher</li>
<li>Additional model sizes (ViT-L, ViT-g) are available via the <a href="https://gdm-tipsv2.github.io/">project page</a></li>
</ul>
</section>
<section id="code-repository" class="level3">
<h3 class="anchored" data-anchor-id="code-repository" id="code-repository">Code Repository</h3>
<p>Full training and evaluation code is at <a href="https://github.com/google-deepmind/tips">github.com/google-deepmind/tips</a>, covering both TIPSv2 (CVPR 2026) and TIPS (ICLR 2025), including pretraining code, distillation pipeline, evaluation scripts, and pre-trained checkpoints.</p>
</section>
<section id="example-usage-huggingface" class="level3">
<h3 class="anchored" data-anchor-id="example-usage-huggingface" id="example-usage-huggingface">Example Usage (HuggingFace)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel, AutoProcessor</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model and processor</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"google/tipsv2-b14"</span>)</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> AutoProcessor.from_pretrained(<span class="st">"google/tipsv2-b14"</span>)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Encode an image</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">"example.jpg"</span>)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>inputs <span class="op">=</span> processor(images<span class="op">=</span>image, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model.get_image_features(<span class="op">**</span>inputs)</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Get patch-level representations (exclude [CLS] token)</span></span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>patch_features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">1</span>:, :]</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Get global [CLS] representation</span></span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>cls_feature <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">0</span>, :]</span></code></pre></div></div>
</section>
<section id="zero-shot-segmentation-example" class="level3">
<h3 class="anchored" data-anchor-id="zero-shot-segmentation-example" id="zero-shot-segmentation-example">Zero-Shot Segmentation Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel, AutoTokenizer, AutoProcessor</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"google/tipsv2-b14"</span>)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> AutoProcessor.from_pretrained(<span class="st">"google/tipsv2-b14"</span>)</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>tokenizer <span class="op">=</span> AutoTokenizer.from_pretrained(<span class="st">"google/tipsv2-b14"</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Class names for zero-shot segmentation</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>class_names <span class="op">=</span> [<span class="st">"sky"</span>, <span class="st">"tree"</span>, <span class="st">"road"</span>, <span class="st">"car"</span>, <span class="st">"person"</span>, <span class="st">"building"</span>]</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Encode class names as text</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>text_inputs <span class="op">=</span> tokenizer(class_names, padding<span class="op">=</span><span class="va">True</span>, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    text_features <span class="op">=</span> model.get_text_features(<span class="op">**</span>text_inputs)</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    text_features <span class="op">=</span> F.normalize(text_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Encode image and get patch features</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">"scene.jpg"</span>)</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>img_inputs <span class="op">=</span> processor(images<span class="op">=</span>image, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    img_outputs <span class="op">=</span> model.get_image_features(<span class="op">**</span>img_inputs)</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>    patch_features <span class="op">=</span> img_outputs.last_hidden_state[:, <span class="dv">1</span>:, :]  <span class="co"># (1, N_patches, D)</span></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>    patch_features <span class="op">=</span> F.normalize(patch_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Compute similarity map: (N_patches, N_classes)</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>similarity <span class="op">=</span> torch.einsum(<span class="st">"bpd,cd-&gt;bpc"</span>, patch_features, text_features.unsqueeze(<span class="dv">0</span>))</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>segmentation_map <span class="op">=</span> similarity.argmax(dim<span class="op">=-</span><span class="dv">1</span>)  <span class="co"># (batch, N_patches)</span></span></code></pre></div></div>
<hr>
</section>
</section>
<section id="broader-impact-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="broader-impact-and-limitations" id="broader-impact-and-limitations">Broader Impact and Limitations</h2>
<section id="positive-impacts" class="level3">
<h3 class="anchored" data-anchor-id="positive-impacts" id="positive-impacts">Positive Impacts</h3>
<p>TIPSv2’s strong patch-text alignment capabilities have the potential to significantly advance:</p>
<ul>
<li><strong>Accessibility technology</strong> — more accurate image descriptions for visually impaired users</li>
<li><strong>Medical imaging</strong> — precise region-level understanding without expensive annotation</li>
<li><strong>Scientific image analysis</strong> — automated understanding of spatial patterns in microscopy, satellite imagery, etc.</li>
<li><strong>Robotics and embodied AI</strong> — spatially grounded understanding for manipulation and navigation</li>
<li><strong>Efficient AI</strong> — the head-only EMA strategy reduces training resource requirements</li>
</ul>
</section>
<section id="potential-concerns" class="level3">
<h3 class="anchored" data-anchor-id="potential-concerns" id="potential-concerns">Potential Concerns</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Warning
</div>
</div>
<div class="callout-body-container callout-body">
<p>Like all large vision-language models, TIPSv2 inherits risks associated with this class of systems.</p>
</div>
</div>
<p><strong>Bias and fairness.</strong> Models trained on web-scale data may encode societal biases. The use of synthetic captions from PaliGemma and Gemini could propagate or transform existing biases.</p>
<p><strong>Privacy.</strong> Large models trained on web-scraped image-text pairs may have memorized aspects of training data.</p>
<p><strong>Misuse.</strong> Highly capable vision-language encoders can be components of surveillance systems or other dual-use applications.</p>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<p><strong>Dense task performance vs.&nbsp;task-specific models.</strong> While TIPSv2 achieves impressive zero-shot and frozen-feature performance, fully fine-tuned task-specific models (Mask2Former, DepthAnything) typically outperform frozen foundation models on their specific benchmarks.</p>
<p><strong>Text encoder scope.</strong> TIPSv2’s text encoder is not a large language model — its language understanding is bounded by what can be learned from paired image-text training.</p>
<p><strong>Compute requirements at scale.</strong> Despite the efficiency gains from head-only EMA, training ViT-g scale models with the full pretraining objective still requires significant computational resources.</p>
<hr>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>TIPSv2 represents a carefully engineered and empirically grounded advance in vision-language pretraining. By tracing its design choices back to a single surprising empirical observation — that patch-level distillation produces better patch-text alignment than the teacher model itself — the paper develops a coherent set of three complementary innovations:</p>
<p><strong>iBOT++</strong> extends self-supervised patch-level loss to all tokens, delivering +14.1 mIoU gains on zero-shot segmentation alone.</p>
<p><strong>Head-only EMA</strong> leverages the text supervision signal to eliminate the need for a full-backbone EMA teacher, reducing training parameter counts by ~42% and improving throughput without sacrificing performance.</p>
<p><strong>Multi-Granularity Captions</strong> provides richer, spatially-detailed text supervision by mixing short, medium, and long synthetic captions from PaliGemma and Gemini.</p>
<p>Together, these innovations produce a model family that achieves state-of-the-art performance on all four zero-shot segmentation benchmarks, best or second-best on the majority of its 20-dataset evaluation suite, and strong global image-text alignment — often matching or surpassing models with significantly more parameters.</p>
<p>TIPSv2 is a testament to the value of careful empirical investigation: sometimes the best improvements come not from scaling compute or data, but from understanding <em>why</em> a model works the way it does, and designing training procedures that deliberately cultivate the mechanisms responsible for success.</p>
<hr>
</section>
<section id="references-and-further-reading" class="level2">
<h2 class="anchored" data-anchor-id="references-and-further-reading" id="references-and-further-reading">References and Further Reading</h2>
<section id="primary-sources" class="level3">
<h3 class="anchored" data-anchor-id="primary-sources" id="primary-sources">Primary Sources</h3>
<ul>
<li><strong>TIPSv2 Paper:</strong> Cao, B., et al.&nbsp;“TIPSv2: Advancing Vision-Language Pretraining with Enhanced Patch-Text Alignment.” <em>CVPR 2026</em>. <a href="https://arxiv.org/abs/2604.12012">arXiv:2604.12012</a></li>
<li><strong>TIPS (v1) Paper:</strong> “TIPS: Text-Image Pretraining with Spatial Awareness.” <em>ICLR 2025</em>. <a href="https://arxiv.org/abs/2410.16512">arXiv:2410.16512</a></li>
<li><strong>TIPSv2 Project Page:</strong> <a href="https://gdm-tipsv2.github.io/">gdm-tipsv2.github.io</a></li>
<li><strong>GitHub (TIPS + TIPSv2):</strong> <a href="https://github.com/google-deepmind/tips">github.com/google-deepmind/tips</a></li>
<li><strong>HuggingFace Model Hub:</strong> <a href="https://huggingface.co/google/tipsv2-b14">google/tipsv2-b14</a></li>
</ul>
</section>
<section id="related-work" class="level3">
<h3 class="anchored" data-anchor-id="related-work" id="related-work">Related Work</h3>
<ul>
<li><strong>CLIP:</strong> Radford, A., et al.&nbsp;“Learning Transferable Visual Models From Natural Language Supervision.” <em>ICML 2021</em>.</li>
<li><strong>SigLIP:</strong> Zhai, X., et al.&nbsp;“Sigmoid Loss for Language Image Pre-Training.” <em>ICCV 2023</em>.</li>
<li><strong>DINOv2:</strong> Oquab, M., et al.&nbsp;“DINOv2: Learning Robust Visual Features without Supervision.” <em>TMLR 2023</em>.</li>
<li><strong>iBOT:</strong> Zhou, J., et al.&nbsp;“iBOT: Image BERT Pre-Training with Online Tokenizer.” <em>ICLR 2022</em>.</li>
<li><strong>DINO:</strong> Caron, M., et al.&nbsp;“Emerging Properties in Self-Supervised Vision Transformers.” <em>ICCV 2021</em>.</li>
<li><strong>SILC:</strong> Naeem, M.F., et al.&nbsp;“SILC: Improving Vision Language Pretraining with Self-Distillation.” 2023.</li>
<li><strong>PaliGemma:</strong> Google DeepMind’s vision-language model used for medium-granularity caption generation.</li>
</ul>
</section>
<section id="survey-and-context-reading" class="level3">
<h3 class="anchored" data-anchor-id="survey-and-context-reading" id="survey-and-context-reading">Survey and Context Reading</h3>
<ul>
<li><strong>Vision Transformer (ViT):</strong> Dosovitskiy, A., et al.&nbsp;“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” <em>ICLR 2021</em>.</li>
<li><strong>Masked Autoencoders:</strong> He, K., et al.&nbsp;“Masked Autoencoders Are Scalable Vision Learners.” <em>CVPR 2022</em>.</li>
<li><strong>Vision-Language Pretraining Survey:</strong> <a href="https://paperswithcode.com/task/vision-language-pre-training">Papers With Code — Vision-Language Pre-Training</a>.</li>
</ul>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[GitHub Actions × MLFlow CI/CD for Computer Vision]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/github-mlflow/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/github-mlflow/</guid>
      <pubDate>Thu, 09 Apr 2026 00:00:00 GMT</pubDate>
      <description><![CDATA[A practitioner’s guide to building reliable, reproducible, and observable CV model pipelines using GitHub Actions and MLFlow.]]></description>
      <category>code</category>
      <category>mlops</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="github-actions-mlflow-cicd-for-computer-vision" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/github-mlflow/cicd.png" class="img-fluid"></p>
<section id="philosophy-guiding-principles" class="level2">
<h2 class="anchored" data-anchor-id="philosophy-guiding-principles" id="philosophy-guiding-principles">Philosophy &amp; Guiding Principles</h2>
<p>Operational excellence in CV production systems rests on four pillars:</p>
<table class="table-striped table-hover caption-top table">
<caption>Four pillars of operational excellence</caption>
<colgroup>
<col style="width: 50%">
<col style="width: 50%">
</colgroup>
<thead>
<tr class="header">
<th>Pillar</th>
<th>What it means in practice</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Reproducibility</strong></td>
<td>Every training run can be re-created identically from a commit SHA + data hash</td>
</tr>
<tr class="even">
<td><strong>Observability</strong></td>
<td>Every metric, artifact, and environment is logged and queryable</td>
</tr>
<tr class="odd">
<td><strong>Automation</strong></td>
<td>Humans approve transitions; machines do everything else</td>
</tr>
<tr class="even">
<td><strong>Fail Fast</strong></td>
<td>Catch regressions on cheap compute (unit tests, sanity checks) before expensive GPU training</td>
</tr>
</tbody>
</table>
<p>These principles drive every recommendation in this guide.</p>
<hr>
</section>
<section id="repository-project-structure" class="level2">
<h2 class="anchored" data-anchor-id="repository-project-structure" id="repository-project-structure">Repository &amp; Project Structure</h2>
<p>Organizing your monorepo consistently makes workflow triggers predictable and avoids accidental pipeline skips.</p>
<pre><code>repo root/
├── .github/
│   ├── workflows/
│   │   ├── ci.yml          # On every PR
│   │   ├── train.yml       # Merge to main / manual
│   │   ├── evaluate.yml    # Post-training gate
│   │   └── deploy.yml      # Registry stage promotion
│   └── actions/
│       └── setup-mlflow/   # Reusable composite action
├── src/
│   ├── data/               # Loading, augmentation, versioning
│   ├── models/             # Architecture definitions
│   ├── training/           # Loops, callbacks
│   ├── evaluation/         # Metrics, visualisations
│   └── serving/            # Inference wrapper
├── configs/
│   ├── base.yaml           # Shared hyperparameters
│   ├── experiment/         # Hydra overrides
│   └── deployment/         # Serving config per env
├── tests/
│   ├── unit/
│   ├── integration/
│   └── smoke/              # Fast inference checks
├── mlflow/
│   └── MLproject           # Reproducible runs
├── scripts/
│   ├── register_model.py
│   ├── compare_runs.py
│   └── promote_model.py
└── Makefile
</code></pre>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Rule
</div>
</div>
<div class="callout-body-container callout-body">
<p>Keep model training code, serving code, and infrastructure config in the same repository. Split repos for CV pipelines cause drift between what was trained and what is served.</p>
</div>
</div>
<hr>
</section>
<section id="mlflow-setup-for-cv-pipelines" class="level2">
<h2 class="anchored" data-anchor-id="mlflow-setup-for-cv-pipelines" id="mlflow-setup-for-cv-pipelines">MLFlow Setup for CV Pipelines</h2>
<section id="mlproject-file" class="level3">
<h3 class="anchored" data-anchor-id="mlproject-file" id="mlproject-file">MLProject File</h3>
<p>The <code>MLproject</code> file is the contract between your code and MLFlow’s runner. Always define it — it makes runs reproducible from any environment.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>mlflow/MLproject</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2" data-filename="mlflow/MLproject"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> cv-pipeline</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="fu">conda_env</span><span class="kw">:</span><span class="at"> conda.yaml</span><span class="co">   # or docker_env / python_env</span></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="fu">entry_points</span><span class="kw">:</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">train</span><span class="kw">:</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">parameters</span><span class="kw">:</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">config_path</span><span class="kw">:</span><span class="at">  </span><span class="kw">{</span><span class="fu">type</span><span class="kw">:</span><span class="at"> str</span><span class="kw">,</span><span class="at"> </span><span class="fu">default</span><span class="kw">:</span><span class="at"> </span><span class="st">"configs/base.yaml"</span><span class="kw">}</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">data_version</span><span class="kw">:</span><span class="at"> </span><span class="kw">{</span><span class="fu">type</span><span class="kw">:</span><span class="at"> str</span><span class="kw">}</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run_name</span><span class="kw">:</span><span class="at">     </span><span class="kw">{</span><span class="fu">type</span><span class="kw">:</span><span class="at"> str</span><span class="kw">,</span><span class="at"> </span><span class="fu">default</span><span class="kw">:</span><span class="at"> </span><span class="st">"unnamed"</span><span class="kw">}</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="fu">    command</span><span class="kw">: </span><span class="ch">&gt;</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>      python -m src.training.train</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        --config {config_path}</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        --data-version {data_version}</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        --run-name {run_name}</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">evaluate</span><span class="kw">:</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">parameters</span><span class="kw">:</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run_id</span><span class="kw">:</span><span class="at">      </span><span class="kw">{</span><span class="fu">type</span><span class="kw">:</span><span class="at"> str</span><span class="kw">}</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">dataset</span><span class="kw">:</span><span class="at">     </span><span class="kw">{</span><span class="fu">type</span><span class="kw">:</span><span class="at"> str</span><span class="kw">,</span><span class="at"> </span><span class="fu">default</span><span class="kw">:</span><span class="at"> </span><span class="st">"val"</span><span class="kw">}</span></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="fu">    command</span><span class="kw">: </span><span class="ch">&gt;</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>      python -m src.evaluation.evaluate</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>        --run-id {run_id}</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        --dataset {dataset}</span></code></pre></div></div>
</div>
</section>
<section id="logging-cv-artifacts-what-to-always-log" class="level3">
<h3 class="anchored" data-anchor-id="logging-cv-artifacts-what-to-always-log" id="logging-cv-artifacts-what-to-always-log">Logging CV Artifacts — What to Always Log</h3>
<p>Log generously during training. Storage is cheap; missing data when debugging a production incident is expensive.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/training/train.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3" data-filename="src/training/train.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pytorch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pathlib <span class="im">import</span> Path</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_run(config, data_version):</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    mlflow.set_experiment(config.experiment_name)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span>config.run_name) <span class="im">as</span> run:</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># --- Tags: non-numeric metadata ---</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        mlflow.set_tags({</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">"git.commit"</span>:    os.environ[<span class="st">"GITHUB_SHA"</span>],</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">"git.branch"</span>:    os.environ.get(<span class="st">"GITHUB_REF_NAME"</span>, <span class="st">"local"</span>),</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">"data.version"</span>:  data_version,</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">"model.arch"</span>:    config.model.architecture,</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">"triggered_by"</span>:  os.environ.get(<span class="st">"GITHUB_ACTOR"</span>, <span class="st">"local"</span>),</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># --- Params: hyperparameters &amp; config ---</span></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        mlflow.log_params(flatten_dict(config))   <span class="co"># log full config, not just LR/BS</span></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># --- Training loop ---</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(config.epochs):</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            metrics <span class="op">=</span> train_one_epoch(model, loader, optimizer)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            mlflow.log_metrics({</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>                <span class="st">"train/loss"</span>:       metrics.loss,</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">"train/lr"</span>:         scheduler.get_last_lr()[<span class="dv">0</span>],</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">"val/mAP"</span>:          metrics.val_map,</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">"val/mAP_50"</span>:       metrics.val_map_50,</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">"val/precision"</span>:    metrics.precision,</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">"val/recall"</span>:       metrics.recall,</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>                <span class="st">"gpu/memory_mb"</span>:    torch.cuda.max_memory_allocated() <span class="op">//</span> <span class="fl">1e6</span>,</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>            }, step<span class="op">=</span>epoch)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># --- CV-specific artifacts ---</span></span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Confusion matrix image</span></span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        mlflow.log_figure(plot_confusion_matrix(model, val_loader), <span class="st">"eval/confusion_matrix.png"</span>)</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># PR curve per class</span></span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        mlflow.log_figure(plot_pr_curves(model, val_loader), <span class="st">"eval/pr_curves.png"</span>)</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Sample predictions (good + failure cases)</span></span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        log_prediction_grid(model, val_loader, run, n<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Model weights + signature</span></span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        signature <span class="op">=</span> mlflow.models.infer_signature(sample_input, sample_output)</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        mlflow.pytorch.log_model(model, <span class="st">"model"</span>, signature<span class="op">=</span>signature)</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Full config file for exact reproduction</span></span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        mlflow.log_artifact(<span class="st">"configs/base.yaml"</span>, <span class="st">"config"</span>)</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> run.info.run_id</span></code></pre></div></div>
</div>
</section>
<section id="inputoutput-signature" class="level3">
<h3 class="anchored" data-anchor-id="inputoutput-signature" id="inputoutput-signature">Input/Output Signature</h3>
<p>Always define a model signature. It enforces schema validation at serving time and catches preprocessing mismatches before they reach users.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/training/signature.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4" data-filename="src/training/signature.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.models.signature <span class="im">import</span> ModelSignature</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.types.schema <span class="im">import</span> Schema, TensorSpec</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="co"># For a BCHW image classifier</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>input_schema  <span class="op">=</span> Schema([TensorSpec(np.dtype(np.float32), (<span class="op">-</span><span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>), <span class="st">"image"</span>)])</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>output_schema <span class="op">=</span> Schema([TensorSpec(np.dtype(np.float32), (<span class="op">-</span><span class="dv">1</span>, <span class="dv">1000</span>),         <span class="st">"logits"</span>)])</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>signature     <span class="op">=</span> ModelSignature(inputs<span class="op">=</span>input_schema, outputs<span class="op">=</span>output_schema)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(model, <span class="st">"model"</span>, signature<span class="op">=</span>signature)</span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Why signatures matter
</div>
</div>
<div class="callout-body-container callout-body">
<p>A model registered without an input/output schema loses automatic schema validation in serving. This makes it impossible to safely automate inference-time assertions and is a common source of silent production errors.</p>
</div>
</div>
<hr>
</section>
</section>
<section id="github-actions-workflow-architecture" class="level2">
<h2 class="anchored" data-anchor-id="github-actions-workflow-architecture" id="github-actions-workflow-architecture">GitHub Actions Workflow Architecture</h2>
<section id="event-to-workflow-mapping" class="level3">
<h3 class="anchored" data-anchor-id="event-to-workflow-mapping" id="event-to-workflow-mapping">Event-to-Workflow Mapping</h3>
<p>Design workflows around <strong>what changed</strong> and <strong>who initiated</strong> the change, not just which branch.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-trigger-map" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-trigger-map-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-trigger-map">flowchart TD
    PR["🔀 Pull Request opened / updated"]
    MERGE["✅ Merge to main"]
    MANUAL["🖱️ Manual dispatch workflow_dispatch"]
    WEBHOOK["🔔 MLFlow webhook / Registry event"]

    PR --&gt; CI["ci.yml lint · unit tests smoke inference · config validation"]

    MERGE --&gt; TRAIN["train.yml full training job logs to MLFlow"]
    TRAIN --&gt;|on success| EVAL["evaluate.yml quality gates model comparison"]
    EVAL --&gt;|on pass| REG["📋 Opens PR to promote model in registry"]

    MANUAL --&gt; TRAIN2["train.yml re-train with custom params (experiments)"]

    WEBHOOK --&gt; DEPLOY["deploy.yml deploy 'Production'-staged model to serving infra"]

    style CI fill:#d4edda,stroke:#28a745
    style TRAIN fill:#cce5ff,stroke:#004085
    style TRAIN2 fill:#cce5ff,stroke:#004085
    style EVAL fill:#fff3cd,stroke:#856404
    style REG fill:#e2d9f3,stroke:#6f42c1
    style DEPLOY fill:#f8d7da,stroke:#721c24
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-trigger-map-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: GitHub event → workflow mapping
</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="reusable-composite-action-for-mlflow-setup" class="level3">
<h3 class="anchored" data-anchor-id="reusable-composite-action-for-mlflow-setup" id="reusable-composite-action-for-mlflow-setup">Reusable Composite Action for MLFlow Setup</h3>
<p>Avoid duplicating MLFlow setup across every workflow with a composite action.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/actions/setup-mlflow/action.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5" data-filename=".github/actions/setup-mlflow/action.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> Setup MLFlow</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="fu">description</span><span class="kw">:</span><span class="at"> Installs dependencies and configures MLFlow tracking server</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="fu">inputs</span><span class="kw">:</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">python-version</span><span class="kw">:</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">false</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">default</span><span class="kw">:</span><span class="at"> </span><span class="st">"3.11"</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="fu">runs</span><span class="kw">:</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">using</span><span class="kw">:</span><span class="at"> composite</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/setup-python@v5</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> ${{ inputs.python-version }}</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">cache</span><span class="kw">:</span><span class="at"> pip</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Install dependencies</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">shell</span><span class="kw">:</span><span class="at"> bash</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run</span><span class="kw">:</span><span class="at"> pip install -r requirements.txt</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Configure MLFlow</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">shell</span><span class="kw">:</span><span class="at"> bash</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">MLFLOW_TRACKING_URI</span><span class="kw">:</span><span class="at">      ${{ inputs.mlflow-tracking-uri }}</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">MLFLOW_S3_ENDPOINT_URL</span><span class="kw">:</span><span class="at">   ${{ inputs.mlflow-s3-bucket }}</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a><span class="fu">      run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        echo "MLFLOW_TRACKING_URI=$MLFLOW_TRACKING_URI"   &gt;&gt; $GITHUB_ENV</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>        echo "MLFLOW_S3_ENDPOINT_URL=$MLFLOW_S3_ENDPOINT_URL" &gt;&gt; $GITHUB_ENV</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="ci-pipeline-validate-before-you-merge" class="level2">
<h2 class="anchored" data-anchor-id="ci-pipeline-validate-before-you-merge" id="ci-pipeline-validate-before-you-merge">CI Pipeline — Validate Before You Merge</h2>
<p>The goal of CI is to give fast, cheap signal on PRs — <strong>no GPU, no real training</strong>.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/ci.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6" data-filename=".github/workflows/ci.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> CI</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">pull_request</span><span class="kw">:</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">branches</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">main</span><span class="kw">,</span><span class="at"> develop</span><span class="kw">]</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">paths</span><span class="kw">:</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"src/**"</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"configs/**"</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"tests/**"</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"requirements*.txt"</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a><span class="fu">concurrency</span><span class="kw">:</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">group</span><span class="kw">:</span><span class="at"> ci-${{ github.ref }}</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">cancel-in-progress</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co">         # Kill stale CI runs on force-push</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">lint-and-type-check</span><span class="kw">:</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/setup-python@v5</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span><span class="at"> </span><span class="kw">{</span><span class="at"> </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> </span><span class="st">"3.11"</span><span class="kw">,</span><span class="at"> </span><span class="fu">cache</span><span class="kw">:</span><span class="at"> pip </span><span class="kw">}</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">run</span><span class="kw">:</span><span class="at"> pip install ruff mypy</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">run</span><span class="kw">:</span><span class="at"> ruff check src/ tests/</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">run</span><span class="kw">:</span><span class="at"> mypy src/ --ignore-missing-imports</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">unit-tests</span><span class="kw">:</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">needs</span><span class="kw">:</span><span class="at"> lint-and-type-check</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> http://localhost:5000</span><span class="co">     # local ephemeral server</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    </span><span class="st">""</span></span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Start local MLFlow server</span></span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> mlflow server --backend-store-uri sqlite:///mlflow.db &amp;</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Run unit tests</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> pytest tests/unit/ -v --tb=short --cov=src --cov-report=xml</span></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> codecov/codecov-action@v4</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">config-validation</span><span class="kw">:</span></span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/setup-python@v5</span></span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span><span class="at"> </span><span class="kw">{</span><span class="at"> </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> </span><span class="st">"3.11"</span><span class="kw">,</span><span class="at"> </span><span class="fu">cache</span><span class="kw">:</span><span class="at"> pip </span><span class="kw">}</span></span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Validate all YAML configs</span></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> python scripts/validate_configs.py configs/</span></span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">smoke-inference</span><span class="kw">:</span></span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">needs</span><span class="kw">:</span><span class="at"> unit-tests</span></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> ${{ secrets.MLFLOW_TRACKING_URI }}</span></span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    ${{ secrets.MLFLOW_S3_BUCKET }}</span></span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Run smoke test with current Production model</span></span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>          python scripts/smoke_test.py \</span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>            --model-stage Production \</span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a>            --n-images 10 \</span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a>            --max-latency-ms 200</span></code></pre></div></div>
</div>
<section id="smoke-test-script-pattern" class="level3">
<h3 class="anchored" data-anchor-id="smoke-test-script-pattern" id="smoke-test-script-pattern">Smoke Test Script Pattern</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>scripts/smoke_test.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7" data-filename="scripts/smoke_test.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pyfunc, time, sys, argparse, numpy <span class="im">as</span> np</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> run_smoke_test(stage: <span class="bu">str</span>, n_images: <span class="bu">int</span>, max_latency_ms: <span class="bu">float</span>):</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> mlflow.pyfunc.load_model(<span class="ss">f"models:/cv-model/</span><span class="sc">{</span>stage<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    dummy_batch <span class="op">=</span> np.random.rand(n_images, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>).astype(np.float32)</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    t0 <span class="op">=</span> time.perf_counter()</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    preds <span class="op">=</span> model.predict(dummy_batch)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    latency_ms <span class="op">=</span> (time.perf_counter() <span class="op">-</span> t0) <span class="op">*</span> <span class="dv">1000</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Latency: </span><span class="sc">{</span>latency_ms<span class="sc">:.1f}</span><span class="ss">ms for </span><span class="sc">{</span>n_images<span class="sc">}</span><span class="ss"> images"</span>)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">assert</span> latency_ms <span class="op">&lt;</span> max_latency_ms, (</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        <span class="ss">f"Smoke test FAILED: </span><span class="sc">{</span>latency_ms<span class="sc">:.1f}</span><span class="ss">ms &gt; </span><span class="sc">{</span>max_latency_ms<span class="sc">}</span><span class="ss">ms threshold"</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">assert</span> preds.shape[<span class="dv">0</span>] <span class="op">==</span> n_images, <span class="st">"Output batch size mismatch"</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Smoke test PASSED ✓"</span>)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    p <span class="op">=</span> argparse.ArgumentParser()</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    p.add_argument(<span class="st">"--model-stage"</span>,    default<span class="op">=</span><span class="st">"Production"</span>)</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    p.add_argument(<span class="st">"--n-images"</span>,       <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>,   default<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    p.add_argument(<span class="st">"--max-latency-ms"</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">float</span>, default<span class="op">=</span><span class="fl">200.0</span>)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    args <span class="op">=</span> p.parse_args()</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    run_smoke_test(args.model_stage, args.n_images, args.max_latency_ms)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="cd-pipeline-promote-register-deploy" class="level2">
<h2 class="anchored" data-anchor-id="cd-pipeline-promote-register-deploy" id="cd-pipeline-promote-register-deploy">CD Pipeline — Promote, Register, Deploy</h2>
<section id="training-workflow" class="level3">
<h3 class="anchored" data-anchor-id="training-workflow" id="training-workflow">Training Workflow</h3>
<p>Training jobs are expensive — protect them with <code>workflow_dispatch</code> for manual runs and auto-trigger only on clean merges to <code>main</code>.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/train.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8" data-filename=".github/workflows/train.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> Train</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">push</span><span class="kw">:</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">branches</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">main</span><span class="kw">]</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">paths</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">"src/models/**"</span><span class="kw">,</span><span class="at"> </span><span class="st">"src/training/**"</span><span class="kw">,</span><span class="at"> </span><span class="st">"configs/base.yaml"</span><span class="kw">]</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">workflow_dispatch</span><span class="kw">:</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">inputs</span><span class="kw">:</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">config_override</span><span class="kw">:</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">description</span><span class="kw">:</span><span class="at"> </span><span class="st">"Config file path (relative to configs/)"</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">default</span><span class="kw">:</span><span class="at"> </span><span class="st">"base.yaml"</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">data_version</span><span class="kw">:</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">description</span><span class="kw">:</span><span class="at"> </span><span class="st">"DVC/data version tag"</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">train</span><span class="kw">:</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">self-hosted</span><span class="kw">,</span><span class="at"> gpu</span><span class="kw">]</span><span class="co">        # GPU runner required</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">timeout-minutes</span><span class="kw">:</span><span class="at"> </span><span class="dv">360</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span><span class="at"> training</span><span class="co">              # Requires manual approval gate in GitHub Environments</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> ${{ secrets.MLFLOW_TRACKING_URI }}</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    ${{ secrets.MLFLOW_S3_BUCKET }}</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Pull data with DVC</span></span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">AWS_ACCESS_KEY_ID</span><span class="kw">:</span><span class="at">     ${{ secrets.DVC_AWS_KEY }}</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">AWS_SECRET_ACCESS_KEY</span><span class="kw">:</span><span class="at"> ${{ secrets.DVC_AWS_SECRET }}</span></span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>          dvc pull data/processed/${{ inputs.data_version || 'latest' }}</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Launch MLFlow training run</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">id</span><span class="kw">:</span><span class="at"> training</span></span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>          RUN_ID=$(python -m mlflow run mlflow/ \</span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>            -P config_path=configs/${{ inputs.config_override || 'base.yaml' }} \</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>            -P data_version=${{ inputs.data_version || 'latest' }} \</span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>            -P run_name="ci-${{ github.sha }}" \</span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>            --env-manager local \</span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>            2&gt;&amp;1 | grep "Run ID:" | awk '{print $NF}')</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>          echo "run_id=$RUN_ID" &gt;&gt; $GITHUB_OUTPUT</span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Export run ID as artifact</span></span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> echo "${{ steps.training.outputs.run_id }}" &gt; run_id.txt</span></span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/upload-artifact@v4</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">name</span><span class="kw">:</span><span class="at"> training-run-id</span></span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">path</span><span class="kw">:</span><span class="at"> run_id.txt</span></span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">outputs</span><span class="kw">:</span></span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run_id</span><span class="kw">:</span><span class="at"> ${{ steps.training.outputs.run_id }}</span></span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">evaluate</span><span class="kw">:</span></span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">needs</span><span class="kw">:</span><span class="at"> train</span></span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/workflows/evaluate.yml</span></span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run_id</span><span class="kw">:</span><span class="at"> ${{ needs.train.outputs.run_id }}</span></span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">secrets</span><span class="kw">:</span><span class="at"> inherit</span></span></code></pre></div></div>
</div>
</section>
<section id="evaluation-quality-gate-workflow" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-quality-gate-workflow" id="evaluation-quality-gate-workflow">Evaluation &amp; Quality Gate Workflow</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/evaluate.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9" data-filename=".github/workflows/evaluate.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> Evaluate</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">workflow_call</span><span class="kw">:</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">inputs</span><span class="kw">:</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run_id</span><span class="kw">:</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">type</span><span class="kw">:</span><span class="at"> string</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">quality-gate</span><span class="kw">:</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">self-hosted</span><span class="kw">,</span><span class="at"> gpu</span><span class="kw">]</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> ${{ secrets.MLFLOW_TRACKING_URI }}</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    ${{ secrets.MLFLOW_S3_BUCKET }}</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Run evaluation suite</span></span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>          python -m src.evaluation.evaluate \</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            --run-id ${{ inputs.run_id }} \</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            --dataset test \</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>            --output-path eval_report.json</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Quality gate check</span></span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">id</span><span class="kw">:</span><span class="at"> gate</span></span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>          python scripts/quality_gate.py \</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>            --run-id      ${{ inputs.run_id }} \</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>            --baseline    Production \</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>            --thresholds  configs/deployment/thresholds.yaml \</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>            --output      gate_result.json</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Upload evaluation report</span></span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/upload-artifact@v4</span></span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">name</span><span class="kw">:</span><span class="at"> eval-report-${{ inputs.run_id }}</span></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a><span class="fu">          path</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>            eval_report.json</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>            gate_result.json</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Register model if gate passes</span></span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">if</span><span class="kw">:</span><span class="at"> ${{ steps.gate.outputs.passed == 'true' }}</span></span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>          python scripts/register_model.py \</span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a>            --run-id    ${{ inputs.run_id }} \</span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>            --name      cv-model \</span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>            --stage     Staging \</span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>            --alias     "candidate-${{ github.sha }}"</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="model-registry-workflow-with-mlflow" class="level2">
<h2 class="anchored" data-anchor-id="model-registry-workflow-with-mlflow" id="model-registry-workflow-with-mlflow">Model Registry Workflow with MLFlow</h2>
<section id="stage-transitions" class="level3">
<h3 class="anchored" data-anchor-id="stage-transitions" id="stage-transitions">Stage Transitions</h3>
<p>Use MLFlow’s registry stages as a formal promotion pipeline: <code>None → Staging → Production</code>. Never skip a stage in automation — only allow it via manual approval.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-promotion-ladder" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-promotion-ladder-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-promotion-ladder">flowchart TD
    RUN["🏋️ Training Run (GitHub Actions · GPU runner)"]
    STAGING["📦 Staging registered candidate model"]
    PRODUCTION["🚀 Production serving live traffic"]
    ARCHIVED["🗄️ Archived retained for rollback"]

    RUN --&gt;|"quality gate passed automated by evaluate.yml"| STAGING
    STAGING --&gt;|"manual approval in GitHub Environments OR integration tests pass automated by deploy.yml"| PRODUCTION
    PRODUCTION --&gt;|"deprecate after N days or on next promotion"| ARCHIVED

    ARCHIVED -.-&gt;|"rollback path rollback.yml"| PRODUCTION

    style RUN      fill:#cce5ff,stroke:#004085
    style STAGING  fill:#fff3cd,stroke:#856404
    style PRODUCTION fill:#d4edda,stroke:#28a745
    style ARCHIVED fill:#e2e3e5,stroke:#6c757d
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-promotion-ladder-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;2: MLFlow model registry stage promotion pipeline
</figcaption>
</figure>
</div>
</div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>scripts/promote_model.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10" data-filename="scripts/promote_model.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.tracking <span class="im">import</span> MlflowClient</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> promote_to_production(model_name: <span class="bu">str</span>, staging_version: <span class="bu">str</span>):</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    client <span class="op">=</span> MlflowClient()</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Archive current Production before promoting</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    prod_versions <span class="op">=</span> client.get_latest_versions(model_name, stages<span class="op">=</span>[<span class="st">"Production"</span>])</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> v <span class="kw">in</span> prod_versions:</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        client.transition_model_version_stage(</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>            name<span class="op">=</span>model_name, version<span class="op">=</span>v.version, stage<span class="op">=</span><span class="st">"Archived"</span>,</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>            archive_existing_versions<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Archived version </span><span class="sc">{</span>v<span class="sc">.</span>version<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Promote Staging to Production</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    client.transition_model_version_stage(</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span>model_name, version<span class="op">=</span>staging_version, stage<span class="op">=</span><span class="st">"Production"</span>,</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    client.set_model_version_tag(model_name, staging_version,</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>                                  <span class="st">"promoted_by"</span>, os.environ.get(<span class="st">"GITHUB_ACTOR"</span>))</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    client.set_model_version_tag(model_name, staging_version,</span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>                                  <span class="st">"promoted_at"</span>, datetime.utcnow().isoformat())</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Promoted version </span><span class="sc">{</span>staging_version<span class="sc">}</span><span class="ss"> to Production ✓"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="quality-gate-script" class="level3">
<h3 class="anchored" data-anchor-id="quality-gate-script" id="quality-gate-script">Quality Gate Script</h3>
<p>Define acceptance thresholds in config, not hardcoded in scripts. This lets you tighten gates per dataset or model class without changing pipeline code.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>configs/deployment/thresholds.yaml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11" data-filename="configs/deployment/thresholds.yaml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="fu">min_improvement_over_baseline</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.005</span><span class="co">   # mAP must improve by ≥ 0.5%</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="fu">absolute_thresholds</span><span class="kw">:</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">val/mAP</span><span class="kw">:</span><span class="at">       </span><span class="fl">0.72</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">val/precision</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.80</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">val/recall</span><span class="kw">:</span><span class="at">    </span><span class="fl">0.75</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="fu">regression_thresholds</span><span class="kw">:</span><span class="co">               # alert if drop is larger than these</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">val/mAP</span><span class="kw">:</span><span class="at">       </span><span class="fl">0.02</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="fu">max_latency_p95_ms</span><span class="kw">:</span><span class="at"> </span><span class="dv">150</span></span></code></pre></div></div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>scripts/quality_gate.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12" data-filename="scripts/quality_gate.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow, yaml, json, sys</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> check_gate(run_id, baseline_stage, thresholds_path, output_path):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    client  <span class="op">=</span> mlflow.tracking.MlflowClient()</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    run     <span class="op">=</span> client.get_run(run_id)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    metrics <span class="op">=</span> run.data.metrics</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    thresholds <span class="op">=</span> yaml.safe_load(<span class="bu">open</span>(thresholds_path))</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    results, passed <span class="op">=</span> {}, <span class="va">True</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Absolute threshold checks</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> metric, min_val <span class="kw">in</span> thresholds[<span class="st">"absolute_thresholds"</span>].items():</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        actual <span class="op">=</span> metrics.get(metric, <span class="fl">0.0</span>)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        ok     <span class="op">=</span> actual <span class="op">&gt;=</span> min_val</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        results[metric] <span class="op">=</span> {<span class="st">"actual"</span>: actual, <span class="st">"threshold"</span>: min_val, <span class="st">"passed"</span>: ok}</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> ok:</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"FAIL </span><span class="sc">{</span>metric<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>actual<span class="sc">:.4f}</span><span class="ss"> &lt; </span><span class="sc">{</span>min_val<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>            passed <span class="op">=</span> <span class="va">False</span></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Regression check vs baseline Production model</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>        baseline_versions <span class="op">=</span> client.get_latest_versions(<span class="st">"cv-model"</span>, stages<span class="op">=</span>[baseline_stage])</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> baseline_versions:</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>            baseline_run <span class="op">=</span> client.get_run(baseline_versions[<span class="dv">0</span>].run_id)</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>            baseline_map  <span class="op">=</span> baseline_run.data.metrics.get(<span class="st">"val/mAP"</span>, <span class="fl">0.0</span>)</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>            candidate_map <span class="op">=</span> metrics.get(<span class="st">"val/mAP"</span>, <span class="fl">0.0</span>)</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>            delta <span class="op">=</span> candidate_map <span class="op">-</span> baseline_map</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>            min_delta <span class="op">=</span> thresholds[<span class="st">"min_improvement_over_baseline"</span>]</span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>            ok <span class="op">=</span> delta <span class="op">&gt;=</span> <span class="op">-</span>thresholds[<span class="st">"regression_thresholds"</span>][<span class="st">"val/mAP"</span>]</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>            results[<span class="st">"regression_check"</span>] <span class="op">=</span> {</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">"baseline_mAP"</span>: baseline_map, <span class="st">"candidate_mAP"</span>: candidate_map,</span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>                <span class="st">"delta"</span>: delta, <span class="st">"passed"</span>: ok</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="kw">not</span> ok:</span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"FAIL regression: mAP dropped by </span><span class="sc">{</span><span class="bu">abs</span>(delta)<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>                passed <span class="op">=</span> <span class="va">False</span></span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"WARN: Could not compare to baseline: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>    json.dump({<span class="st">"passed"</span>: passed, <span class="st">"details"</span>: results}, <span class="bu">open</span>(output_path, <span class="st">"w"</span>), indent<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Gate result: </span><span class="sc">{</span><span class="st">'PASSED ✓'</span> <span class="cf">if</span> passed <span class="cf">else</span> <span class="st">'FAILED ✗'</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Write GitHub Actions output</span></span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(os.environ[<span class="st">"GITHUB_OUTPUT"</span>], <span class="st">"a"</span>) <span class="im">as</span> f:</span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>        f.write(<span class="ss">f"passed=</span><span class="sc">{</span><span class="st">'true'</span> <span class="cf">if</span> passed <span class="cf">else</span> <span class="st">'false'</span><span class="sc">}</span><span class="ss"> "</span>)</span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>    sys.exit(<span class="dv">0</span> <span class="cf">if</span> passed <span class="cf">else</span> <span class="dv">1</span>)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="data-artifact-versioning" class="level2">
<h2 class="anchored" data-anchor-id="data-artifact-versioning" id="data-artifact-versioning">Data &amp; Artifact Versioning</h2>
<section id="dvc-mlflow-integration" class="level3">
<h3 class="anchored" data-anchor-id="dvc-mlflow-integration" id="dvc-mlflow-integration">DVC + MLFlow Integration</h3>
<p>Data versioning is as important as code versioning for CV. Use DVC for data, MLFlow for model artifacts — and link them explicitly.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/training/data_utils.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13" data-filename="src/training/data_utils.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> subprocess, hashlib</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_data_hash(data_dir: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">str</span>:</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Compute SHA256 of the DVC lock file for this dataset."""</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    lock <span class="op">=</span> Path(data_dir).parent <span class="op">/</span> <span class="st">"dvc.lock"</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> hashlib.sha256(lock.read_bytes()).hexdigest()[:<span class="dv">12</span>]</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Log the DVC commit hash alongside the model</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run():</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    dvc_hash <span class="op">=</span> subprocess.check_output(</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        [<span class="st">"dvc"</span>, <span class="st">"data"</span>, <span class="st">"status"</span>, <span class="st">"--json"</span>]</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    ).decode().strip()</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    mlflow.set_tag(<span class="st">"data.dvc_hash"</span>, get_data_hash(<span class="st">"data/processed"</span>))</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    mlflow.log_artifact(<span class="st">"data.dvc"</span>, <span class="st">"data_version"</span>)    <span class="co"># log the .dvc pointer file</span></span></code></pre></div></div>
</div>
</section>
<section id="artifact-storage-hierarchy" class="level3">
<h3 class="anchored" data-anchor-id="artifact-storage-hierarchy" id="artifact-storage-hierarchy">Artifact Storage Hierarchy</h3>
<p>Organise S3/artifact storage so old experiments are easy to find and prune:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-s3-layout" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-s3-layout-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-s3-layout">flowchart TD
    BUCKET["🪣 s3://your-bucket/mlflow/"]
    EXP["{experiment_id}/"]
    RUN["{run_id}/"]
    ART["artifacts/"]
    MET["metrics/ MLFlow metric files auto-managed"]

    MODEL["model/ PyTorch · ONNX weights"]
    EVAL["eval/ Confusion matrices PR curves"]
    CONFIG["config/ Full config snapshot"]
    DATA["data_version/ DVC pointer file"]

    BUCKET --&gt; EXP
    EXP --&gt; RUN
    RUN --&gt; ART
    RUN --&gt; MET
    ART --&gt; MODEL
    ART --&gt; EVAL
    ART --&gt; CONFIG
    ART --&gt; DATA

    style BUCKET fill:#fff3cd,stroke:#856404
    style MODEL  fill:#cce5ff,stroke:#004085
    style EVAL   fill:#d4edda,stroke:#28a745
    style CONFIG fill:#e2d9f3,stroke:#6f42c1
    style DATA   fill:#f8d7da,stroke:#721c24
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-s3-layout-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;3: S3 artifact storage hierarchy under MLFlow
</figcaption>
</figure>
</div>
</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Artifact Retention
</div>
</div>
<div class="callout-body-container callout-body">
<p>Set artifact retention policies at the <strong>storage level</strong> (S3 lifecycle rules, GCS Object Lifecycle). Don’t delete from the MLFlow UI — that only removes metadata and leaves orphaned binaries in object storage.</p>
</div>
</div>
<hr>
</section>
</section>
<section id="cv-specific-quality-gates" class="level2">
<h2 class="anchored" data-anchor-id="cv-specific-quality-gates" id="cv-specific-quality-gates">CV-Specific Quality Gates</h2>
<p>Beyond mAP, production CV systems require domain-specific checks that generic ML pipelines miss.</p>
<section id="per-class-performance-gate" class="level3">
<h3 class="anchored" data-anchor-id="per-class-performance-gate" id="per-class-performance-gate">Per-Class Performance Gate</h3>
<p>A model that improves aggregate mAP but collapses a safety-critical class should be blocked.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>scripts/per_class_gate.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14" data-filename="scripts/per_class_gate.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> check_per_class_thresholds(run_id: <span class="bu">str</span>, min_per_class_ap: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.60</span>):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    client <span class="op">=</span> mlflow.tracking.MlflowClient()</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    run    <span class="op">=</span> client.get_run(run_id)</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Expect per-class AP logged as "class/{classname}/AP"</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    class_aps <span class="op">=</span> {</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        k.replace(<span class="st">"class/"</span>, <span class="st">""</span>).replace(<span class="st">"/AP"</span>, <span class="st">""</span>): v</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> k, v <span class="kw">in</span> run.data.metrics.items()</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> k.startswith(<span class="st">"class/"</span>) <span class="kw">and</span> k.endswith(<span class="st">"/AP"</span>)</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    failing <span class="op">=</span> {cls: ap <span class="cf">for</span> cls, ap <span class="kw">in</span> class_aps.items() <span class="cf">if</span> ap <span class="op">&lt;</span> min_per_class_ap}</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> failing:</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Per-class failures:"</span>)</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> cls, ap <span class="kw">in</span> failing.items():</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>cls<span class="sc">}</span><span class="ss">: AP=</span><span class="sc">{</span>ap<span class="sc">:.3f}</span><span class="ss"> &lt; </span><span class="sc">{</span>min_per_class_ap<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">False</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="va">True</span></span></code></pre></div></div>
</div>
</section>
<section id="inference-latency-gate" class="level3">
<h3 class="anchored" data-anchor-id="inference-latency-gate" id="inference-latency-gate">Inference Latency Gate</h3>
<p>Log latency during evaluation, not just accuracy — a 2× slower model is often a deployment blocker regardless of mAP.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/evaluation/latency.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15" data-filename="src/evaluation/latency.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time, torch</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_inference(model, input_size<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">640</span>, <span class="dv">640</span>), n_warmup<span class="op">=</span><span class="dv">10</span>, n_iters<span class="op">=</span><span class="dv">100</span>, device<span class="op">=</span><span class="st">"cuda"</span>):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    dummy <span class="op">=</span> torch.randn(<span class="op">*</span>input_size).to(device)</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warm up</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_warmup):</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            model(dummy)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    torch.cuda.synchronize()</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    times <span class="op">=</span> []</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_iters):</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        t0 <span class="op">=</span> time.perf_counter()</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>            model(dummy)</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize()</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        times.append((time.perf_counter() <span class="op">-</span> t0) <span class="op">*</span> <span class="dv">1000</span>)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>    mlflow.log_metrics({</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>        <span class="st">"latency/mean_ms"</span>: np.mean(times),</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>        <span class="st">"latency/p95_ms"</span>:  np.percentile(times, <span class="dv">95</span>),</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>        <span class="st">"latency/p99_ms"</span>:  np.percentile(times, <span class="dv">99</span>),</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>    })</span></code></pre></div></div>
</div>
</section>
<section id="distribution-shift-detection-gate-pre-deploy" class="level3">
<h3 class="anchored" data-anchor-id="distribution-shift-detection-gate-pre-deploy" id="distribution-shift-detection-gate-pre-deploy">Distribution Shift Detection Gate (Pre-Deploy)</h3>
<p>Before deploying to production, validate the candidate model on a held-out dataset that represents recent production traffic — not just the original test split.</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>In evaluate.yml — production distribution check</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16" data-filename="In evaluate.yml — production distribution check"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Check distribution shift robustness</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="fu">  run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    python scripts/eval_production_sample.py \</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>      --run-id ${{ inputs.run_id }} \</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>      --dataset-path data/production_sample/latest \</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>      --min-mAP 0.65          # Lower threshold for noisy production data</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="monitoring-drift-detection-in-production" class="level2">
<h2 class="anchored" data-anchor-id="monitoring-drift-detection-in-production" id="monitoring-drift-detection-in-production">Monitoring &amp; Drift Detection in Production</h2>
<p>Closing the loop between production and CI is what separates “deployed” from “operational.”</p>
<section id="scheduled-drift-detection-workflow" class="level3">
<h3 class="anchored" data-anchor-id="scheduled-drift-detection-workflow" id="scheduled-drift-detection-workflow">Scheduled Drift Detection Workflow</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/drift_monitor.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17" data-filename=".github/workflows/drift_monitor.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> Production Drift Monitor</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">schedule</span><span class="kw">:</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">cron</span><span class="kw">:</span><span class="at"> </span><span class="st">"0 6 * * *"</span><span class="co">     # Daily at 06:00 UTC</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">workflow_dispatch</span><span class="kw">:</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">detect-drift</span><span class="kw">:</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> ${{ secrets.MLFLOW_TRACKING_URI }}</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    ${{ secrets.MLFLOW_S3_BUCKET }}</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Sample production predictions</span></span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> python scripts/sample_production_logs.py --n 1000 --output prod_sample.parquet</span></span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Run drift detection</span></span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">id</span><span class="kw">:</span><span class="at"> drift</span></span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>          python scripts/detect_drift.py \</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>            --production-sample prod_sample.parquet \</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>            --reference-dataset data/processed/latest \</span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>            --model-stage Production \</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>            --output drift_report.json</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Alert if drift detected</span></span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">if</span><span class="kw">:</span><span class="at"> ${{ steps.drift.outputs.drift_detected == 'true' }}</span></span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> slackapi/slack-github-action@v1</span></span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a><span class="fu">          payload</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>            {"text": "⚠️ Production drift detected. mAP degraded by ${{ steps.drift.outputs.map_delta }}. Consider re-training."}</span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">SLACK_WEBHOOK_URL</span><span class="kw">:</span><span class="at"> ${{ secrets.SLACK_WEBHOOK_URL }}</span></span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Log drift metrics to MLFlow</span></span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">run</span><span class="kw">:</span><span class="at"> python scripts/log_drift_to_mlflow.py --report drift_report.json</span></span></code></pre></div></div>
</div>
</section>
<section id="log-what-matters-in-serving" class="level3">
<h3 class="anchored" data-anchor-id="log-what-matters-in-serving" id="log-what-matters-in-serving">Log What Matters in Serving</h3>
<p>In your inference service, emit metrics that MLFlow (or your monitoring stack) can consume:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/serving/monitored_model.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18" data-filename="src/serving/monitored_model.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow, time</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MonitoredCVModel:</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_name<span class="op">=</span><span class="st">"cv-model"</span>, stage<span class="op">=</span><span class="st">"Production"</span>):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> mlflow.pyfunc.load_model(<span class="ss">f"models:/</span><span class="sc">{</span>model_name<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>stage<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.run_id <span class="op">=</span> mlflow.tracking.MlflowClient() <span class="op">\</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>            .get_latest_versions(model_name, [stage])[<span class="dv">0</span>].run_id</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, image_batch):</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        t0 <span class="op">=</span> time.perf_counter()</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.model.predict(image_batch)</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        latency <span class="op">=</span> (time.perf_counter() <span class="op">-</span> t0) <span class="op">*</span> <span class="dv">1000</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Emit to your metrics sink (Prometheus, CloudWatch, etc.)</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        emit_metric(<span class="st">"inference.latency_ms"</span>,  latency)</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>        emit_metric(<span class="st">"inference.batch_size"</span>,  <span class="bu">len</span>(image_batch))</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>        emit_metric(<span class="st">"inference.low_confidence_ratio"</span>,</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>                    (result.<span class="bu">max</span>(axis<span class="op">=</span><span class="dv">1</span>) <span class="op">&lt;</span> <span class="fl">0.5</span>).mean())</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="security-secrets-management" class="level2">
<h2 class="anchored" data-anchor-id="security-secrets-management" id="security-secrets-management">Security &amp; Secrets Management</h2>
<section id="secrets-strategy" class="level3">
<h3 class="anchored" data-anchor-id="secrets-strategy" id="secrets-strategy">Secrets Strategy</h3>
<table class="table-striped table-hover caption-top table">
<caption>Secrets placement strategy</caption>
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Secret</th>
<th>Where</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>MLFLOW_TRACKING_URI</code></td>
<td>GitHub Environment secret</td>
<td>Scope to <code>training</code> and <code>deploy</code> environments only</td>
</tr>
<tr class="even">
<td><code>MLFLOW_TRACKING_TOKEN</code></td>
<td>GitHub Environment secret</td>
<td>Use short-lived tokens, rotate monthly</td>
</tr>
<tr class="odd">
<td><code>DVC_AWS_KEY / SECRET</code></td>
<td>GitHub Actions secret</td>
<td>Read-only IAM role — never write access from CI</td>
</tr>
<tr class="even">
<td><code>SLACK_WEBHOOK_URL</code></td>
<td>GitHub Actions secret</td>
<td>Use per-channel webhooks, not workspace tokens</td>
</tr>
<tr class="odd">
<td>Model serving credentials</td>
<td>External secret manager</td>
<td>Inject at deploy time, never in repo</td>
</tr>
</tbody>
</table>
</section>
<section id="prevent-secrets-from-leaking-into-mlflow" class="level3">
<h3 class="anchored" data-anchor-id="prevent-secrets-from-leaking-into-mlflow" id="prevent-secrets-from-leaking-into-mlflow">Prevent Secrets from Leaking into MLFlow</h3>
<p>It’s easy to accidentally log an entire config dict that contains credentials. Guard against it:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/training/safe_logging.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19" data-filename="src/training/safe_logging.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a>SENSITIVE_KEYS <span class="op">=</span> {<span class="st">"api_key"</span>, <span class="st">"password"</span>, <span class="st">"token"</span>, <span class="st">"secret"</span>, <span class="st">"aws_access_key"</span>}</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_log_params(config: <span class="bu">dict</span>):</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Log params, redacting any sensitive keys."""</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    safe <span class="op">=</span> {</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>        k: <span class="st">"[REDACTED]"</span> <span class="cf">if</span> <span class="bu">any</span>(s <span class="kw">in</span> k.lower() <span class="cf">for</span> s <span class="kw">in</span> SENSITIVE_KEYS) <span class="cf">else</span> v</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> k, v <span class="kw">in</span> flatten_dict(config).items()</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    mlflow.log_params(safe)</span></code></pre></div></div>
</div>
</section>
<section id="permissions-hardening-in-workflows" class="level3">
<h3 class="anchored" data-anchor-id="permissions-hardening-in-workflows" id="permissions-hardening-in-workflows">Permissions Hardening in Workflows</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Applies to every workflow file</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20" data-filename="Applies to every workflow file"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="fu">permissions</span><span class="kw">:</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">contents</span><span class="kw">:</span><span class="at"> read</span><span class="co">            # Never write unless you explicitly need it</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">id-token</span><span class="kw">:</span><span class="at"> write</span><span class="co">           # Only if using OIDC for cloud auth</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">actions</span><span class="kw">:</span><span class="at"> read</span></span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Least-Privilege Default
</div>
</div>
<div class="callout-body-container callout-body">
<p>Apply <code>permissions</code> at the workflow level as the default, then override per-job only where escalation is genuinely needed. Omitting this block grants broad default permissions in many GitHub org configurations.</p>
</div>
</div>
<hr>
</section>
</section>
<section id="rollback-strategy" class="level2">
<h2 class="anchored" data-anchor-id="rollback-strategy" id="rollback-strategy">Rollback Strategy</h2>
<p>Production CV models need a documented, tested rollback path — not a post-incident improvisation.</p>
<section id="automated-rollback-trigger" class="level3">
<h3 class="anchored" data-anchor-id="automated-rollback-trigger" id="automated-rollback-trigger">Automated Rollback Trigger</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/rollback.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21" data-filename=".github/workflows/rollback.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> Rollback Production Model</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">workflow_dispatch</span><span class="kw">:</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">inputs</span><span class="kw">:</span></span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">reason</span><span class="kw">:</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">description</span><span class="kw">:</span><span class="at"> </span><span class="st">"Reason for rollback"</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">required</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">rollback</span><span class="kw">:</span></span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ubuntu-latest</span></span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span><span class="at"> production-rollback</span><span class="co">    # Requires approval from on-call engineer</span></span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> ./.github/actions/setup-mlflow</span></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-tracking-uri</span><span class="kw">:</span><span class="at"> ${{ secrets.MLFLOW_TRACKING_URI }}</span></span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mlflow-s3-bucket</span><span class="kw">:</span><span class="at">    ${{ secrets.MLFLOW_S3_BUCKET }}</span></span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Rollback to last Archived model</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a><span class="fu">        run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a>          python scripts/rollback_model.py \</span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>            --model-name cv-model \</span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>            --reason     "${{ inputs.reason }}" \</span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>            --initiated-by "${{ github.actor }}"</span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Notify team</span></span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> slackapi/slack-github-action@v1</span></span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a><span class="fu">          payload</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>            {"text": "🔄 *Production rollback executed* by ${{ github.actor }} Reason: ${{ inputs.reason }}"}</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">SLACK_WEBHOOK_URL</span><span class="kw">:</span><span class="at"> ${{ secrets.SLACK_WEBHOOK_URL }}</span></span></code></pre></div></div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>scripts/rollback_model.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22" data-filename="scripts/rollback_model.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> rollback(model_name: <span class="bu">str</span>, reason: <span class="bu">str</span>, initiated_by: <span class="bu">str</span>):</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>    client <span class="op">=</span> MlflowClient()</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Find last Archived version</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>    archived <span class="op">=</span> client.get_latest_versions(model_name, stages<span class="op">=</span>[<span class="st">"Archived"</span>])</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> archived:</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="st">"No Archived version to roll back to"</span>)</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    rollback_version <span class="op">=</span> <span class="bu">sorted</span>(archived, key<span class="op">=</span><span class="kw">lambda</span> v: <span class="bu">int</span>(v.version))[<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Demote current Production to Archived</span></span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>    current_prod <span class="op">=</span> client.get_latest_versions(model_name, stages<span class="op">=</span>[<span class="st">"Production"</span>])</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> v <span class="kw">in</span> current_prod:</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a>        client.transition_model_version_stage(model_name, v.version, <span class="st">"Archived"</span>)</span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>        client.set_model_version_tag(model_name, v.version, <span class="st">"rolled_back_reason"</span>, reason)</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Restore Archived to Production</span></span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>    client.transition_model_version_stage(model_name, rollback_version.version, <span class="st">"Production"</span>)</span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a>    client.set_model_version_tag(model_name, rollback_version.version,</span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a>                                  <span class="st">"rollback_by"</span>, initiated_by)</span>
<span id="cb22-21"><a href="#cb22-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Rolled back to version </span><span class="sc">{</span>rollback_version<span class="sc">.</span>version<span class="sc">}</span><span class="ss"> ✓"</span>)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="anti-patterns-to-avoid" class="level2">
<h2 class="anchored" data-anchor-id="anti-patterns-to-avoid" id="anti-patterns-to-avoid">Anti-Patterns to Avoid</h2>
<p>These are the most common mistakes teams make when first building CV CI/CD pipelines.</p>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-5-contents" aria-controls="callout-5" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Training inside a GitHub Actions runner without a self-hosted GPU
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-5" class="callout-5-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>GitHub-hosted runners have no GPU. Training a real CV model on them will either time out (6-hour limit) or cost a fortune via expensive compute APIs. Always route training to self-hosted GPU runners or cloud job runners (e.g., AWS Batch, GCP Vertex).</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-6-contents" aria-controls="callout-6" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Logging model weights without a signature
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-6" class="callout-6-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>A model in the registry with no input/output schema is a liability. You lose automatic schema validation in serving and make it impossible to safely automate inference-time assertions.</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-7-contents" aria-controls="callout-7" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Using <code>latest</code> as a data version tag in training
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-7" class="callout-7-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p><code>latest</code> is a moving target. Tag your DVC data versions with explicit identifiers and commit hashes so any run can be reproduced months later.</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-8-contents" aria-controls="callout-8" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Skipping per-class metrics in quality gates
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-8" class="callout-8-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>Aggregate mAP can improve while a low-frequency class (e.g., a rare defect type) collapses. Always gate on per-class metrics for any safety- or business-critical class.</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-9-contents" aria-controls="callout-9" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Hardcoding metric thresholds in workflow YAML
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-9" class="callout-9-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>Thresholds in YAML files require a code change to update, create noisy diffs, and are hard to track historically. Keep thresholds in versioned config files loaded by quality gate scripts.</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-10-contents" aria-controls="callout-10" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Not testing the rollback path
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-10" class="callout-10-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>Rollback procedures that have never been executed will fail under pressure. Run a rollback drill in staging at least once per quarter.</p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-11-contents" aria-controls="callout-11" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Logging to MLFlow from matrix jobs without run naming
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-11" class="callout-11-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<p>Parallel matrix jobs that all call <code>mlflow.start_run()</code> without unique <code>run_name</code> values create a registry of indistinguishable runs. Always embed <code>github.sha</code>, <code>matrix.*</code>, and a timestamp into the run name.</p>
</div>
</div>
</div>
<hr>
</section>
<section id="reference-snippets-cheatsheet" class="level2">
<h2 class="anchored" data-anchor-id="reference-snippets-cheatsheet" id="reference-snippets-cheatsheet">Reference Snippets Cheatsheet</h2>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>MLFlow CLI Quick Reference</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23" data-filename="MLFlow CLI Quick Reference"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Start a local MLFlow server for development</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> server <span class="dt">\</span></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>  <span class="at">--backend-store-uri</span> sqlite:///mlflow.db <span class="dt">\</span></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>  <span class="at">--default-artifact-root</span> ./mlruns <span class="dt">\</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>  <span class="at">--host</span> 0.0.0.0 <span class="at">--port</span> 5000</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Launch a reproducible run via MLProject</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> run . <span class="at">-P</span> config_path=configs/base.yaml <span class="at">-P</span> data_version=v1.3.0</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Compare two runs from CLI</span></span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> runs compare <span class="at">--run-ids</span> <span class="op">&lt;</span>run_a<span class="op">&gt;</span> <span class="op">&lt;</span>run_b<span class="op">&gt;</span></span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a><span class="co"># List Production model versions</span></span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> models list <span class="at">--name</span> cv-model</span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Promote a model version to Production</span></span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> models transition-create <span class="dt">\</span></span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a>  <span class="at">--model-name</span> cv-model <span class="dt">\</span></span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>  <span class="at">--version</span> 12 <span class="dt">\</span></span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a>  <span class="at">--stage</span> Production</span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Serve a model locally for testing</span></span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> models serve <span class="at">-m</span> <span class="st">"models:/cv-model/Staging"</span> <span class="at">-p</span> 8080 <span class="at">--no-conda</span></span></code></pre></div></div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Minimal GitHub Actions context in MLFlow tags</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24" data-filename="Minimal GitHub Actions context in MLFlow tags"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a>mlflow.set_tags({</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.sha"</span>:        os.environ.get(<span class="st">"GITHUB_SHA"</span>, <span class="st">"local"</span>),</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.run_id"</span>:     os.environ.get(<span class="st">"GITHUB_RUN_ID"</span>, <span class="st">"local"</span>),</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.run_number"</span>: os.environ.get(<span class="st">"GITHUB_RUN_NUMBER"</span>, <span class="st">"0"</span>),</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.actor"</span>:      os.environ.get(<span class="st">"GITHUB_ACTOR"</span>, <span class="st">"local"</span>),</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.workflow"</span>:   os.environ.get(<span class="st">"GITHUB_WORKFLOW"</span>, <span class="st">"local"</span>),</span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ci.ref"</span>:        os.environ.get(<span class="st">"GITHUB_REF"</span>, <span class="st">"local"</span>),</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Self-hosted GPU runner label convention</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25" data-filename="Self-hosted GPU runner label convention"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Always pin GPU type for reproducible benchmarks</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">self-hosted</span><span class="kw">,</span><span class="at"> linux</span><span class="kw">,</span><span class="at"> gpu</span><span class="kw">,</span><span class="at"> t4</span><span class="kw">]</span></span></code></pre></div></div>
</div>
<hr>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Version Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Covers MLFlow 2.x and GitHub Actions runner v2.x</p>
</div>
</div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[MLflow Best Practices for Computer Vision (Deep Learning)]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/mlflow-oe/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/mlflow-oe/</guid>
      <pubDate>Thu, 09 Apr 2026 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="mlflow-best-practices-for-computer-vision-deep-learning" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/mlflow-oe/flow.png" class="img-fluid"></p>
<section id="sec-experiment-tracking" class="level2">
<h2 class="anchored" data-anchor-id="sec-experiment-tracking" id="sec-experiment-tracking">Experiment Tracking</h2>
<p>This guide demonstrates how to achieve operational excellence using MLFlow in Production.</p>
<hr>
<div class="callout callout-style-simple callout-note">
<div class="callout-body d-flex">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-body-container">
<p><strong>Compatibility:</strong> MLflow ≥ 2.4 · PyTorch ≥ 2.0 · Python ≥ 3.10</p>
</div>
</div>
</div>
<hr>
<section id="structure-experiments-hierarchically" class="level3">
<h3 class="anchored" data-anchor-id="structure-experiments-hierarchically" id="structure-experiments-hierarchically">Structure Experiments Hierarchically</h3>
<p>Organise experiments to mirror your project structure. Avoid dumping all runs into a single experiment.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="co"># One experiment per model family or research objective</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>mlflow.set_experiment(<span class="st">"resnet-backbone-ablations"</span>)</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>mlflow.set_experiment(<span class="st">"yolov8-object-detection-v2"</span>)</span></code></pre></div></div>
</section>
<section id="use-nested-runs-for-multi-stage-pipelines" class="level3">
<h3 class="anchored" data-anchor-id="use-nested-runs-for-multi-stage-pipelines" id="use-nested-runs-for-multi-stage-pipelines">Use Nested Runs for Multi-Stage Pipelines</h3>
<p>CV pipelines typically consist of preprocessing → training → evaluation → post-processing. Model each stage as a child run.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"full-pipeline"</span>) <span class="im">as</span> parent_run:</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"data-augmentation"</span>, nested<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>        mlflow.log_params({<span class="st">"augment_strategy"</span>: <span class="st">"mosaic"</span>, <span class="st">"img_size"</span>: <span class="dv">640</span>})</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"training"</span>, nested<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        mlflow.log_params({<span class="st">"epochs"</span>: <span class="dv">100</span>, <span class="st">"optimizer"</span>: <span class="st">"AdamW"</span>})</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"evaluation"</span>, nested<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        mlflow.log_metrics({<span class="st">"mAP50"</span>: <span class="fl">0.87</span>, <span class="st">"mAP50-95"</span>: <span class="fl">0.63</span>})</span></code></pre></div></div>
</section>
<section id="tag-runs-consistently" class="level3">
<h3 class="anchored" data-anchor-id="tag-runs-consistently" id="tag-runs-consistently">Tag Runs Consistently</h3>
<p>Tags are queryable — use them as first-class metadata for filtering and governance.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>mlflow.set_tags({</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"task"</span>: <span class="st">"object-detection"</span>,</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"backbone"</span>: <span class="st">"ResNet50"</span>,</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"dataset"</span>: <span class="st">"COCO-2017"</span>,</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"env"</span>: <span class="st">"production"</span>,</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"team"</span>: <span class="st">"cv-platform"</span>,</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"git_commit"</span>: os.getenv(<span class="st">"GIT_COMMIT_SHA"</span>, <span class="st">"unknown"</span>),</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
<p><strong>Recommended Tag Schema:</strong></p>
<div id="tbl-tag-schema" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-tag-schema-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Recommended MLflow tag schema for CV runs
</figcaption>
<div aria-describedby="tbl-tag-schema-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Tag Key</th>
<th>Example Value</th>
<th>Purpose</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>task</code></td>
<td><code>segmentation</code></td>
<td>CV task type</td>
</tr>
<tr class="even">
<td><code>backbone</code></td>
<td><code>EfficientNetV2-L</code></td>
<td>Architecture family</td>
</tr>
<tr class="odd">
<td><code>dataset</code></td>
<td><code>COCO-2017</code></td>
<td>Dataset identifier</td>
</tr>
<tr class="even">
<td><code>env</code></td>
<td><code>staging</code> / <code>production</code></td>
<td>Deployment stage</td>
</tr>
<tr class="odd">
<td><code>git_commit</code></td>
<td><code>a3f8c12</code></td>
<td>Code reproducibility</td>
</tr>
<tr class="even">
<td><code>hardware</code></td>
<td><code>A100-80GB</code></td>
<td>Training hardware</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<hr>
</section>
</section>
<section id="sec-model-logging" class="level2">
<h2 class="anchored" data-anchor-id="sec-model-logging" id="sec-model-logging">Model Logging and Registration</h2>
<section id="log-models-with-signatures-and-input-examples" class="level3">
<h3 class="anchored" data-anchor-id="log-models-with-signatures-and-input-examples" id="log-models-with-signatures-and-input-examples">Log Models with Signatures and Input Examples</h3>
<p>Always include a model signature and a representative input example. This is critical for serving CV models correctly — it prevents type/shape mismatches at inference time.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pytorch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Define signature from a real sample</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>sample_input <span class="op">=</span> np.random.rand(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>).astype(np.float32)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>sample_output <span class="op">=</span> model(torch.tensor(sample_input)).detach().numpy()</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>signature <span class="op">=</span> mlflow.models.infer_signature(sample_input, sample_output)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    pytorch_model<span class="op">=</span>model,</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    artifact_path<span class="op">=</span><span class="st">"model"</span>,</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    signature<span class="op">=</span>signature,</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    input_example<span class="op">=</span>sample_input,</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    registered_model_name<span class="op">=</span><span class="st">"cv-resnet50-classifier"</span>,</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="use-the-model-registry-with-stage-transitions" class="level3">
<h3 class="anchored" data-anchor-id="use-the-model-registry-with-stage-transitions" id="use-the-model-registry-with-stage-transitions">Use the Model Registry with Stage Transitions</h3>
<p>The Model Registry enforces promotion gates: <code>None → Staging → Production → Archived</code>.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.tracking <span class="im">import</span> MlflowClient</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MlflowClient()</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Transition a validated model to production</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>client.transition_model_version_stage(</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"cv-resnet50-classifier"</span>,</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    version<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    stage<span class="op">=</span><span class="st">"Production"</span>,</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    archive_existing_versions<span class="op">=</span><span class="va">True</span>,  <span class="co"># Auto-archive old production version</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Warning
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Always archive old production versions.</strong> Never leave two versions in <code>Production</code> simultaneously unless you are intentionally running A/B traffic splits.</p>
</div>
</div>
</section>
<section id="custom-pyfuncs-for-prepost-processing" class="level3">
<h3 class="anchored" data-anchor-id="custom-pyfuncs-for-prepost-processing" id="custom-pyfuncs-for-prepost-processing">Custom PyFuncs for Pre/Post-Processing</h3>
<p>Wrap preprocessing (resize, normalise, augment) and postprocessing (NMS, softmax, decode boxes) into the model artifact itself using <code>mlflow.pyfunc</code>. This avoids serving-time pipeline drift.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CVModelWrapper(mlflow.pyfunc.PythonModel):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_context(<span class="va">self</span>, context):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> torch</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> torch.load(context.artifacts[<span class="st">"model_path"</span>])</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, context, model_input):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> torch, numpy <span class="im">as</span> np</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        tensor <span class="op">=</span> torch.tensor(model_input).<span class="bu">float</span>() <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        tensor <span class="op">=</span> (tensor <span class="op">-</span> <span class="fl">0.485</span>) <span class="op">/</span> <span class="fl">0.229</span>  <span class="co"># ImageNet normalisation</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            logits <span class="op">=</span> <span class="va">self</span>.model(tensor)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Postprocess</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits.softmax(dim<span class="op">=-</span><span class="dv">1</span>).numpy()</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>mlflow.pyfunc.log_model(</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    artifact_path<span class="op">=</span><span class="st">"cv-model-wrapped"</span>,</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    python_model<span class="op">=</span>CVModelWrapper(),</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    artifacts<span class="op">=</span>{<span class="st">"model_path"</span>: <span class="st">"path/to/model.pt"</span>},</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-artifact-management" class="level2">
<h2 class="anchored" data-anchor-id="sec-artifact-management" id="sec-artifact-management">Artifact Management</h2>
<section id="what-to-log-as-artifacts-cv-specific" class="level3">
<h3 class="anchored" data-anchor-id="what-to-log-as-artifacts-cv-specific" id="what-to-log-as-artifacts-cv-specific">What to Log as Artifacts (CV-Specific)</h3>
<div id="tbl-artifacts" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-artifacts-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: CV-specific artifacts and when to log them
</figcaption>
<div aria-describedby="tbl-artifacts-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 37%">
<col style="width: 21%">
<col style="width: 40%">
</colgroup>
<thead>
<tr class="header">
<th>Artifact</th>
<th>When to Log</th>
<th>Why</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Sample predictions (images)</td>
<td>End of each epoch</td>
<td>Visual debugging of model behaviour</td>
</tr>
<tr class="even">
<td>Confusion matrix (as PNG)</td>
<td>Post-evaluation</td>
<td>Class-level error analysis</td>
</tr>
<tr class="odd">
<td>PR / ROC curves</td>
<td>Post-evaluation</td>
<td>Threshold selection guidance</td>
</tr>
<tr class="even">
<td>Augmentation samples</td>
<td>Pre-training</td>
<td>Verify augmentation pipeline</td>
</tr>
<tr class="odd">
<td>Class activation maps (Grad-CAM)</td>
<td>Debugging runs</td>
<td>Explainability</td>
</tr>
<tr class="even">
<td>ONNX / TorchScript exports</td>
<td>Release candidates</td>
<td>Deployment-ready formats</td>
</tr>
<tr class="odd">
<td>Training config YAML</td>
<td>Every run</td>
<td>Full reproducibility</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Log a batch of predictions as an image grid</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>fig, axes <span class="op">=</span> plt.subplots(<span class="dv">2</span>, <span class="dv">4</span>, figsize<span class="op">=</span>(<span class="dv">16</span>, <span class="dv">8</span>))</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, ax <span class="kw">in</span> <span class="bu">enumerate</span>(axes.flat):</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    ax.imshow(pred_images[i])</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    ax.set_title(<span class="ss">f"Pred: </span><span class="sc">{</span>pred_labels[i]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>plt.tight_layout()</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>plt.savefig(<span class="st">"/tmp/predictions_epoch_10.png"</span>)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>mlflow.log_artifact(<span class="st">"/tmp/predictions_epoch_10.png"</span>, artifact_path<span class="op">=</span><span class="st">"visualisations"</span>)</span></code></pre></div></div>
</section>
<section id="log-config-files-not-just-parameters" class="level3">
<h3 class="anchored" data-anchor-id="log-config-files-not-just-parameters" id="log-config-files-not-just-parameters">Log Config Files, Not Just Parameters</h3>
<p>Log the full YAML/JSON config alongside individual parameters. This is your single source of truth for reproducibility.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a>mlflow.log_artifact(<span class="st">"configs/train_config.yaml"</span>, artifact_path<span class="op">=</span><span class="st">"configs"</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Logging the config file ensures you can fully reconstruct the training environment even if individual <code>log_params</code> calls are incomplete or inconsistent.</p>
</div>
</div>
<hr>
</section>
</section>
<section id="sec-dataset-versioning" class="level2">
<h2 class="anchored" data-anchor-id="sec-dataset-versioning" id="sec-dataset-versioning">Dataset Versioning &amp; Lineage</h2>
<section id="use-mlflow.log_input-mlflow-2.4" class="level3">
<h3 class="anchored" data-anchor-id="use-mlflow.log_input-mlflow-2.4" id="use-mlflow.log_input-mlflow-2.4">Use <code>mlflow.log_input()</code> (MLflow ≥ 2.4)</h3>
<p>Track exact dataset versions to make runs reproducible and auditable.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a>dataset <span class="op">=</span> mlflow.data.from_numpy(</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    features<span class="op">=</span>X_train,</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    targets<span class="op">=</span>y_train,</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"coco-detection-train"</span>,</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    source<span class="op">=</span><span class="st">"s3://your-bucket/datasets/coco/2017/train/"</span>,</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run():</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    mlflow.log_input(dataset, context<span class="op">=</span><span class="st">"training"</span>)</span></code></pre></div></div>
</section>
<section id="record-dataset-hashes" class="level3">
<h3 class="anchored" data-anchor-id="record-dataset-hashes" id="record-dataset-hashes">Record Dataset Hashes</h3>
<p>For local or cached datasets, compute and log a SHA-256 checksum:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> hashlib</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> dataset_hash(path: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">str</span>:</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    h <span class="op">=</span> hashlib.sha256()</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(path, <span class="st">"rb"</span>) <span class="im">as</span> f:</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> chunk <span class="kw">in</span> <span class="bu">iter</span>(<span class="kw">lambda</span>: f.read(<span class="dv">65536</span>), <span class="st">b""</span>):</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>            h.update(chunk)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> h.hexdigest()</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>mlflow.log_param(<span class="st">"train_dataset_sha256"</span>, dataset_hash(<span class="st">"/data/train.tar.gz"</span>))</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-hyperparameters" class="level2">
<h2 class="anchored" data-anchor-id="sec-hyperparameters" id="sec-hyperparameters">Hyperparameter Management</h2>
<section id="log-all-hyperparameters-including-implicit-ones" class="level3">
<h3 class="anchored" data-anchor-id="log-all-hyperparameters-including-implicit-ones" id="log-all-hyperparameters-including-implicit-ones">Log All Hyperparameters — Including Implicit Ones</h3>
<p>Don’t log only the obvious params. CV training has many implicit settings that affect results.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a>mlflow.log_params({</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimiser</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: <span class="st">"AdamW"</span>,</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"lr"</span>: <span class="fl">1e-4</span>,</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"weight_decay"</span>: <span class="fl">1e-2</span>,</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"lr_scheduler"</span>: <span class="st">"cosine_annealing"</span>,</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"warmup_epochs"</span>: <span class="dv">5</span>,</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">"img_size"</span>: <span class="dv">640</span>,</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">"batch_size"</span>: <span class="dv">32</span>,</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"num_workers"</span>: <span class="dv">8</span>,</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">"augment_mosaic"</span>: <span class="va">True</span>,</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"augment_mixup"</span>: <span class="fl">0.1</span>,</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">"augment_hsv_h"</span>: <span class="fl">0.015</span>,</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Architecture</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    <span class="st">"backbone"</span>: <span class="st">"EfficientNetV2-L"</span>,</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    <span class="st">"pretrained"</span>: <span class="va">True</span>,</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    <span class="st">"freeze_backbone_epochs"</span>: <span class="dv">10</span>,</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    <span class="st">"dropout"</span>: <span class="fl">0.2</span>,</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>    <span class="st">"epochs"</span>: <span class="dv">200</span>,</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>    <span class="st">"early_stopping_patience"</span>: <span class="dv">15</span>,</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>    <span class="st">"amp"</span>: <span class="va">True</span>,       <span class="co"># Automatic mixed precision</span></span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gradient_clip"</span>: <span class="fl">10.0</span>,</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    <span class="st">"seed"</span>: <span class="dv">42</span>,</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
</section>
<section id="integrate-with-optuna-ray-tune-for-hpo" class="level3">
<h3 class="anchored" data-anchor-id="integrate-with-optuna-ray-tune-for-hpo" id="integrate-with-optuna-ray-tune-for-hpo">Integrate with Optuna / Ray Tune for HPO</h3>
<p>When running hyperparameter optimisation sweeps, each trial should be its own MLflow run, nested under a parent sweep run.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> optuna</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> objective(trial):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> trial.suggest_float(<span class="st">"lr"</span>, <span class="fl">1e-5</span>, <span class="fl">1e-2</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    dropout <span class="op">=</span> trial.suggest_float(<span class="st">"dropout"</span>, <span class="fl">0.1</span>, <span class="fl">0.5</span>)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run(nested<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        mlflow.log_params({<span class="st">"lr"</span>: lr, <span class="st">"dropout"</span>: dropout})</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        val_map <span class="op">=</span> train_and_evaluate(lr<span class="op">=</span>lr, dropout<span class="op">=</span>dropout)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        mlflow.log_metric(<span class="st">"val_mAP50"</span>, val_map)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> val_map</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"hpo-sweep"</span>):</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    study <span class="op">=</span> optuna.create_study(direction<span class="op">=</span><span class="st">"maximize"</span>)</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    study.optimize(objective, n_trials<span class="op">=</span><span class="dv">50</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-metrics" class="level2">
<h2 class="anchored" data-anchor-id="sec-metrics" id="sec-metrics">Metrics &amp; Evaluation</h2>
<section id="log-metrics-at-the-right-granularity" class="level3">
<h3 class="anchored" data-anchor-id="log-metrics-at-the-right-granularity" id="log-metrics-at-the-right-granularity">Log Metrics at the Right Granularity</h3>
<p>Log per-step metrics for loss curves, per-epoch metrics for validation scores, and summary metrics at run end.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    train_loss <span class="op">=</span> run_training_epoch(...)</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    val_map, val_map95 <span class="op">=</span> run_validation(...)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    mlflow.log_metrics({</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"train/loss"</span>: train_loss,</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"val/mAP50"</span>: val_map,</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"val/mAP50-95"</span>: val_map95,</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"lr"</span>: scheduler.get_last_lr()[<span class="dv">0</span>],</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    }, step<span class="op">=</span>epoch)</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Summary at end of training</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>mlflow.log_metrics({</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"best_val_mAP50"</span>: best_map,</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">"best_epoch"</span>: best_epoch,</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="st">"total_train_time_hrs"</span>: elapsed <span class="op">/</span> <span class="dv">3600</span>,</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
</section>
<section id="log-task-specific-cv-metrics" class="level3">
<h3 class="anchored" data-anchor-id="log-task-specific-cv-metrics" id="log-task-specific-cv-metrics">Log Task-Specific CV Metrics</h3>
<div id="tbl-cv-metrics" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-cv-metrics-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;3: Task-specific metrics for common CV tasks
</figcaption>
<div aria-describedby="tbl-cv-metrics-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 28%">
<col style="width: 71%">
</colgroup>
<thead>
<tr class="header">
<th>Task</th>
<th>Key Metrics to Log</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Classification</td>
<td><code>top1_acc</code>, <code>top5_acc</code>, <code>per_class_f1</code></td>
</tr>
<tr class="even">
<td>Object Detection</td>
<td><code>mAP50</code>, <code>mAP50-95</code>, <code>precision</code>, <code>recall</code>, <code>FPS</code></td>
</tr>
<tr class="odd">
<td>Semantic Segmentation</td>
<td><code>mIoU</code>, <code>pixel_acc</code>, <code>per_class_IoU</code></td>
</tr>
<tr class="even">
<td>Instance Segmentation</td>
<td><code>mask_AP</code>, <code>bbox_AP</code></td>
</tr>
<tr class="odd">
<td>Anomaly Detection</td>
<td><code>AUROC</code>, <code>AUPRC</code>, <code>F1@best_threshold</code></td>
</tr>
<tr class="even">
<td>Depth Estimation</td>
<td><code>AbsRel</code>, <code>RMSE</code>, <code>delta_1</code></td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="use-mlflow.evaluate-for-standardised-post-training-evaluation" class="level3">
<h3 class="anchored" data-anchor-id="use-mlflow.evaluate-for-standardised-post-training-evaluation" id="use-mlflow.evaluate-for-standardised-post-training-evaluation">Use <code>mlflow.evaluate()</code> for Standardised Post-Training Evaluation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> mlflow.evaluate(</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span><span class="ss">f"runs:/</span><span class="sc">{</span>run_id<span class="sc">}</span><span class="ss">/model"</span>,</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>test_dataset,</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    targets<span class="op">=</span><span class="st">"labels"</span>,</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    model_type<span class="op">=</span><span class="st">"classifier"</span>,</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    evaluators<span class="op">=</span>[<span class="st">"default"</span>],</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    extra_metrics<span class="op">=</span>[</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        mlflow.metrics.precision_score(average<span class="op">=</span><span class="st">"macro"</span>),</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        mlflow.metrics.recall_score(average<span class="op">=</span><span class="st">"macro"</span>),</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result.metrics)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-serving" class="level2">
<h2 class="anchored" data-anchor-id="sec-serving" id="sec-serving">Model Serving &amp; Deployment</h2>
<section id="export-to-onnx-and-log-as-artifact" class="level3">
<h3 class="anchored" data-anchor-id="export-to-onnx-and-log-as-artifact" id="export-to-onnx-and-log-as-artifact">Export to ONNX and Log as artifact</h3>
<p>For production inference, ONNX enables hardware-agnostic deployment (TensorRT, OpenVINO, ONNX Runtime).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">640</span>, <span class="dv">640</span>)</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    dummy_input,</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"/tmp/model.onnx"</span>,</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    opset_version<span class="op">=</span><span class="dv">17</span>,</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    input_names<span class="op">=</span>[<span class="st">"images"</span>],</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    output_names<span class="op">=</span>[<span class="st">"output"</span>],</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>    dynamic_axes<span class="op">=</span>{<span class="st">"images"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>}, <span class="st">"output"</span>: {<span class="dv">0</span>: <span class="st">"batch_size"</span>}},</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run():</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    mlflow.log_artifact(<span class="st">"/tmp/model.onnx"</span>, artifact_path<span class="op">=</span><span class="st">"onnx"</span>)</span></code></pre></div></div>
</section>
<section id="load-production-models-by-stage-not-by-run-id" class="level3">
<h3 class="anchored" data-anchor-id="load-production-models-by-stage-not-by-run-id" id="load-production-models-by-stage-not-by-run-id">Load Production Models by Stage, Not by Run ID</h3>
<p>Never hardcode a <code>run_id</code> in serving code. Always load by registry stage.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># ✅ Correct — stage-based loading</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> mlflow.pytorch.load_model(<span class="st">"models:/cv-resnet50-classifier/Production"</span>)</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="co"># ❌ Avoid — brittle, ties serving code to a specific run</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> mlflow.pytorch.load_model(<span class="st">"runs:/abc123xyz/model"</span>)</span></code></pre></div></div>
</section>
<section id="log-inference-latency-as-a-metric" class="level3">
<h3 class="anchored" data-anchor-id="log-inference-latency-as-a-metric" id="log-inference-latency-as-a-metric">Log Inference Latency as a Metric</h3>
<p>Track per-batch and per-image latency as part of your evaluation run:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>latencies <span class="op">=</span> []</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> test_loader:</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    t0 <span class="op">=</span> time.perf_counter()</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    _ <span class="op">=</span> model(batch)</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    latencies.append((time.perf_counter() <span class="op">-</span> t0) <span class="op">*</span> <span class="dv">1000</span>)</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>mlflow.log_metrics({</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">"p50_latency_ms"</span>: np.percentile(latencies, <span class="dv">50</span>),</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"p95_latency_ms"</span>: np.percentile(latencies, <span class="dv">95</span>),</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">"p99_latency_ms"</span>: np.percentile(latencies, <span class="dv">99</span>),</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"throughput_imgs_per_sec"</span>: <span class="bu">len</span>(test_loader.dataset) <span class="op">/</span> (<span class="bu">sum</span>(latencies) <span class="op">/</span> <span class="dv">1000</span>),</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-cicd" class="level2">
<h2 class="anchored" data-anchor-id="sec-cicd" id="sec-cicd">CI/CD Integration</h2>
<section id="gate-promotions-on-metric-thresholds" class="level3">
<h3 class="anchored" data-anchor-id="gate-promotions-on-metric-thresholds" id="gate-promotions-on-metric-thresholds">Gate Promotions on Metric Thresholds</h3>
<p>Never promote a model to production manually. Automate stage transitions with metric gates.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.tracking <span class="im">import</span> MlflowClient</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MlflowClient()</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>run <span class="op">=</span> client.get_run(candidate_run_id)</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>metrics <span class="op">=</span> run.data.metrics</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>PRODUCTION_GATE <span class="op">=</span> {</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"val/mAP50"</span>: <span class="fl">0.85</span>,</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">"p95_latency_ms"</span>: <span class="fl">50.0</span>,</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>passed <span class="op">=</span> <span class="bu">all</span>(</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    metrics.get(k, <span class="dv">0</span>) <span class="op">&gt;=</span> v <span class="cf">if</span> <span class="st">"latency"</span> <span class="kw">not</span> <span class="kw">in</span> k</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span> metrics.get(k, <span class="dv">9999</span>) <span class="op">&lt;=</span> v</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> k, v <span class="kw">in</span> PRODUCTION_GATE.items()</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> passed:</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>    client.transition_model_version_stage(</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">"cv-detector"</span>,</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        version<span class="op">=</span>candidate_version,</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        stage<span class="op">=</span><span class="st">"Production"</span>,</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>        archive_existing_versions<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"✅ Promoted to Production"</span>)</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a><span class="cf">else</span>:</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"❌ Failed promotion gate"</span>)</span></code></pre></div></div>
</section>
<section id="automate-comparison-against-current-champion" class="level3">
<h3 class="anchored" data-anchor-id="automate-comparison-against-current-champion" id="automate-comparison-against-current-champion">Automate Comparison Against Current Champion</h3>
<p>Before any promotion, compare the challenger against the current champion model on a held-out test set.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a>champion <span class="op">=</span> client.get_latest_versions(<span class="st">"cv-detector"</span>, stages<span class="op">=</span>[<span class="st">"Production"</span>])[<span class="dv">0</span>]</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>champion_metrics <span class="op">=</span> client.get_run(champion.run_id).data.metrics</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>challenger_metrics <span class="op">=</span> client.get_run(challenger_run_id).data.metrics</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> challenger_metrics[<span class="st">"val/mAP50"</span>] <span class="op">&gt;</span> champion_metrics[<span class="st">"val/mAP50"</span>] <span class="op">+</span> <span class="fl">0.005</span>:</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Challenger beats champion — proceed with promotion"</span>)</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a><span class="cf">else</span>:</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Challenger did not improve sufficiently — reject"</span>)</span></code></pre></div></div>
</section>
<section id="environment-reproducibility" class="level3">
<h3 class="anchored" data-anchor-id="environment-reproducibility" id="environment-reproducibility">Environment Reproducibility</h3>
<p>Always log the full environment alongside the model:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> subprocess</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Log pip freeze</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>pip_freeze <span class="op">=</span> subprocess.check_output([<span class="st">"pip"</span>, <span class="st">"freeze"</span>]).decode()</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> <span class="bu">open</span>(<span class="st">"/tmp/requirements.txt"</span>, <span class="st">"w"</span>) <span class="im">as</span> f:</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    f.write(pip_freeze)</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>mlflow.log_artifact(<span class="st">"/tmp/requirements.txt"</span>, artifact_path<span class="op">=</span><span class="st">"environment"</span>)</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a><span class="co"># MLflow will also auto-capture conda.yaml / python_env.yaml when using log_model</span></span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-governance" class="level2">
<h2 class="anchored" data-anchor-id="sec-governance" id="sec-governance">Governance, Reproducibility &amp; Compliance</h2>
<section id="seed-everything" class="level3">
<h3 class="anchored" data-anchor-id="seed-everything" id="seed-everything">Seed Everything</h3>
<p>Log all random seeds. In CV, augmentation pipelines use multiple RNGs (NumPy, PyTorch, Albumentations).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>SEED <span class="op">=</span> <span class="dv">42</span></span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>random.seed(SEED)</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>np.random.seed(SEED)</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>torch.manual_seed(SEED)</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>torch.cuda.manual_seed_all(SEED)</span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>torch.backends.cudnn.deterministic <span class="op">=</span> <span class="va">True</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>torch.backends.cudnn.benchmark <span class="op">=</span> <span class="va">False</span></span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>mlflow.log_param(<span class="st">"global_seed"</span>, SEED)</span></code></pre></div></div>
</section>
<section id="record-hardware-and-framework-versions" class="level3">
<h3 class="anchored" data-anchor-id="record-hardware-and-framework-versions" id="record-hardware-and-framework-versions">Record Hardware and Framework Versions</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> platform</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>mlflow.log_params({</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"python_version"</span>: platform.python_version(),</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"pytorch_version"</span>: torch.__version__,</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"torchvision_version"</span>: torchvision.__version__,</span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"cuda_version"</span>: torch.version.cuda,</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">"cudnn_version"</span>: <span class="bu">str</span>(torch.backends.cudnn.version()),</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gpu_model"</span>: torch.cuda.get_device_name(<span class="dv">0</span>) <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"CPU"</span>,</span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"num_gpus"</span>: torch.cuda.device_count(),</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>})</span></code></pre></div></div>
</section>
<section id="store-model-cards-as-artifacts" class="level3">
<h3 class="anchored" data-anchor-id="store-model-cards-as-artifacts" id="store-model-cards-as-artifacts">Store Model Cards as artifacts</h3>
<p>Document each registered model version with a model card (intended use, limitations, training data, fairness notes).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a>model_card <span class="op">=</span> <span class="st">"""</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="st"># Model Card: cv-resnet50-classifier v3</span></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="st">## Intended Use</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a><span class="st">- Binary defect classification for manufacturing QC</span></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a><span class="st">- Input: 224x224 RGB images</span></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a><span class="st">## Limitations</span></span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a><span class="st">- Not validated on night-time imagery</span></span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a><span class="st">- Class imbalance: defect rate ~3%</span></span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a><span class="st">## Training Data</span></span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a><span class="st">- Source: Internal dataset, 2023-01 to 2024-06</span></span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a><span class="st">- 85k training / 15k validation images</span></span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a><span class="st">- SHA256: a3f8c12...</span></span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a><span class="st">## Performance</span></span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a><span class="st">- val/top1_acc: 96.4%</span></span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a><span class="st">- p95_latency_ms: 12.3ms (A100)</span></span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a><span class="st">"""</span></span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> <span class="bu">open</span>(<span class="st">"/tmp/MODEL_CARD.md"</span>, <span class="st">"w"</span>) <span class="im">as</span> f:</span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a>    f.write(model_card)</span>
<span id="cb23-24"><a href="#cb23-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-25"><a href="#cb23-25" aria-hidden="true" tabindex="-1"></a>mlflow.log_artifact(<span class="st">"/tmp/MODEL_CARD.md"</span>, artifact_path<span class="op">=</span><span class="st">"governance"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-performance" class="level2">
<h2 class="anchored" data-anchor-id="sec-performance" id="sec-performance">Performance &amp; Scalability</h2>
<section id="avoid-logging-inside-the-training-loop" class="level3">
<h3 class="anchored" data-anchor-id="avoid-logging-inside-the-training-loop" id="avoid-logging-inside-the-training-loop">Avoid Logging Inside the Training Loop</h3>
<p>Excessive per-step metric logging adds I/O overhead. Batch or throttle your logging.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="co"># ❌ Too frequent — logs every step</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> step, batch <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> train_step(batch)</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>    mlflow.log_metric(<span class="st">"train/loss"</span>, loss, step<span class="op">=</span>step)  <span class="co"># Bottleneck!</span></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a><span class="co"># ✅ Log every N steps</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>LOG_INTERVAL <span class="op">=</span> <span class="dv">50</span></span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> step, batch <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> train_step(batch)</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> step <span class="op">%</span> LOG_INTERVAL <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>        mlflow.log_metric(<span class="st">"train/loss"</span>, loss, step<span class="op">=</span>step)</span></code></pre></div></div>
</section>
<section id="use-autologging-selectively" class="level3">
<h3 class="anchored" data-anchor-id="use-autologging-selectively" id="use-autologging-selectively">Use Autologging Selectively</h3>
<p><code>mlflow.pytorch.autolog()</code> is convenient but can log too much noise in CV contexts. Prefer manual logging for control, and use autolog only as a baseline during exploration.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Exploration: enable autolog</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.autolog(log_every_n_epoch<span class="op">=</span><span class="dv">1</span>, log_models<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Production: disable autolog, log explicitly</span></span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.autolog(disable<span class="op">=</span><span class="va">True</span>)</span></code></pre></div></div>
</section>
<section id="backend-storage-recommendations" class="level3">
<h3 class="anchored" data-anchor-id="backend-storage-recommendations" id="backend-storage-recommendations">Backend Storage Recommendations</h3>
<div id="tbl-backend" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-backend-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;4: Backend storage options by team scale
</figcaption>
<div aria-describedby="tbl-backend-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 17%">
<col style="width: 41%">
<col style="width: 41%">
</colgroup>
<thead>
<tr class="header">
<th>Scale</th>
<th>Tracking Server</th>
<th>Artifact Store</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Local/solo</td>
<td>Local filesystem</td>
<td>Local filesystem</td>
</tr>
<tr class="even">
<td>Team</td>
<td>PostgreSQL + MLflow Server</td>
<td>S3 / GCS / Azure Blob</td>
</tr>
<tr class="odd">
<td>Enterprise</td>
<td>Managed MLflow (Databricks)</td>
<td>Object store + CDN</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Point to a remote tracking server</span></span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>mlflow.set_tracking_uri(<span class="st">"http://your-mlflow-server:5000"</span>)</span></code></pre></div></div>
<hr>
</section>
</section>
<section id="sec-checklists" class="level2">
<h2 class="anchored" data-anchor-id="sec-checklists" id="sec-checklists">Quick Reference Checklists</h2>
<section id="per-run-checklist" class="level3">
<h3 class="anchored" data-anchor-id="per-run-checklist" id="per-run-checklist">Per-Run Checklist</h3>
<div class="callout callout-style-default callout-important no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>✅ Before closing a run, verify:
</div>
</div>
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox">Experiment name reflects the research objective</label></li>
<li><label><input type="checkbox">All hyperparameters logged (including implicit/augmentation params)</label></li>
<li><label><input type="checkbox">Git commit SHA tagged</label></li>
<li><label><input type="checkbox">Training config YAML logged as artifact</label></li>
<li><label><input type="checkbox">Dataset version/hash logged</label></li>
<li><label><input type="checkbox">Model logged with signature and input example</label></li>
<li><label><input type="checkbox">Evaluation metrics logged per epoch + summary</label></li>
<li><label><input type="checkbox">Sample predictions visualised and logged</label></li>
<li><label><input type="checkbox">Hardware/framework versions logged</label></li>
<li><label><input type="checkbox">Global seeds logged and fixed</label></li>
</ul>
</div>
</div>
</section>
<section id="per-promotion-checklist" class="level3">
<h3 class="anchored" data-anchor-id="per-promotion-checklist" id="per-promotion-checklist">Per-Promotion Checklist</h3>
<div class="callout callout-style-default callout-important no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>✅ Before promoting to Production, verify:
</div>
</div>
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox">Challenger vs champion comparison passed</label></li>
<li><label><input type="checkbox">All metric gates satisfied (mAP, latency, recall)</label></li>
<li><label><input type="checkbox">ONNX export tested and logged</label></li>
<li><label><input type="checkbox">Model card updated</label></li>
<li><label><input type="checkbox">Old production version archived</label></li>
<li><label><input type="checkbox"><code>requirements.txt</code> artifact present</label></li>
</ul>
</div>
</div>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Grounding DINO Implementation Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/implementation/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/implementation/</guid>
      <pubDate>Sun, 21 Dec 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="grounding-dino-implementation-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/implementation/groundingdino.png" class="img-fluid"></p>
<p>Grounding DINO is a state-of-the-art open-set object detection model that combines language understanding with visual detection. It can detect and localize objects based on natural language descriptions, making it highly flexible for zero-shot object detection tasks.</p>
<section id="key-features" class="level2">
<h2 class="anchored" data-anchor-id="key-features" id="key-features">Key Features</h2>
<ul>
<li><strong>Open-vocabulary detection</strong>: Detect objects using free-form text descriptions</li>
<li><strong>Zero-shot capability</strong>: No need for task-specific fine-tuning</li>
<li><strong>High accuracy</strong>: Achieves strong performance on COCO and other benchmarks</li>
<li><strong>Flexible integration</strong>: Works with various downstream tasks like segmentation</li>
</ul>
</section>
<section id="architecture-components" class="level2">
<h2 class="anchored" data-anchor-id="architecture-components" id="architecture-components">Architecture Components</h2>
<section id="vision-backbone" class="level3">
<h3 class="anchored" data-anchor-id="vision-backbone" id="vision-backbone">Vision Backbone</h3>
<p>Grounding DINO uses a Swin Transformer as the vision backbone to extract multi-scale visual features from input images.</p>
</section>
<section id="language-backbone" class="level3">
<h3 class="anchored" data-anchor-id="language-backbone" id="language-backbone">Language Backbone</h3>
<p>BERT is used as the text encoder to process language queries and extract semantic features.</p>
</section>
<section id="feature-enhancer" class="level3">
<h3 class="anchored" data-anchor-id="feature-enhancer" id="feature-enhancer">Feature Enhancer</h3>
<p>A feature enhancer module fuses vision and language features through cross-modality attention mechanisms.</p>
</section>
<section id="language-guided-query-selection" class="level3">
<h3 class="anchored" data-anchor-id="language-guided-query-selection" id="language-guided-query-selection">Language-Guided Query Selection</h3>
<p>The model uses language features to guide the selection of object queries in the decoder.</p>
</section>
<section id="cross-modality-decoder" class="level3">
<h3 class="anchored" data-anchor-id="cross-modality-decoder" id="cross-modality-decoder">Cross-Modality Decoder</h3>
<p>A transformer decoder that performs cross-attention between image features and text features to predict bounding boxes.</p>
</section>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a virtual environment</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> venv grounding_dino_env</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="bu">source</span> grounding_dino_env/bin/activate  <span class="co"># On Windows: grounding_dino_env\Scripts\activate</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyTorch (adjust for your CUDA version)</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision <span class="at">--index-url</span> https://download.pytorch.org/whl/cu118</span></code></pre></div></div>
</section>
<section id="install-grounding-dino" class="level3">
<h3 class="anchored" data-anchor-id="install-grounding-dino" id="install-grounding-dino">Install Grounding DINO</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clone the repository</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/IDEA-Research/GroundingDINO.git</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> GroundingDINO</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install requirements</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative: Install from PyPI (if available)</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install groundingdino</span></code></pre></div></div>
</section>
<section id="download-model-weights" class="level3">
<h3 class="anchored" data-anchor-id="download-model-weights" id="download-model-weights">Download Model Weights</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Download pre-trained weights</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="fu">mkdir</span> weights</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> weights</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="fu">wget</span> https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth</span></code></pre></div></div>
</section>
</section>
<section id="basic-implementation" class="level2">
<h2 class="anchored" data-anchor-id="basic-implementation" id="basic-implementation">Basic Implementation</h2>
<section id="simple-detection-example" class="level3">
<h3 class="anchored" data-anchor-id="simple-detection-example" id="simple-detection-example">Simple Detection Example</h3>
<div id="d51a5274" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.util.inference <span class="im">import</span> load_model, load_image, predict, annotate</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> load_model(</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"groundingdino/config/GroundingDINO_SwinT_OGC.py"</span>,</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"weights/groundingdino_swint_ogc.pth"</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Load image</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>image_source, image <span class="op">=</span> load_image(<span class="st">"path/to/your/image.jpg"</span>)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Define text prompt</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>TEXT_PROMPT <span class="op">=</span> <span class="st">"cat . dog . person"</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>BOX_THRESHOLD <span class="op">=</span> <span class="fl">0.35</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>TEXT_THRESHOLD <span class="op">=</span> <span class="fl">0.25</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Run inference</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>boxes, logits, phrases <span class="op">=</span> predict(</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span>model,</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    image<span class="op">=</span>image,</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    caption<span class="op">=</span>TEXT_PROMPT,</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    box_threshold<span class="op">=</span>BOX_THRESHOLD,</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    text_threshold<span class="op">=</span>TEXT_THRESHOLD</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Visualize results</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>annotated_frame <span class="op">=</span> annotate(</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    image_source<span class="op">=</span>image_source,</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>    boxes<span class="op">=</span>boxes,</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>    logits<span class="op">=</span>logits,</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    phrases<span class="op">=</span>phrases</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Save or display</span></span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>Image.fromarray(annotated_frame).save(<span class="st">"output.jpg"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="custom-implementation" class="level3">
<h3 class="anchored" data-anchor-id="custom-implementation" id="custom-implementation">Custom Implementation</h3>
<div id="bece3280" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.util <span class="im">import</span> box_ops</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.models <span class="im">import</span> build_model</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.util.slconfig <span class="im">import</span> SLConfig</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.util.utils <span class="im">import</span> clean_state_dict</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_custom_model(config_path, checkpoint_path, device<span class="op">=</span><span class="st">'cuda'</span>):</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Load Grounding DINO model with custom configuration"""</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    args <span class="op">=</span> SLConfig.fromfile(config_path)</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    args.device <span class="op">=</span> device</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> build_model(args)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    checkpoint <span class="op">=</span> torch.load(checkpoint_path, map_location<span class="op">=</span><span class="st">'cpu'</span>)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    model.load_state_dict(clean_state_dict(checkpoint[<span class="st">'model'</span>]), strict<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model.to(device)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_caption(caption):</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process caption for model input"""</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Separate objects with periods</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    caption <span class="op">=</span> caption.lower().strip()</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> caption.endswith(<span class="st">'.'</span>):</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        caption <span class="op">=</span> caption <span class="op">+</span> <span class="st">'.'</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> caption</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> detect_objects(model, image_tensor, caption, box_threshold<span class="op">=</span><span class="fl">0.35</span>, text_threshold<span class="op">=</span><span class="fl">0.25</span>):</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="co">    Run object detection</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a><span class="co">        model: Grounding DINO model</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a><span class="co">        image_tensor: Preprocessed image tensor [C, H, W]</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a><span class="co">        caption: Text description of objects to detect</span></span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a><span class="co">        box_threshold: Confidence threshold for boxes</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a><span class="co">        text_threshold: Confidence threshold for text matching</span></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a><span class="co">        boxes: Detected bounding boxes in [cx, cy, w, h] format</span></span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a><span class="co">        scores: Confidence scores</span></span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a><span class="co">        labels: Text labels for each box</span></span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    caption <span class="op">=</span> preprocess_caption(caption)</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(image_tensor[<span class="va">None</span>], captions<span class="op">=</span>[caption])</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Extract predictions</span></span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>    logits <span class="op">=</span> outputs[<span class="st">"pred_logits"</span>].sigmoid()[<span class="dv">0</span>]  <span class="co"># [num_queries, num_classes]</span></span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>    boxes <span class="op">=</span> outputs[<span class="st">"pred_boxes"</span>][<span class="dv">0</span>]  <span class="co"># [num_queries, 4]</span></span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Filter by thresholds</span></span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>    max_logits, _ <span class="op">=</span> logits.<span class="bu">max</span>(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>    mask <span class="op">=</span> max_logits <span class="op">&gt;</span> box_threshold</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>    boxes <span class="op">=</span> boxes[mask]</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>    logits <span class="op">=</span> logits[mask]</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get phrase labels</span></span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>    phrases <span class="op">=</span> []</span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>    scores <span class="op">=</span> []</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> logit <span class="kw">in</span> logits:</span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>        max_score, max_idx <span class="op">=</span> logit.<span class="bu">max</span>(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> max_score <span class="op">&gt;</span> text_threshold:</span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a>            phrases.append(caption.split(<span class="st">'.'</span>)[max_idx.item()])</span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a>            scores.append(max_score.item())</span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> boxes, scores, phrases</span></code></pre></div></div>
</div>
</section>
<section id="image-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="image-preprocessing" id="image-preprocessing">Image Preprocessing</h3>
<div id="367a42f6" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.transforms <span class="im">import</span> Compose, Resize, ToTensor, Normalize</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_image(image_path, target_size<span class="op">=</span><span class="dv">800</span>):</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="co">    Preprocess image for Grounding DINO</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="co">        image_path: Path to input image</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="co">        target_size: Target size for the shorter side</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a><span class="co">        image_tensor: Preprocessed image tensor</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="co">        original_size: Original image dimensions (H, W)</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Read image</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> cv2.imread(image_path)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> cv2.cvtColor(image, cv2.COLOR_BGR2RGB)</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    original_size <span class="op">=</span> image.shape[:<span class="dv">2</span>]</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Resize while maintaining aspect ratio</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    h, w <span class="op">=</span> image.shape[:<span class="dv">2</span>]</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    scale <span class="op">=</span> target_size <span class="op">/</span> <span class="bu">min</span>(h, w)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    new_h, new_w <span class="op">=</span> <span class="bu">int</span>(h <span class="op">*</span> scale), <span class="bu">int</span>(w <span class="op">*</span> scale)</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> cv2.resize(image, (new_w, new_h))</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to tensor and normalize</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> Compose([</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        ToTensor(),</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>    image_tensor <span class="op">=</span> transform(image)</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> image_tensor, original_size</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-usage" class="level2">
<h2 class="anchored" data-anchor-id="advanced-usage" id="advanced-usage">Advanced Usage</h2>
<section id="batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing" id="batch-processing">Batch Processing</h3>
<div id="d93a8da5" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_detect(model, image_paths, caption, batch_size<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process multiple images in batches"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(image_paths), batch_size):</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        batch_paths <span class="op">=</span> image_paths[i:i<span class="op">+</span>batch_size]</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        batch_tensors <span class="op">=</span> []</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        batch_sizes <span class="op">=</span> []</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> path <span class="kw">in</span> batch_paths:</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>            tensor, size <span class="op">=</span> preprocess_image(path)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            batch_tensors.append(tensor)</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            batch_sizes.append(size)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Pad tensors to same size</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        max_h <span class="op">=</span> <span class="bu">max</span>(t.shape[<span class="dv">1</span>] <span class="cf">for</span> t <span class="kw">in</span> batch_tensors)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        max_w <span class="op">=</span> <span class="bu">max</span>(t.shape[<span class="dv">2</span>] <span class="cf">for</span> t <span class="kw">in</span> batch_tensors)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        padded_batch <span class="op">=</span> []</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> tensor <span class="kw">in</span> batch_tensors:</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>            pad_h <span class="op">=</span> max_h <span class="op">-</span> tensor.shape[<span class="dv">1</span>]</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            pad_w <span class="op">=</span> max_w <span class="op">-</span> tensor.shape[<span class="dv">2</span>]</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            padded <span class="op">=</span> torch.nn.functional.pad(tensor, (<span class="dv">0</span>, pad_w, <span class="dv">0</span>, pad_h))</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>            padded_batch.append(padded)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        batch_tensor <span class="op">=</span> torch.stack(padded_batch)</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run inference</span></span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(batch_tensor, captions<span class="op">=</span>[caption] <span class="op">*</span> <span class="bu">len</span>(batch_paths))</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process outputs for each image</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> j, (boxes, logits) <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="bu">zip</span>(outputs[<span class="st">"pred_boxes"</span>], outputs[<span class="st">"pred_logits"</span>])):</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>                <span class="st">'image'</span>: batch_paths[j],</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>                <span class="st">'boxes'</span>: boxes,</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>                <span class="st">'logits'</span>: logits</span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span></code></pre></div></div>
</div>
</section>
<section id="integration-with-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-segmentation" id="integration-with-segmentation">Integration with Segmentation</h3>
<div id="77cff362" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> combine_with_sam(grounding_model, sam_predictor, image_path, text_prompt):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Combine Grounding DINO with Segment Anything Model (SAM)</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">    for text-prompted segmentation</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="im">from</span> segment_anything <span class="im">import</span> SamPredictor</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Detect objects with Grounding DINO</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    image_source, image <span class="op">=</span> load_image(image_path)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    boxes, logits, phrases <span class="op">=</span> predict(</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        model<span class="op">=</span>grounding_model,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>image,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        caption<span class="op">=</span>text_prompt,</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        box_threshold<span class="op">=</span><span class="fl">0.35</span>,</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        text_threshold<span class="op">=</span><span class="fl">0.25</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert boxes to SAM format</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    h, w <span class="op">=</span> image_source.shape[:<span class="dv">2</span>]</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    boxes_xyxy <span class="op">=</span> box_ops.box_cxcywh_to_xyxy(boxes) <span class="op">*</span> torch.Tensor([w, h, w, h])</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate masks with SAM</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    sam_predictor.set_image(image_source)</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    transformed_boxes <span class="op">=</span> sam_predictor.transform.apply_boxes_torch(</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        boxes_xyxy, image_source.shape[:<span class="dv">2</span>]</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    masks, scores, _ <span class="op">=</span> sam_predictor.predict_torch(</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        point_coords<span class="op">=</span><span class="va">None</span>,</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        point_labels<span class="op">=</span><span class="va">None</span>,</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        boxes<span class="op">=</span>transformed_boxes,</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        multimask_output<span class="op">=</span><span class="va">False</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> masks, boxes, phrases</span></code></pre></div></div>
</div>
</section>
<section id="fine-tuning-on-custom-dataset" class="level3">
<h3 class="anchored" data-anchor-id="fine-tuning-on-custom-dataset" id="fine-tuning-on-custom-dataset">Fine-tuning on Custom Dataset</h3>
<div id="78bef8bc" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> groundingdino.datasets <span class="im">import</span> CocoDetection</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_custom_dataloader(data_root, ann_file, batch_size<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create dataloader for custom dataset"""</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    dataset <span class="op">=</span> CocoDetection(</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        img_folder<span class="op">=</span>data_root,</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        ann_file<span class="op">=</span>ann_file,</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        transforms<span class="op">=</span><span class="va">None</span>,  <span class="co"># Add custom transforms</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        return_masks<span class="op">=</span><span class="va">False</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    dataloader <span class="op">=</span> DataLoader(</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        dataset,</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span>batch_size,</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        collate_fn<span class="op">=</span><span class="kw">lambda</span> x: x  <span class="co"># Custom collate function</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dataloader</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fine_tune_model(model, train_loader, val_loader, epochs<span class="op">=</span><span class="dv">10</span>, lr<span class="op">=</span><span class="fl">1e-5</span>):</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Fine-tune Grounding DINO on custom data"""</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span>lr)</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> torch.nn.CrossEntropyLoss()</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        model.train()</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>        train_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>            images, targets, captions <span class="op">=</span> batch</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(images, captions<span class="op">=</span>captions)</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute loss (simplified)</span></span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> compute_loss(outputs, targets)</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">+=</span> loss.item()</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Validation</span></span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>        val_loss <span class="op">=</span> validate(model, val_loader)</span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">, Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss">, Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="mixed-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">Mixed Precision Training</h3>
<div id="79e87e5a" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>scaler <span class="op">=</span> GradScaler()</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> autocast():</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(images, captions<span class="op">=</span>captions)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> compute_loss(outputs, targets)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>scaler.scale(loss).backward()</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>scaler.step(optimizer)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>scaler.update()</span></code></pre></div></div>
</div>
</section>
<section id="tensorrt-optimization" class="level3">
<h3 class="anchored" data-anchor-id="tensorrt-optimization" id="tensorrt-optimization">TensorRT Optimization</h3>
<div id="b25747c1" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch_tensorrt</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile model for TensorRT</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>trt_model <span class="op">=</span> torch_tensorrt.<span class="bu">compile</span>(</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    inputs<span class="op">=</span>[torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">800</span>, <span class="dv">800</span>).cuda()],</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    enabled_precisions<span class="op">=</span>{torch.float16}</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="onnx-export" class="level3">
<h3 class="anchored" data-anchor-id="onnx-export" id="onnx-export">ONNX Export</h3>
<div id="4af3e9be" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> export_to_onnx(model, output_path, input_size<span class="op">=</span>(<span class="dv">800</span>, <span class="dv">800</span>)):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Export Grounding DINO to ONNX format"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    dummy_image <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="op">*</span>input_size).cuda()</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    dummy_caption <span class="op">=</span> [<span class="st">"cat . dog"</span>]</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    torch.onnx.export(</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        model,</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        (dummy_image, dummy_caption),</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        output_path,</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        export_params<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        opset_version<span class="op">=</span><span class="dv">14</span>,</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        do_constant_folding<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        input_names<span class="op">=</span>[<span class="st">'image'</span>, <span class="st">'caption'</span>],</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        output_names<span class="op">=</span>[<span class="st">'boxes'</span>, <span class="st">'logits'</span>],</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        dynamic_axes<span class="op">=</span>{</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">'image'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>},</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>            <span class="st">'boxes'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>},</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>            <span class="st">'logits'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>}</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    )</span></code></pre></div></div>
</div>
</section>
</section>
<section id="common-issues-and-solutions" class="level2">
<h2 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h2>
<section id="issue-1-cuda-out-of-memory" class="level3">
<h3 class="anchored" data-anchor-id="issue-1-cuda-out-of-memory" id="issue-1-cuda-out-of-memory">Issue 1: CUDA Out of Memory</h3>
<p><strong>Solution</strong>: Reduce batch size, use gradient accumulation, or resize images to smaller dimensions.</p>
<div id="5d1ef2fd" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Gradient accumulation</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>accumulation_steps <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, batch <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> model(batch) <span class="op">/</span> accumulation_steps</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    loss.backward()</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> (i <span class="op">+</span> <span class="dv">1</span>) <span class="op">%</span> accumulation_steps <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
<section id="issue-2-low-detection-accuracy" class="level3">
<h3 class="anchored" data-anchor-id="issue-2-low-detection-accuracy" id="issue-2-low-detection-accuracy">Issue 2: Low Detection Accuracy</h3>
<p><strong>Solution</strong>: Adjust thresholds, improve text prompts, or use more descriptive captions.</p>
<div id="5132c83b" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Try different threshold combinations</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>box_thresholds <span class="op">=</span> [<span class="fl">0.25</span>, <span class="fl">0.35</span>, <span class="fl">0.45</span>]</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>text_thresholds <span class="op">=</span> [<span class="fl">0.20</span>, <span class="fl">0.25</span>, <span class="fl">0.30</span>]</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>best_results <span class="op">=</span> <span class="va">None</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>best_score <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> box_th <span class="kw">in</span> box_thresholds:</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> text_th <span class="kw">in</span> text_thresholds:</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        boxes, logits, phrases <span class="op">=</span> predict(model, image, caption, box_th, text_th)</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        score <span class="op">=</span> evaluate_results(boxes, ground_truth)</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> score <span class="op">&gt;</span> best_score:</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>            best_score <span class="op">=</span> score</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>            best_results <span class="op">=</span> (boxes, logits, phrases)</span></code></pre></div></div>
</div>
</section>
<section id="issue-3-slow-inference" class="level3">
<h3 class="anchored" data-anchor-id="issue-3-slow-inference" id="issue-3-slow-inference">Issue 3: Slow Inference</h3>
<p><strong>Solution</strong>: Use TensorRT, reduce image resolution, or batch process images.</p>
<div id="4f651a6e" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimize image size</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> adaptive_resize(image, max_size<span class="op">=</span><span class="dv">1024</span>):</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    h, w <span class="op">=</span> image.shape[:<span class="dv">2</span>]</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    scale <span class="op">=</span> max_size <span class="op">/</span> <span class="bu">max</span>(h, w)</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    new_h, new_w <span class="op">=</span> <span class="bu">int</span>(h <span class="op">*</span> scale), <span class="bu">int</span>(w <span class="op">*</span> scale)</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> cv2.resize(image, (new_w, new_h))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<ol type="1">
<li><p><strong>Text Prompts</strong>: Use clear, specific descriptions separated by periods</p>
<ul>
<li>Good: <code>"red car . person wearing hat . traffic light"</code></li>
<li>Bad: <code>"things in the street"</code></li>
</ul></li>
<li><p><strong>Threshold Tuning</strong>: Start with default values and adjust based on results</p>
<ul>
<li>Higher thresholds: Fewer false positives, may miss objects</li>
<li>Lower thresholds: More detections, more false positives</li>
</ul></li>
<li><p><strong>Image Quality</strong>: Use high-resolution images when possible</p>
<ul>
<li>Minimum recommended: 640x640</li>
<li>Optimal: 800x800 or higher</li>
</ul></li>
<li><p><strong>Batch Processing</strong>: Group similar-sized images to minimize padding overhead</p></li>
<li><p><strong>GPU Memory</strong>: Monitor usage and adjust batch size accordingly</p></li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Grounding DINO provides a powerful framework for open-vocabulary object detection. Its ability to understand natural language makes it highly versatile for various computer vision applications, from autonomous driving to robotics and content moderation.</p>
</section>
<section id="resources" class="level2">
<h2 class="anchored" data-anchor-id="resources" id="resources">Resources</h2>
<ul>
<li><a href="https://github.com/IDEA-Research/GroundingDINO">Official GitHub Repository</a></li>
<li><a href="https://arxiv.org/abs/2303.05499">Research Paper</a></li>
<li><a href="https://huggingface.co/spaces/IDEA-Research/Grounding-DINO">Hugging Face Demo</a></li>
<li><a href="https://groundingdino.readthedocs.io/">API Documentation</a></li>
</ul>
</section>
<section id="citation" class="level2">
<h2 class="anchored" data-anchor-id="citation" id="citation">Citation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode bibtex code-with-copy"><code class="sourceCode bibtex"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="va">@article</span>{<span class="ot">liu2023grounding</span>,</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">title</span>={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>  <span class="dt">author</span>={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>  <span class="dt">journal</span>={arXiv preprint arXiv:2303.05499},</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>  <span class="dt">year</span>={2023}</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Mathematics Behind Grounding DINO]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/math/</guid>
      <pubDate>Sun, 21 Dec 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="mathematics-behind-grounding-dino" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/grounding-dino/math/groundingdino.png" class="img-fluid"></p>
<p>Grounding DINO (Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection) is a state-of-the-art open-set object detection model that combines vision and language modalities. It extends the DINO (DETR with Improved deNoising anchOr boxes) architecture to perform zero-shot object detection using natural language descriptions.</p>
<section id="core-architecture-components" class="level2">
<h2 class="anchored" data-anchor-id="core-architecture-components" id="core-architecture-components">Core Architecture Components</h2>
<section id="feature-extraction" class="level3">
<h3 class="anchored" data-anchor-id="feature-extraction" id="feature-extraction">Feature Extraction</h3>
<p><strong>Image Encoder</strong>: Grounding DINO uses a backbone network (typically Swin Transformer) to extract visual features:</p>
<p><span class="math display">\[
\mathbf{F}_{img} = \text{Backbone}(\mathbf{I}) \in \mathbb{R}^{H \times W \times C}
\]</span></p>
<p>where <span class="math inline">\(\mathbf{I}\)</span> is the input image, and <span class="math inline">\(H, W, C\)</span> represent the spatial dimensions and channels.</p>
<p><strong>Text Encoder</strong>: A BERT-based encoder processes the text query:</p>
<p><span class="math display">\[
\mathbf{F}_{text} = \text{TextEncoder}(\mathbf{T}) \in \mathbb{R}^{L \times D}
\]</span></p>
<p>where <span class="math inline">\(\mathbf{T}\)</span> is the tokenized text, <span class="math inline">\(L\)</span> is the sequence length, and <span class="math inline">\(D\)</span> is the embedding dimension.</p>
</section>
<section id="feature-enhancement-module" class="level3">
<h3 class="anchored" data-anchor-id="feature-enhancement-module" id="feature-enhancement-module">Feature Enhancement Module</h3>
<p>The model employs a Feature Enhancer to strengthen features through multi-modal interactions:</p>
<p><span class="math display">\[
\mathbf{F}'_{img}, \mathbf{F}'_{text} = \text{FeatureEnhancer}(\mathbf{F}_{img}, \mathbf{F}_{text})
\]</span></p>
<p>This involves:</p>
<ul>
<li><strong>Deformable Self-Attention</strong> for image features</li>
<li><strong>Self-Attention</strong> for text features</li>
<li><strong>Cross-Attention</strong> between modalities</li>
</ul>
</section>
<section id="language-guided-query-selection" class="level3">
<h3 class="anchored" data-anchor-id="language-guided-query-selection" id="language-guided-query-selection">Language-Guided Query Selection</h3>
<p>Grounding DINO introduces a novel query initialization mechanism that leverages text features:</p>
<p><span class="math display">\[
\mathbf{Q}_{init} = \text{QuerySelect}(\mathbf{F}'_{img}, \mathbf{F}'_{text})
\]</span></p>
<p>The queries are selected based on similarity between image and text features:</p>
<p><span class="math display">\[
\text{Score}(i, j) = \frac{\mathbf{F}'_{img}[i] \cdot \mathbf{F}'_{text}[j]}{||\mathbf{F}'_{img}[i]|| \cdot ||\mathbf{F}'_{text}[j]||}
\]</span></p>
<p>Top-k positions with highest scores are selected as initial anchor points.</p>
</section>
</section>
<section id="transformer-decoder-architecture" class="level2">
<h2 class="anchored" data-anchor-id="transformer-decoder-architecture" id="transformer-decoder-architecture">Transformer Decoder Architecture</h2>
<section id="cross-modality-decoder" class="level3">
<h3 class="anchored" data-anchor-id="cross-modality-decoder" id="cross-modality-decoder">Cross-Modality Decoder</h3>
<p>The decoder consists of multiple layers, each containing:</p>
<p><strong>Self-Attention on Queries</strong>:</p>
<p><span class="math display">\[
\mathbf{Q}^{(l+1)} = \text{SelfAttn}(\mathbf{Q}^{(l)}) + \mathbf{Q}^{(l)}
\]</span></p>
<p><strong>Image Cross-Attention</strong> (Deformable Attention):</p>
<p><span class="math display">\[
\mathbf{Q}^{(l+1)} = \text{DeformAttn}(\mathbf{Q}^{(l+1)}, \mathbf{F}'_{img}) + \mathbf{Q}^{(l+1)}
\]</span></p>
<p>The deformable attention is computed as:</p>
<p><span class="math display">\[
\text{DeformAttn}(\mathbf{q}, \mathbf{x}, \mathbf{p}) = \sum_{m=1}^{M} \mathbf{W}_m \sum_{k=1}^{K} A_{mqk} \cdot \mathbf{W}'_m \mathbf{x}(\mathbf{p}_q + \Delta\mathbf{p}_{mqk})
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(M\)</span> is the number of attention heads</li>
<li><span class="math inline">\(K\)</span> is the number of sampling points</li>
<li><span class="math inline">\(A_{mqk}\)</span> are attention weights</li>
<li><span class="math inline">\(\Delta\mathbf{p}_{mqk}\)</span> are learned offsets</li>
<li><span class="math inline">\(\mathbf{p}_q\)</span> is the reference point</li>
</ul>
<p><strong>Text Cross-Attention</strong>:</p>
<p><span class="math display">\[
\mathbf{Q}^{(l+1)} = \text{TextAttn}(\mathbf{Q}^{(l+1)}, \mathbf{F}'_{text}) + \mathbf{Q}^{(l+1)}
\]</span></p>
<p>Standard cross-attention:</p>
<p><span class="math display">\[
\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}
\]</span></p>
</section>
</section>
<section id="prediction-heads" class="level2">
<h2 class="anchored" data-anchor-id="prediction-heads" id="prediction-heads">Prediction Heads</h2>
<section id="classification-head" class="level3">
<h3 class="anchored" data-anchor-id="classification-head" id="classification-head">Classification Head</h3>
<p>For each query <span class="math inline">\(\mathbf{q}_i\)</span>, the model computes similarity with text tokens:</p>
<p><span class="math display">\[
\mathbf{s}_i = \frac{\mathbf{q}_i \mathbf{W}_c \cdot \mathbf{F}'_{text}^T}{||\mathbf{q}_i \mathbf{W}_c|| \cdot ||\mathbf{F}'_{text}||}
\]</span></p>
<p>Classification score for token <span class="math inline">\(j\)</span>:</p>
<p><span class="math display">\[
p_{ij} = \text{sigmoid}(\mathbf{s}_{ij})
\]</span></p>
</section>
<section id="bounding-box-regression-head" class="level3">
<h3 class="anchored" data-anchor-id="bounding-box-regression-head" id="bounding-box-regression-head">Bounding Box Regression Head</h3>
<p>The box coordinates are predicted as:</p>
<p><span class="math display">\[
\mathbf{b}_i = \sigma(\text{FFN}_{box}(\mathbf{q}_i)) = [\hat{x}_c, \hat{y}_c, \hat{w}, \hat{h}]
\]</span></p>
<p>where <span class="math inline">\(\sigma\)</span> is the sigmoid function, and coordinates are normalized to [0, 1].</p>
<p>The predicted box in absolute coordinates:</p>
<p><span class="math display">\[
\begin{align}
x_c &amp;= \hat{x}_c \cdot W \\
y_c &amp;= \hat{y}_c \cdot H \\
w &amp;= \hat{w} \cdot W \\
h &amp;= \hat{h} \cdot H
\end{align}
\]</span></p>
</section>
</section>
<section id="loss-functions" class="level2">
<h2 class="anchored" data-anchor-id="loss-functions" id="loss-functions">Loss Functions</h2>
<section id="bipartite-matching-loss" class="level3">
<h3 class="anchored" data-anchor-id="bipartite-matching-loss" id="bipartite-matching-loss">Bipartite Matching Loss</h3>
<p>Following DETR, Grounding DINO uses Hungarian matching to find optimal assignment between predictions and ground truth:</p>
<p><span class="math display">\[
\hat{\sigma} = \arg\min_{\sigma \in \mathfrak{S}_N} \sum_{i}^{N} \mathcal{L}_{match}(y_i, \hat{y}_{\sigma(i)})
\]</span></p>
<p>where <span class="math inline">\(\mathfrak{S}_N\)</span> is the set of all permutations of N elements.</p>
<p>The matching cost:</p>
<p><span class="math display">\[
\mathcal{L}_{match}(y_i, \hat{y}_j) = -\mathbb{1}_{\{c_i \neq \emptyset\}} \hat{p}_j(c_i) + \mathbb{1}_{\{c_i \neq \emptyset\}} \mathcal{L}_{box}(b_i, \hat{b}_j)
\]</span></p>
</section>
<section id="total-loss" class="level3">
<h3 class="anchored" data-anchor-id="total-loss" id="total-loss">Total Loss</h3>
<p>After optimal matching, the total loss is:</p>
<p><span class="math display">\[
\mathcal{L} = \lambda_{cls}\mathcal{L}_{cls} + \lambda_{box}\mathcal{L}_{box} + \lambda_{giou}\mathcal{L}_{giou}
\]</span></p>
<p><strong>Classification Loss</strong> (Focal Loss):</p>
<p><span class="math display">\[
\mathcal{L}_{cls} = -\alpha(1-p_t)^\gamma \log(p_t)
\]</span></p>
<p>where <span class="math inline">\(p_t\)</span> is the model’s estimated probability for the correct class.</p>
<p><strong>Box L1 Loss</strong>:</p>
<p><span class="math display">\[
\mathcal{L}_{box} = \sum_{i=1}^{N} \mathbb{1}_{\{c_i \neq \emptyset\}} ||b_i - \hat{b}_{\sigma(i)}||_1
\]</span></p>
<p><strong>GIoU Loss</strong> (Generalized Intersection over Union):</p>
<p><span class="math display">\[
\mathcal{L}_{giou} = 1 - \text{GIoU}(b_i, \hat{b}_{\sigma(i)})
\]</span></p>
<p>where:</p>
<p><span class="math display">\[
\text{GIoU} = \text{IoU} - \frac{|C \setminus (A \cup B)|}{|C|}
\]</span></p>
<p><span class="math inline">\(C\)</span> is the smallest convex hull enclosing both boxes <span class="math inline">\(A\)</span> and <span class="math inline">\(B\)</span>.</p>
</section>
</section>
<section id="contrastive-alignment" class="level2">
<h2 class="anchored" data-anchor-id="contrastive-alignment" id="contrastive-alignment">Contrastive Alignment</h2>
<section id="contrastive-learning-for-vision-language-alignment" class="level3">
<h3 class="anchored" data-anchor-id="contrastive-learning-for-vision-language-alignment" id="contrastive-learning-for-vision-language-alignment">Contrastive Learning for Vision-Language Alignment</h3>
<p>During pre-training, Grounding DINO uses contrastive learning to align image regions with text phrases:</p>
<p><span class="math display">\[
\mathcal{L}_{contrast} = -\log \frac{\exp(\text{sim}(\mathbf{v}_i, \mathbf{t}_i)/\tau)}{\sum_{j=1}^{B} \exp(\text{sim}(\mathbf{v}_i, \mathbf{t}_j)/\tau)}
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(\mathbf{v}_i\)</span> is the visual embedding for region <span class="math inline">\(i\)</span></li>
<li><span class="math inline">\(\mathbf{t}_i\)</span> is the corresponding text embedding</li>
<li><span class="math inline">\(\tau\)</span> is the temperature parameter</li>
<li><span class="math inline">\(B\)</span> is the batch size</li>
</ul>
</section>
</section>
<section id="key-mathematical-innovations" class="level2">
<h2 class="anchored" data-anchor-id="key-mathematical-innovations" id="key-mathematical-innovations">Key Mathematical Innovations</h2>
<section id="enhanced-feature-fusion" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-feature-fusion" id="enhanced-feature-fusion">Enhanced Feature Fusion</h3>
<p>The cross-modality fusion uses a gating mechanism:</p>
<p><span class="math display">\[
\mathbf{F}_{fused} = \alpha \odot \mathbf{F}_{img} + (1-\alpha) \odot \mathbf{F}_{text}
\]</span></p>
<p>where <span class="math inline">\(\alpha = \sigma(\text{FFN}([\mathbf{F}_{img}; \mathbf{F}_{text}]))\)</span> is learned dynamically.</p>
</section>
<section id="position-encoding" class="level3">
<h3 class="anchored" data-anchor-id="position-encoding" id="position-encoding">Position Encoding</h3>
<p><strong>Image Position Encoding</strong>: 2D sine-cosine positional encoding:</p>
<p><span class="math display">\[
\begin{align}
PE_{(x,y,2i)} &amp;= \sin\left(\frac{x}{10000^{2i/d}}\right) \\
PE_{(x,y,2i+1)} &amp;= \cos\left(\frac{x}{10000^{2i/d}}\right)
\end{align}
\]</span></p>
<p><strong>Text Position Encoding</strong>: Standard 1D positional encoding for sequence position.</p>
</section>
</section>
<section id="inference-process" class="level2">
<h2 class="anchored" data-anchor-id="inference-process" id="inference-process">Inference Process</h2>
<p>At inference time, given an image and text query:</p>
<ol type="1">
<li>Extract features: <span class="math inline">\(\mathbf{F}_{img}, \mathbf{F}_{text}\)</span></li>
<li>Enhance features through cross-attention</li>
<li>Initialize queries based on image-text similarity</li>
<li>Pass through decoder layers</li>
<li>Generate predictions for each query</li>
<li>Apply NMS (Non-Maximum Suppression) to filter overlapping boxes:</li>
</ol>
<p><span class="math display">\[
\text{Keep box } i \text{ if } \text{IoU}(b_i, b_j) &lt; \theta \text{ for all } j \text{ with higher score}
\]</span></p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Grounding DINO’s mathematical framework elegantly combines:</p>
<ul>
<li>Deformable attention for efficient multi-scale feature processing</li>
<li>Cross-modal attention for vision-language alignment</li>
<li>Contrastive learning for robust feature representations</li>
<li>Hungarian matching for optimal prediction-target assignment</li>
</ul>
<p>These components work together to enable open-vocabulary object detection, allowing the model to detect objects described by arbitrary text queries without fine-tuning on specific categories.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Programming Languages in Computer Vision & Machine Learning]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/ml-langs/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/ml-langs/</guid>
      <pubDate>Tue, 30 Sep 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="programming-languages-in-computer-vision-machine-learning" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/ml-langs/langs.png" class="img-fluid"></p>
<p>When I look back at my career so far, it feels like a journey through different languages, each chapter shaping the way I think about solving problems.</p>
<p>It all began with Java in collage. That was my first serious step into software development — a world of strong typing, object-oriented design, and enterprise-scale thinking. It gave me discipline in structure, patterns, and writing code that lasts.</p>
<p>From there, I transitioned into JavaScript during my time at TopRankers. It was like entering a different universe — one that was faster, more dynamic, and centered around creating immediate impact for users. JavaScript taught me how to think in terms of interactivity, responsiveness, and user experience.</p>
<p>At Waycool Foods, my journey deepened as I worked with both Python and JavaScript. Here, I started to bridge worlds — backend logic, data-driven decision-making, and the user-facing layer. This dual exposure helped me appreciate the power of Python’s simplicity and versatility, alongside the speed and ubiquity of JavaScript.</p>
<p>Then came Mareana, where my focus shifted entirely to Python and MLOps. This was the turning point — moving from writing applications to building scalable machine learning systems. It was about automation, pipelines, monitoring, and making sure models didn’t just work in a notebook but thrived in production. I learned how to bring discipline into the chaos of experimentation.</p>
<p>Now at Lytx, I find myself in the exciting realm of applied research. Here, Python is my closest ally — powering experiments in computer vision, machine learning, and deep learning. It’s no longer just about deployment, but about pushing boundaries, asking new questions, and finding answers in data.</p>
<p>Looking back, each language and role wasn’t just a skill upgrade — it was a mindset shift. Java gave me structure, JavaScript gave me adaptability, MLOps taught me scale, and research has taught me curiosity. Together, they form the story of how I grew from a developer into a practitioner of applied AI.</p>
<p>The choice of programming language for computer vision and machine learning projects depends on a careful balance of performance requirements, development speed, team expertise, and deployment constraints. This guide explores the four primary languages used in CV &amp; ML: Python, C++, JavaScript, and Go.</p>
</section>
<section id="python-in-cv-ml" class="level1">
<h1>Python in CV &amp; ML</h1>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>Python dominates the machine learning and computer vision landscape, serving as the primary language for research, prototyping, and production deployment. Its extensive ecosystem and ease of use make it the de facto standard for ML practitioners.</p>
</section>
<section id="key-strengths" class="level2">
<h2 class="anchored" data-anchor-id="key-strengths" id="key-strengths">Key Strengths</h2>
<p><strong>Rich Ecosystem</strong>: Python boasts the most comprehensive collection of ML and CV libraries, with mature, well-documented frameworks that handle everything from data preprocessing to model deployment.</p>
<p><strong>Rapid Prototyping</strong>: The language’s intuitive syntax and interactive development environment (Jupyter notebooks, IPython) enable researchers to iterate quickly on ideas and visualize results in real-time.</p>
<p><strong>Community &amp; Resources</strong>: With millions of practitioners worldwide, Python offers unparalleled community support, tutorials, pre-trained models, and solutions to common problems.</p>
<p><strong>Research-to-Production</strong>: Modern frameworks like PyTorch and TensorFlow provide clear paths from research prototypes to production systems, with tools for optimization and deployment.</p>
</section>
<section id="essential-libraries-frameworks" class="level2">
<h2 class="anchored" data-anchor-id="essential-libraries-frameworks" id="essential-libraries-frameworks">Essential Libraries &amp; Frameworks</h2>
<section id="deep-learning-frameworks" class="level3">
<h3 class="anchored" data-anchor-id="deep-learning-frameworks" id="deep-learning-frameworks">Deep Learning Frameworks</h3>
<p><strong>PyTorch</strong>: The preferred framework for research and increasingly for production. PyTorch’s dynamic computational graphs make debugging intuitive, while its eager execution model aligns with Python’s natural flow. Features include:</p>
<ul>
<li>TorchVision for computer vision tasks with pre-trained models (ResNet, YOLO, Vision Transformers)</li>
<li>TorchScript for converting models to production-ready formats</li>
<li>Native support for distributed training across multiple GPUs</li>
<li>Extensive ecosystem with libraries like PyTorch Lightning, Detectron2, and MMDetection</li>
</ul>
<p><strong>TensorFlow/Keras</strong>: Google’s framework excels in production environments with robust deployment tools. TensorFlow offers:</p>
<ul>
<li>Keras API for high-level, user-friendly model building</li>
<li>TensorFlow Serving for scalable model deployment</li>
<li>TensorFlow Lite for mobile and edge devices</li>
<li>TensorFlow.js for browser-based inference</li>
<li>Strong support for TPU acceleration</li>
</ul>
<p><strong>JAX</strong>: Emerging as a powerful tool for research, JAX combines NumPy-like syntax with automatic differentiation and XLA compilation for exceptional performance on GPUs and TPUs.</p>
</section>
<section id="computer-vision-libraries" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-libraries" id="computer-vision-libraries">Computer Vision Libraries</h3>
<p><strong>OpenCV (cv2)</strong>: The cornerstone of computer vision, OpenCV provides 2,500+ optimized algorithms for:</p>
<ul>
<li>Image processing (filtering, transformation, morphological operations)</li>
<li>Feature detection (SIFT, SURF, ORB, Harris corners)</li>
<li>Object detection (Haar cascades, HOG)</li>
<li>Camera calibration and 3D reconstruction</li>
<li>Video analysis and optical flow</li>
<li>Real-time face detection and tracking</li>
</ul>
<p><strong>Pillow (PIL)</strong>: Essential for image manipulation tasks including:</p>
<ul>
<li>Loading and saving images in various formats</li>
<li>Basic transformations (resize, crop, rotate)</li>
<li>Color space conversions</li>
<li>Image enhancement and filtering</li>
<li>Drawing and text overlay</li>
</ul>
<p><strong>scikit-image</strong>: Provides sophisticated algorithms for image processing research:</p>
<ul>
<li>Advanced segmentation (watershed, active contours)</li>
<li>Feature extraction (texture analysis, HOG descriptors)</li>
<li>Morphological operations</li>
<li>Image restoration and denoising</li>
<li>Geometric transformations</li>
</ul>
<p><strong>Albumentations</strong>: State-of-the-art data augmentation library offering 70+ transformation techniques optimized for speed, crucial for training robust models on limited datasets.</p>
</section>
<section id="machine-learning-libraries" class="level3">
<h3 class="anchored" data-anchor-id="machine-learning-libraries" id="machine-learning-libraries">Machine Learning Libraries</h3>
<p><strong>scikit-learn</strong>: The go-to library for traditional machine learning, offering:</p>
<ul>
<li>Classification algorithms (SVM, Random Forests, Gradient Boosting)</li>
<li>Clustering methods (K-means, DBSCAN, hierarchical clustering)</li>
<li>Dimensionality reduction (PCA, t-SNE, UMAP)</li>
<li>Model evaluation and cross-validation tools</li>
<li>Feature engineering utilities</li>
</ul>
<p><strong>NumPy &amp; Pandas</strong>: Form the foundation of data manipulation:</p>
<ul>
<li>NumPy provides efficient array operations and linear algebra</li>
<li>Pandas excels at structured data handling and preprocessing</li>
<li>Both integrate seamlessly with all ML frameworks</li>
</ul>
<p><strong>Matplotlib &amp; Seaborn</strong>: Visualization libraries essential for:</p>
<ul>
<li>Exploring datasets and distributions</li>
<li>Visualizing model predictions and errors</li>
<li>Creating publication-quality figures</li>
<li>Understanding feature importance</li>
</ul>
</section>
</section>
<section id="practical-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="practical-use-cases" id="practical-use-cases">Practical Use Cases</h2>
<p><strong>Image Classification</strong>: Building models to categorize images into predefined classes using CNNs like ResNet, EfficientNet, or Vision Transformers. Python’s frameworks make transfer learning straightforward, allowing practitioners to fine-tune pre-trained models on custom datasets with minimal code.</p>
<p><strong>Object Detection</strong>: Implementing real-time detection systems using architectures like YOLO, Faster R-CNN, or RetinaNet. Libraries like Detectron2 provide production-ready implementations with extensive customization options.</p>
<p><strong>Semantic Segmentation</strong>: Creating pixel-level predictions for medical imaging, autonomous vehicles, or satellite imagery using U-Net, DeepLab, or Mask R-CNN architectures.</p>
<p><strong>Generative Models</strong>: Developing GANs, VAEs, and diffusion models for image synthesis, style transfer, and data augmentation. PyTorch’s flexibility makes implementing complex generator-discriminator architectures manageable.</p>
<p><strong>Natural Language Processing</strong>: Building transformers, BERT models, and large language models using Hugging Face Transformers library, which has become the industry standard for NLP tasks.</p>
<p><strong>Time Series Analysis</strong>: Applying LSTMs, Transformers, and traditional statistical methods for forecasting, anomaly detection, and pattern recognition in temporal data.</p>
</section>
<section id="code-example" class="level2">
<h2 class="anchored" data-anchor-id="code-example" id="code-example">Code Example</h2>
<div id="d3bdc53f" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.models <span class="im">as</span> models</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained ResNet model</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.resnet50(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Define image preprocessing pipeline</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>preprocess <span class="op">=</span> transforms.Compose([</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>                         std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Load and preprocess image</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>img <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">'image.jpg'</span>)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>img_tensor <span class="op">=</span> preprocess(img).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Perform inference</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> model(img_tensor)</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> torch.nn.functional.softmax(output[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Top prediction: </span><span class="sc">{</span>probabilities<span class="sc">.</span>argmax()<span class="sc">.</span>item()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="performance-considerations" class="level2">
<h2 class="anchored" data-anchor-id="performance-considerations" id="performance-considerations">Performance Considerations</h2>
<p>While Python excels in development speed, raw computational performance comes primarily from underlying C/C++ implementations in libraries like NumPy, PyTorch, and TensorFlow. For production systems requiring maximum performance:</p>
<ul>
<li>Use compiled extensions (Cython, numba)</li>
<li>Leverage GPU acceleration through CUDA</li>
<li>Optimize model architectures with quantization and pruning</li>
<li>Consider model compilation with TorchScript or ONNX</li>
</ul>
</section>
<section id="when-to-choose-python" class="level2">
<h2 class="anchored" data-anchor-id="when-to-choose-python" id="when-to-choose-python">When to Choose Python</h2>
<p>Python is the optimal choice when:</p>
<ul>
<li>Rapid prototyping and experimentation are priorities</li>
<li>Leveraging pre-trained models and established architectures</li>
<li>Working with a team of data scientists and researchers</li>
<li>Integrating with data processing pipelines</li>
<li>Building end-to-end ML applications with web frameworks (Flask, FastAPI)</li>
<li>Prioritizing development time over raw execution speed</li>
</ul>
</section>
</section>
<section id="c-in-cv-ml" class="level1">
<h1>C++ in CV &amp; ML</h1>
<section id="overview-1" class="level2">
<h2 class="anchored" data-anchor-id="overview-1" id="overview-1">Overview</h2>
<p>C++ serves as the high-performance backbone of computer vision and machine learning systems. While less common for model development, it’s essential for production deployments, embedded systems, and applications requiring real-time performance with minimal latency.</p>
</section>
<section id="key-strengths-1" class="level2">
<h2 class="anchored" data-anchor-id="key-strengths-1" id="key-strengths-1">Key Strengths</h2>
<p><strong>Unmatched Performance</strong>: C++ provides direct memory control, zero-overhead abstractions, and compilation to native machine code, enabling the fastest possible execution speeds for CV and ML workloads.</p>
<p><strong>Low-Level Control</strong>: Fine-grained management of memory allocation, threading, and hardware resources allows optimization for specific use cases that higher-level languages cannot achieve.</p>
<p><strong>Cross-Platform Deployment</strong>: C++ code compiles to native binaries for any platform, making it ideal for embedded systems, mobile devices, and edge computing scenarios where Python runtimes may be impractical.</p>
<p><strong>Industry Standard</strong>: Most production computer vision systems in robotics, autonomous vehicles, gaming, and AR/VR rely on C++ for their performance-critical components.</p>
</section>
<section id="essential-libraries-frameworks-1" class="level2">
<h2 class="anchored" data-anchor-id="essential-libraries-frameworks-1" id="essential-libraries-frameworks-1">Essential Libraries &amp; Frameworks</h2>
<section id="computer-vision" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision" id="computer-vision">Computer Vision</h3>
<p><strong>OpenCV</strong>: Originally written in C++, OpenCV’s native interface provides the best performance for:</p>
<ul>
<li>Real-time video processing pipelines</li>
<li>Camera interface and hardware acceleration</li>
<li>GPU-accelerated operations via CUDA and OpenCL</li>
<li>Integration with specialized hardware (Intel RealSense, NVIDIA Jetson)</li>
<li>Custom algorithm implementation with full control</li>
</ul>
<p><strong>Dlib</strong>: A sophisticated C++ library excelling in:</p>
<ul>
<li>Face detection and landmark localization</li>
<li>Object tracking algorithms</li>
<li>Optimization routines for machine learning</li>
<li>Image processing utilities</li>
<li>Shape prediction models</li>
</ul>
<p><strong>Point Cloud Library (PCL)</strong>: Specialized for 3D computer vision:</p>
<ul>
<li>Point cloud processing and filtering</li>
<li>3D feature extraction and registration</li>
<li>Surface reconstruction and segmentation</li>
<li>Integration with depth sensors and LiDAR</li>
<li>Essential for robotics and autonomous systems</li>
</ul>
</section>
<section id="deep-learning" class="level3">
<h3 class="anchored" data-anchor-id="deep-learning" id="deep-learning">Deep Learning</h3>
<p><strong>LibTorch</strong>: PyTorch’s C++ API enables deployment of PyTorch models in production C++ applications:</p>
<ul>
<li>Load and run TorchScript models</li>
<li>Full computational graph control</li>
<li>Custom operator implementation</li>
<li>Integration with existing C++ codebases</li>
<li>Mobile deployment support</li>
</ul>
<p><strong>TensorFlow C++ API</strong>: Provides production-grade inference capabilities:</p>
<ul>
<li>Model serving and optimization</li>
<li>Hardware acceleration support</li>
<li>Custom operation implementation</li>
<li>Integration with TensorFlow ecosystem</li>
</ul>
<p><strong>ONNX Runtime</strong>: Cross-framework inference engine offering:</p>
<ul>
<li>Optimized execution for ONNX models</li>
<li>Hardware-specific acceleration (CPU, GPU, NPU)</li>
<li>Quantization and optimization tools</li>
<li>Support for models from PyTorch, TensorFlow, and others</li>
</ul>
<p><strong>Caffe</strong>: One of the original deep learning frameworks, still used in production:</p>
<ul>
<li>Efficient CNN implementation</li>
<li>Model Zoo with pre-trained networks</li>
<li>Focus on vision tasks</li>
<li>Mature and stable codebase</li>
</ul>
<p><strong>TensorRT</strong>: NVIDIA’s inference optimization engine:</p>
<ul>
<li>Layer fusion and kernel optimization</li>
<li>Reduced precision inference (INT8, FP16)</li>
<li>Platform-specific tuning for NVIDIA GPUs</li>
<li>Up to 10x faster inference than standard frameworks</li>
</ul>
</section>
<section id="machine-learning" class="level3">
<h3 class="anchored" data-anchor-id="machine-learning" id="machine-learning">Machine Learning</h3>
<p><strong>MLpack</strong>: Fast machine learning library implementing:</p>
<ul>
<li>Classification and regression algorithms</li>
<li>Clustering methods</li>
<li>Dimensionality reduction</li>
<li>Efficient implementations with template metaprogramming</li>
</ul>
<p><strong>Eigen</strong>: Core linear algebra library used by most ML frameworks:</p>
<ul>
<li>Matrix and vector operations</li>
<li>Solvers for linear systems</li>
<li>Decompositions and eigenvalue computations</li>
<li>SIMD optimization and vectorization</li>
</ul>
<p><strong>Shark</strong>: Comprehensive machine learning library with:</p>
<ul>
<li>Supervised and unsupervised learning algorithms</li>
<li>Neural network implementations</li>
<li>Evolutionary algorithms</li>
<li>Optimization routines</li>
</ul>
</section>
</section>
<section id="practical-use-cases-1" class="level2">
<h2 class="anchored" data-anchor-id="practical-use-cases-1" id="practical-use-cases-1">Practical Use Cases</h2>
<p><strong>Real-Time Computer Vision Systems</strong>: Building autonomous vehicle perception, industrial quality control, or robotics systems requiring processing at 30+ FPS with minimal latency. C++ enables tight integration with sensors and actuators.</p>
<p><strong>Edge AI Deployment</strong>: Deploying ML models on resource-constrained devices like Raspberry Pi, NVIDIA Jetson, or custom embedded hardware where memory footprint and power consumption are critical.</p>
<p><strong>High-Performance Inference Servers</strong>: Creating production inference systems handling thousands of requests per second, where every millisecond of latency matters for user experience or business metrics.</p>
<p><strong>Game AI &amp; Graphics</strong>: Implementing computer vision for gaming (player tracking, gesture recognition) or augmented reality applications requiring integration with game engines and rendering pipelines.</p>
<p><strong>Medical Imaging Systems</strong>: Developing FDA-approved medical devices or PACS systems requiring deterministic performance, regulatory compliance, and integration with specialized medical hardware.</p>
<p><strong>Custom Hardware Acceleration</strong>: Writing CUDA kernels or FPGA implementations for specialized computer vision algorithms, achieving performance impossible with general-purpose frameworks.</p>
</section>
<section id="code-example-1" class="level2">
<h2 class="anchored" data-anchor-id="code-example-1" id="code-example-1">Code Example</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode cpp code-with-copy"><code class="sourceCode cpp"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="pp">#include </span><span class="im">&lt;opencv2/opencv.hpp&gt;</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="pp">#include </span><span class="im">&lt;torch/script.h&gt;</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="pp">#include </span><span class="im">&lt;iostream&gt;</span></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="pp">#include </span><span class="im">&lt;vector&gt;</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="dt">int</span> main<span class="op">()</span> <span class="op">{</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Load TorchScript model</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    torch<span class="op">::</span>jit<span class="op">::</span>script<span class="op">::</span>Module model<span class="op">;</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span> <span class="op">{</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> torch<span class="op">::</span>jit<span class="op">::</span>load<span class="op">(</span><span class="st">"model.pt"</span><span class="op">);</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        model<span class="op">.</span>eval<span class="op">();</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span> <span class="cf">catch</span> <span class="op">(</span><span class="at">const</span> c10<span class="op">::</span>Error<span class="op">&amp;</span> e<span class="op">)</span> <span class="op">{</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="bu">std::</span>cerr <span class="op">&lt;&lt;</span> <span class="st">"Error loading model</span><span class="sc">\n</span><span class="st">"</span><span class="op">;</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="op">-</span><span class="dv">1</span><span class="op">;</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Open video capture</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    cv<span class="op">::</span>VideoCapture cap<span class="op">(</span><span class="dv">0</span><span class="op">);</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="op">(!</span>cap<span class="op">.</span>isOpened<span class="op">())</span> <span class="op">{</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>        <span class="bu">std::</span>cerr <span class="op">&lt;&lt;</span> <span class="st">"Error opening camera</span><span class="sc">\n</span><span class="st">"</span><span class="op">;</span></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="op">-</span><span class="dv">1</span><span class="op">;</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>    cv<span class="op">::</span>Mat frame<span class="op">;</span></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="op">(</span><span class="kw">true</span><span class="op">)</span> <span class="op">{</span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>        cap <span class="op">&gt;&gt;</span> frame<span class="op">;</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="op">(</span>frame<span class="op">.</span>empty<span class="op">())</span> <span class="cf">break</span><span class="op">;</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Preprocess image</span></span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        cv<span class="op">::</span>Mat rgb<span class="op">;</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        cv<span class="op">::</span>cvtColor<span class="op">(</span>frame<span class="op">,</span> rgb<span class="op">,</span> cv<span class="op">::</span>COLOR_BGR2RGB<span class="op">);</span></span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>        cv<span class="op">::</span>resize<span class="op">(</span>rgb<span class="op">,</span> rgb<span class="op">,</span> cv<span class="op">::</span>Size<span class="op">(</span><span class="dv">224</span><span class="op">,</span> <span class="dv">224</span><span class="op">));</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Convert to tensor</span></span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>        torch<span class="op">::</span>Tensor tensor <span class="op">=</span> torch<span class="op">::</span>from_blob<span class="op">(</span></span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>            rgb<span class="op">.</span>data<span class="op">,</span> <span class="op">{</span><span class="dv">1</span><span class="op">,</span> <span class="dv">224</span><span class="op">,</span> <span class="dv">224</span><span class="op">,</span> <span class="dv">3</span><span class="op">},</span> torch<span class="op">::</span>kByte</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        <span class="op">).</span>permute<span class="op">({</span><span class="dv">0</span><span class="op">,</span> <span class="dv">3</span><span class="op">,</span> <span class="dv">1</span><span class="op">,</span> <span class="dv">2</span><span class="op">}).</span>to<span class="op">(</span>torch<span class="op">::</span>kFloat32<span class="op">)</span> <span class="op">/</span> <span class="fl">255.0</span><span class="op">;</span></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Inference</span></span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>        <span class="kw">auto</span> output <span class="op">=</span> model<span class="op">.</span>forward<span class="op">({</span>tensor<span class="op">}).</span>toTensor<span class="op">();</span></span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>        <span class="kw">auto</span> prediction <span class="op">=</span> output<span class="op">.</span>argmax<span class="op">(</span><span class="dv">1</span><span class="op">).</span>item<span class="op">&lt;</span><span class="dt">int</span><span class="op">&gt;();</span></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Display result</span></span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>        cv<span class="op">::</span>putText<span class="op">(</span>frame<span class="op">,</span> <span class="st">"Class: "</span> <span class="op">+</span> <span class="bu">std::</span>to_string<span class="op">(</span>prediction<span class="op">),</span></span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>                    cv<span class="op">::</span>Point<span class="op">(</span><span class="dv">10</span><span class="op">,</span> <span class="dv">30</span><span class="op">),</span> cv<span class="op">::</span>FONT_HERSHEY_SIMPLEX<span class="op">,</span></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>                    <span class="fl">1.0</span><span class="op">,</span> cv<span class="op">::</span>Scalar<span class="op">(</span><span class="dv">0</span><span class="op">,</span> <span class="dv">255</span><span class="op">,</span> <span class="dv">0</span><span class="op">),</span> <span class="dv">2</span><span class="op">);</span></span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>        cv<span class="op">::</span>imshow<span class="op">(</span><span class="st">"Detection"</span><span class="op">,</span> frame<span class="op">);</span></span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="op">(</span>cv<span class="op">::</span>waitKey<span class="op">(</span><span class="dv">1</span><span class="op">)</span> <span class="op">==</span> <span class="dv">27</span><span class="op">)</span> <span class="cf">break</span><span class="op">;</span> <span class="co">// ESC to exit</span></span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="dv">0</span><span class="op">;</span></span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="performance-optimization-techniques" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization-techniques" id="performance-optimization-techniques">Performance Optimization Techniques</h2>
<p><strong>SIMD Vectorization</strong>: Utilize SSE, AVX, or NEON instructions for parallel processing of image pixels or matrix operations, achieving 4-16x speedups on suitable operations.</p>
<p><strong>Multi-threading</strong>: Implement parallel processing using OpenMP, TBB, or std::thread for CPU-bound tasks, distributing workload across available cores.</p>
<p><strong>GPU Acceleration</strong>: Write CUDA kernels for NVIDIA GPUs or OpenCL for cross-platform acceleration, moving compute-intensive operations to massively parallel hardware.</p>
<p><strong>Memory Management</strong>: Minimize allocations, use object pooling, and leverage move semantics to reduce overhead and improve cache locality.</p>
<p><strong>Compiler Optimizations</strong>: Enable aggressive optimization flags (-O3, -march=native) and profile-guided optimization to squeeze maximum performance from code.</p>
</section>
<section id="when-to-choose-c" class="level2">
<h2 class="anchored" data-anchor-id="when-to-choose-c" id="when-to-choose-c">When to Choose C++</h2>
<p>C++ is the optimal choice when:</p>
<ul>
<li>Real-time performance with strict latency requirements is mandatory</li>
<li>Deploying to embedded systems or edge devices</li>
<li>Building production inference systems at scale</li>
<li>Integrating with existing C++ codebases or game engines</li>
<li>Developing for platforms without Python support</li>
<li>Requiring maximum control over hardware resources</li>
<li>Building commercial products where runtime licensing matters</li>
<li>Working with specialized hardware or custom accelerators</li>
</ul>
</section>
</section>
<section id="javascript-in-cv-ml" class="level1">
<h1>JavaScript in CV &amp; ML</h1>
<section id="overview-2" class="level2">
<h2 class="anchored" data-anchor-id="overview-2" id="overview-2">Overview</h2>
<p>JavaScript has emerged as a surprisingly capable platform for machine learning and computer vision, particularly for browser-based applications and interactive demos. While not matching Python’s ecosystem or C++’s performance, JavaScript’s ubiquity and zero-installation deployment make it valuable for specific use cases.</p>
</section>
<section id="key-strengths-2" class="level2">
<h2 class="anchored" data-anchor-id="key-strengths-2" id="key-strengths-2">Key Strengths</h2>
<p><strong>Browser-Native Execution</strong>: JavaScript runs directly in web browsers without installation, enabling instant deployment of ML models to billions of devices worldwide through simple URLs.</p>
<p><strong>Privacy-Preserving Computing</strong>: Client-side inference keeps sensitive data on user devices, crucial for healthcare, finance, or personal applications where data privacy is paramount.</p>
<p><strong>Interactive Experiences</strong>: JavaScript’s event-driven nature and DOM manipulation capabilities enable rich, responsive interfaces that react instantly to ML model predictions.</p>
<p><strong>Cross-Platform Reach</strong>: A single JavaScript codebase runs on desktops, mobile devices, and tablets through browsers, eliminating platform-specific development and distribution challenges.</p>
<p><strong>Server-Side Capabilities</strong>: Node.js enables JavaScript ML applications on servers, allowing full-stack JavaScript development with shared code between client and server.</p>
</section>
<section id="essential-libraries-frameworks-2" class="level2">
<h2 class="anchored" data-anchor-id="essential-libraries-frameworks-2" id="essential-libraries-frameworks-2">Essential Libraries &amp; Frameworks</h2>
<section id="deep-learning-1" class="level3">
<h3 class="anchored" data-anchor-id="deep-learning-1" id="deep-learning-1">Deep Learning</h3>
<p><strong>TensorFlow.js</strong>: The most comprehensive JavaScript ML library, offering:</p>
<ul>
<li>Pre-trained models for common tasks (image classification, object detection, pose estimation)</li>
<li>Model conversion from Python TensorFlow/Keras</li>
<li>Training capabilities directly in the browser</li>
<li>WebGL acceleration for GPU performance</li>
<li>Node.js backend for server-side execution</li>
<li>Transfer learning and fine-tuning support</li>
</ul>
<p><strong>ONNX.js</strong>: Microsoft’s runtime for ONNX models providing:</p>
<ul>
<li>Cross-framework model support</li>
<li>WebGL and WebAssembly backends</li>
<li>Optimized inference performance</li>
<li>Broad model compatibility</li>
</ul>
<p><strong>Brain.js</strong>: Lightweight neural network library ideal for:</p>
<ul>
<li>Simple neural networks without heavy dependencies</li>
<li>Recurrent networks (LSTM, GRU)</li>
<li>Educational purposes and prototyping</li>
<li>Projects where TensorFlow.js is overkill</li>
</ul>
<p><strong>ml5.js</strong>: Built on TensorFlow.js, ml5.js provides:</p>
<ul>
<li>Beginner-friendly API for common tasks</li>
<li>Pre-trained models (PoseNet, BodyPix, FaceApi)</li>
<li>Extensive documentation and examples</li>
<li>Focus on creative coding and art projects</li>
</ul>
</section>
<section id="computer-vision-1" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-1" id="computer-vision-1">Computer Vision</h3>
<p><strong>OpenCV.js</strong>: WebAssembly port of OpenCV offering:</p>
<ul>
<li>Core image processing functions</li>
<li>Feature detection and matching</li>
<li>Video analysis capabilities</li>
<li>Camera access through WebRTC</li>
<li>Near-native performance for many operations</li>
</ul>
<p><strong>Tracking.js</strong>: Specialized library for:</p>
<ul>
<li>Face and object tracking in video</li>
<li>Color tracking and detection</li>
<li>Custom tracker implementation</li>
<li>Lightweight and focused functionality</li>
</ul>
<p><strong>PixiJS</strong>: While primarily a rendering engine, PixiJS provides:</p>
<ul>
<li>High-performance 2D graphics with WebGL</li>
<li>Image filters and effects</li>
<li>Real-time image manipulation</li>
<li>Integration with ML models for visualization</li>
</ul>
</section>
</section>
<section id="practical-use-cases-2" class="level2">
<h2 class="anchored" data-anchor-id="practical-use-cases-2" id="practical-use-cases-2">Practical Use Cases</h2>
<p><strong>Interactive ML Demos</strong>: Creating educational visualizations and interactive demonstrations where users can instantly experiment with models, adjust parameters, and see results without installation barriers.</p>
<p><strong>Real-Time Webcam Applications</strong>: Building accessible applications for pose estimation, face filters, gesture recognition, or virtual try-on experiences that run entirely in the browser with no server required.</p>
<p><strong>Privacy-Sensitive Applications</strong>: Developing healthcare diagnostic tools, personal finance analyzers, or document processing systems where data never leaves the user’s device, ensuring compliance with privacy regulations.</p>
<p><strong>Progressive Web Apps</strong>: Creating installable web applications with offline ML capabilities, leveraging service workers to cache models and enable functionality without internet connectivity.</p>
<p><strong>IoT and Edge Browsers</strong>: Deploying ML models to embedded devices running lightweight browsers, enabling intelligent processing on resource-constrained hardware.</p>
<p><strong>A/B Testing and Experimentation</strong>: Rapidly deploying and testing different model versions to users without app store approval processes, enabling quick iteration based on real-world feedback.</p>
</section>
<section id="code-example-2" class="level2">
<h2 class="anchored" data-anchor-id="code-example-2" id="code-example-2">Code Example</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode javascript code-with-copy"><code class="sourceCode javascript"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Load MobileNet model for image classification</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="kw">const</span> model <span class="op">=</span> <span class="cf">await</span> mobilenet<span class="op">.</span><span class="fu">load</span>()<span class="op">;</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co">// Get video stream from webcam</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="kw">const</span> video <span class="op">=</span> <span class="bu">document</span><span class="op">.</span><span class="fu">getElementById</span>(<span class="st">'webcam'</span>)<span class="op">;</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="kw">const</span> stream <span class="op">=</span> <span class="cf">await</span> <span class="bu">navigator</span><span class="op">.</span><span class="at">mediaDevices</span><span class="op">.</span><span class="fu">getUserMedia</span>({ <span class="dt">video</span><span class="op">:</span> <span class="kw">true</span> })<span class="op">;</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>video<span class="op">.</span><span class="at">srcObject</span> <span class="op">=</span> stream<span class="op">;</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="co">// Classify images continuously</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="kw">async</span> <span class="kw">function</span> <span class="fu">classifyFrame</span>() {</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> predictions <span class="op">=</span> <span class="cf">await</span> model<span class="op">.</span><span class="fu">classify</span>(video)<span class="op">;</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Display top 3 predictions</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> resultsDiv <span class="op">=</span> <span class="bu">document</span><span class="op">.</span><span class="fu">getElementById</span>(<span class="st">'results'</span>)<span class="op">;</span></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    resultsDiv<span class="op">.</span><span class="at">innerHTML</span> <span class="op">=</span> predictions</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">slice</span>(<span class="dv">0</span><span class="op">,</span> <span class="dv">3</span>)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">map</span>(p <span class="kw">=&gt;</span> <span class="vs">`</span><span class="sc">${</span>p<span class="op">.</span><span class="at">className</span><span class="sc">}</span><span class="vs">: </span><span class="sc">${</span>(p<span class="op">.</span><span class="at">probability</span> <span class="op">*</span> <span class="dv">100</span>)<span class="op">.</span><span class="fu">toFixed</span>(<span class="dv">2</span>)<span class="sc">}</span><span class="vs">%`</span>)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">join</span>(<span class="st">'&lt;br&gt;'</span>)<span class="op">;</span></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>    <span class="fu">requestAnimationFrame</span>(classifyFrame)<span class="op">;</span></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a><span class="co">// Start classification</span></span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>video<span class="op">.</span><span class="fu">addEventListener</span>(<span class="st">'loadeddata'</span><span class="op">,</span> () <span class="kw">=&gt;</span> {</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    <span class="fu">classifyFrame</span>()<span class="op">;</span></span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>})<span class="op">;</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a><span class="co">// Custom model inference example with TensorFlow.js</span></span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a><span class="kw">async</span> <span class="kw">function</span> <span class="fu">runCustomModel</span>() {</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> model <span class="op">=</span> <span class="cf">await</span> tf<span class="op">.</span><span class="fu">loadLayersModel</span>(<span class="st">'model/model.json'</span>)<span class="op">;</span></span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> img <span class="op">=</span> <span class="bu">document</span><span class="op">.</span><span class="fu">getElementById</span>(<span class="st">'input-image'</span>)<span class="op">;</span></span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> tensor <span class="op">=</span> tf<span class="op">.</span><span class="at">browser</span><span class="op">.</span><span class="fu">fromPixels</span>(img)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">resizeNearestNeighbor</span>([<span class="dv">224</span><span class="op">,</span> <span class="dv">224</span>])</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">expandDims</span>()</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">toFloat</span>()</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        <span class="op">.</span><span class="fu">div</span>(<span class="fl">255.0</span>)<span class="op">;</span></span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">const</span> predictions <span class="op">=</span> <span class="cf">await</span> model<span class="op">.</span><span class="fu">predict</span>(tensor)<span class="op">.</span><span class="fu">data</span>()<span class="op">;</span></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>    <span class="bu">console</span><span class="op">.</span><span class="fu">log</span>(<span class="st">'Predictions:'</span><span class="op">,</span> predictions)<span class="op">;</span></span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Clean up tensors</span></span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>    tensor<span class="op">.</span><span class="fu">dispose</span>()<span class="op">;</span></span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</section>
<section id="performance-considerations-1" class="level2">
<h2 class="anchored" data-anchor-id="performance-considerations-1" id="performance-considerations-1">Performance Considerations</h2>
<p><strong>WebGL Acceleration</strong>: TensorFlow.js leverages WebGL for GPU acceleration, achieving performance within 2-3x of native implementations for many operations. Ensure WebGL is available and fallback to CPU when necessary.</p>
<p><strong>Model Size Optimization</strong>: Minimize model size through quantization (converting float32 to uint8), pruning unnecessary weights, and using efficient architectures like MobileNet or SqueezeNet to reduce download time and memory usage.</p>
<p><strong>WebAssembly</strong>: For compute-heavy operations not suited to WebGL, WebAssembly provides near-native performance, particularly beneficial for OpenCV.js operations.</p>
<p><strong>Lazy Loading</strong>: Split large models into chunks and load only necessary components to improve initial page load time and perceived performance.</p>
<p><strong>Web Workers</strong>: Move intensive computations to background threads to prevent blocking the main thread and maintain responsive user interfaces.</p>
</section>
<section id="limitations" class="level2">
<h2 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h2>
<p><strong>Performance Gap</strong>: JavaScript inference is typically 5-20x slower than Python with CUDA for equivalent models, making it unsuitable for large models or batch processing.</p>
<p><strong>Memory Constraints</strong>: Browser memory limits (typically 2-4GB) restrict model size and batch processing capabilities compared to server environments.</p>
<p><strong>Limited Training</strong>: While possible, training large models in browsers is impractical due to performance and memory constraints. JavaScript ML focuses primarily on inference.</p>
<p><strong>Ecosystem Maturity</strong>: Fewer pre-trained models, less community support, and limited documentation compared to Python’s mature ecosystem.</p>
</section>
<section id="when-to-choose-javascript" class="level2">
<h2 class="anchored" data-anchor-id="when-to-choose-javascript" id="when-to-choose-javascript">When to Choose JavaScript</h2>
<p>JavaScript is the optimal choice when:</p>
<ul>
<li>Zero-installation deployment to users is essential</li>
<li>Building privacy-preserving applications with client-side inference</li>
<li>Creating interactive demos or educational tools</li>
<li>Developing progressive web apps with offline ML capabilities</li>
<li>Prototyping ideas quickly for non-technical stakeholders</li>
<li>Leveraging existing web development skills and infrastructure</li>
<li>Building browser extensions with ML capabilities</li>
<li>Requiring cross-platform deployment without native code</li>
</ul>
</section>
</section>
<section id="golang-in-cv-ml" class="level1">
<h1>Golang in CV &amp; ML</h1>
<section id="overview-3" class="level2">
<h2 class="anchored" data-anchor-id="overview-3" id="overview-3">Overview</h2>
<p>Go (Golang) represents an emerging option for machine learning and computer vision, particularly suited for building production infrastructure, scalable services, and systems where Python’s performance limitations become apparent but C++’s complexity is unnecessary.</p>
</section>
<section id="key-strengths-3" class="level2">
<h2 class="anchored" data-anchor-id="key-strengths-3" id="key-strengths-3">Key Strengths</h2>
<p><strong>Exceptional Concurrency</strong>: Go’s goroutines and channels provide lightweight, elegant concurrency primitives perfect for parallel model inference, data pipeline processing, and handling multiple simultaneous requests.</p>
<p><strong>Production-Ready</strong>: Built-in tooling for testing, profiling, and deployment, combined with static typing and compile-time error checking, results in robust, maintainable production systems.</p>
<p><strong>Fast Compilation</strong>: Near-instant compilation enables rapid development cycles while producing optimized native binaries, bridging the gap between Python’s development speed and C++’s execution speed.</p>
<p><strong>Simple Deployment</strong>: Single binary deployment with no runtime dependencies simplifies containerization and distribution, making Go ideal for microservices and cloud-native ML systems.</p>
<p><strong>Resource Efficiency</strong>: Lower memory footprint and CPU usage compared to Python make Go attractive for cost-sensitive deployments and resource-constrained environments.</p>
</section>
<section id="essential-libraries-frameworks-3" class="level2">
<h2 class="anchored" data-anchor-id="essential-libraries-frameworks-3" id="essential-libraries-frameworks-3">Essential Libraries &amp; Frameworks</h2>
<section id="machine-learning-1" class="level3">
<h3 class="anchored" data-anchor-id="machine-learning-1" id="machine-learning-1">Machine Learning</h3>
<p><strong>Gorgonia</strong>: The primary deep learning library for Go, providing:</p>
<ul>
<li>Automatic differentiation and gradient computation</li>
<li>Neural network building blocks</li>
<li>CUDA support for GPU acceleration</li>
<li>Similar API design to PyTorch</li>
<li>Active development and growing community</li>
</ul>
<p><strong>GoLearn</strong>: Comprehensive machine learning library offering:</p>
<ul>
<li>Decision trees and ensemble methods</li>
<li>Linear models and regularization</li>
<li>Clustering algorithms</li>
<li>Model evaluation and cross-validation</li>
<li>Scikit-learn-inspired API design</li>
</ul>
<p><strong>GoML</strong>: Focused on traditional ML algorithms with:</p>
<ul>
<li>Online learning implementations</li>
<li>Stochastic gradient descent variants</li>
<li>Perceptron and linear models</li>
<li>Clear, readable code for learning</li>
</ul>
<p><strong>TensorFlow Go Bindings</strong>: Official Go API for TensorFlow enabling:</p>
<ul>
<li>Loading and running SavedModel format models</li>
<li>Integration with TensorFlow ecosystem</li>
<li>Production inference deployment</li>
<li>Limited training capabilities</li>
</ul>
</section>
<section id="computer-vision-2" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-2" id="computer-vision-2">Computer Vision</h3>
<p><strong>GoCV</strong>: Go bindings for OpenCV 4, providing access to:</p>
<ul>
<li>Comprehensive image processing functions</li>
<li>Video capture and analysis</li>
<li>Face detection and recognition</li>
<li>Feature extraction and matching</li>
<li>Integration with cameras and video files</li>
<li>CUDA acceleration support</li>
</ul>
<p><strong>Gift (Go Image Filtering Toolkit)</strong>: Pure Go image processing with:</p>
<ul>
<li>Convolution and filters</li>
<li>Resampling algorithms</li>
<li>Histogram operations</li>
<li>Format conversion utilities</li>
</ul>
<p><strong>BImg</strong>: High-performance image manipulation using libvips:</p>
<ul>
<li>Fast resize and crop operations</li>
<li>Format conversion</li>
<li>Image pipeline processing</li>
<li>Optimized for web services</li>
</ul>
</section>
</section>
<section id="practical-use-cases-3" class="level2">
<h2 class="anchored" data-anchor-id="practical-use-cases-3" id="practical-use-cases-3">Practical Use Cases</h2>
<p><strong>ML Inference Microservices</strong>: Building scalable, containerized services that load pre-trained models and serve predictions via REST or gRPC APIs, handling thousands of concurrent requests efficiently.</p>
<p><strong>Data Pipeline Orchestration</strong>: Creating ETL pipelines that preprocess data, perform feature engineering, and feed processed data to models, leveraging Go’s concurrency for parallel processing of large datasets.</p>
<p><strong>Model Serving Infrastructure</strong>: Developing custom model serving frameworks with load balancing, A/B testing, and monitoring capabilities, where Go’s performance and simplicity outshine Python-based solutions.</p>
<p><strong>Real-Time Processing Systems</strong>: Building systems that process video streams or sensor data in real-time, applying ML models for anomaly detection, quality control, or monitoring applications.</p>
<p><strong>Edge Computing Gateways</strong>: Creating lightweight gateways for IoT devices that aggregate data, perform local inference, and manage communication with cloud services efficiently.</p>
<p><strong>CLI Tools for ML Operations</strong>: Developing command-line tools for model deployment, monitoring, data validation, and MLOps workflows, distributed as single binaries.</p>
</section>
<section id="code-example-3" class="level2">
<h2 class="anchored" data-anchor-id="code-example-3" id="code-example-3">Code Example</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode go code-with-copy"><code class="sourceCode go"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">package</span> main</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="kw">import</span> <span class="op">(</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fmt"</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gocv.io/x/gocv"</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    tf <span class="st">"github.com/tensorflow/tensorflow/tensorflow/go"</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="op">)</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="kw">func</span> main<span class="op">()</span> <span class="op">{</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Load TensorFlow model</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    model<span class="op">,</span> err <span class="op">:=</span> tf<span class="op">.</span>LoadSavedModel<span class="op">(</span><span class="st">"model_path"</span><span class="op">,</span> <span class="op">[]</span><span class="dt">string</span><span class="op">{</span><span class="st">"serve"</span><span class="op">},</span> <span class="ot">nil</span><span class="op">)</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> err <span class="op">!=</span> <span class="ot">nil</span> <span class="op">{</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        <span class="bu">panic</span><span class="op">(</span>err<span class="op">)</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">defer</span> model<span class="op">.</span>Session<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Open webcam</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    webcam<span class="op">,</span> err <span class="op">:=</span> gocv<span class="op">.</span>OpenVideoCapture<span class="op">(</span><span class="dv">0</span><span class="op">)</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> err <span class="op">!=</span> <span class="ot">nil</span> <span class="op">{</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        <span class="bu">panic</span><span class="op">(</span>err<span class="op">)</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">defer</span> webcam<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Create window</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    window <span class="op">:=</span> gocv<span class="op">.</span>NewWindow<span class="op">(</span><span class="st">"Detection"</span><span class="op">)</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">defer</span> window<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>    img <span class="op">:=</span> gocv<span class="op">.</span>NewMat<span class="op">()</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">defer</span> img<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> <span class="op">{</span></span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> ok <span class="op">:=</span> webcam<span class="op">.</span>Read<span class="op">(&amp;</span>img<span class="op">);</span> <span class="op">!</span>ok <span class="op">{</span></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> img<span class="op">.</span>Empty<span class="op">()</span> <span class="op">{</span></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Preprocess image</span></span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>        resized <span class="op">:=</span> gocv<span class="op">.</span>NewMat<span class="op">()</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>        gocv<span class="op">.</span>Resize<span class="op">(</span>img<span class="op">,</span> <span class="op">&amp;</span>resized<span class="op">,</span> image<span class="op">.</span>Pt<span class="op">(</span><span class="dv">224</span><span class="op">,</span> <span class="dv">224</span><span class="op">),</span> <span class="dv">0</span><span class="op">,</span> <span class="dv">0</span><span class="op">,</span> gocv<span class="op">.</span>InterpolationLinear<span class="op">)</span></span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Convert to float32 and normalize</span></span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        normalized <span class="op">:=</span> gocv<span class="op">.</span>NewMat<span class="op">()</span></span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        resized<span class="op">.</span>ConvertTo<span class="op">(&amp;</span>normalized<span class="op">,</span> gocv<span class="op">.</span>MatTypeCV32F<span class="op">)</span></span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>        normalized<span class="op">.</span>DivideFloat<span class="op">(</span><span class="fl">255.0</span><span class="op">)</span></span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Create tensor and run inference</span></span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>        tensor<span class="op">,</span> _ <span class="op">:=</span> tf<span class="op">.</span>NewTensor<span class="op">(</span>convertMatToTensor<span class="op">(</span>normalized<span class="op">))</span></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>        result<span class="op">,</span> err <span class="op">:=</span> model<span class="op">.</span>Session<span class="op">.</span>Run<span class="op">(</span></span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>            <span class="kw">map</span><span class="op">[</span>tf<span class="op">.</span>Output<span class="op">]*</span>tf<span class="op">.</span>Tensor<span class="op">{</span></span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>                model<span class="op">.</span>Graph<span class="op">.</span>Operation<span class="op">(</span><span class="st">"input"</span><span class="op">).</span>Output<span class="op">(</span><span class="dv">0</span><span class="op">):</span> tensor<span class="op">,</span></span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>            <span class="op">},</span></span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>            <span class="op">[]</span>tf<span class="op">.</span>Output<span class="op">{</span></span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>                model<span class="op">.</span>Graph<span class="op">.</span>Operation<span class="op">(</span><span class="st">"output"</span><span class="op">).</span>Output<span class="op">(</span><span class="dv">0</span><span class="op">),</span></span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>            <span class="op">},</span></span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>            <span class="ot">nil</span><span class="op">,</span></span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        <span class="op">)</span></span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> err <span class="op">==</span> <span class="ot">nil</span> <span class="op">{</span></span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">:=</span> result<span class="op">[</span><span class="dv">0</span><span class="op">].</span>Value<span class="op">().([][]</span><span class="dt">float32</span><span class="op">)</span></span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>            fmt<span class="op">.</span>Printf<span class="op">(</span><span class="st">"Predictions: %v</span><span class="ch">\n</span><span class="st">"</span><span class="op">,</span> predictions<span class="op">)</span></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a>        window<span class="op">.</span>IMShow<span class="op">(</span>img<span class="op">)</span></span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> window<span class="op">.</span>WaitKey<span class="op">(</span><span class="dv">1</span><span class="op">)</span> <span class="op">==</span> <span class="dv">27</span> <span class="op">{</span></span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>        resized<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-71"><a href="#cb4-71" aria-hidden="true" tabindex="-1"></a>        normalized<span class="op">.</span>Close<span class="op">()</span></span>
<span id="cb4-72"><a href="#cb4-72" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb4-73"><a href="#cb4-73" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="integration-patterns" class="level2">
<h2 class="anchored" data-anchor-id="integration-patterns" id="integration-patterns">Integration Patterns</h2>
<p><strong>Python Model Training + Go Inference</strong>: The most common pattern involves training models in Python using PyTorch or TensorFlow, converting to ONNX or SavedModel format, then deploying inference services in Go for production performance and scalability.</p>
<p><strong>Hybrid Services</strong>: Building services where Go handles HTTP routing, request validation, and concurrency management, while delegating actual inference to Python workers via gRPC or message queues.</p>
<p><strong>Batch Processing</strong>: Using Go to coordinate distributed batch inference jobs across multiple workers, aggregating results, and managing job queues, leveraging Go’s excellent concurrency model.</p>
<p><strong>Feature Engineering</strong>: Implementing performance-critical feature extraction and data preprocessing in Go, producing features consumed by downstream Python models.</p>
</section>
<section id="performance-characteristics" class="level2">
<h2 class="anchored" data-anchor-id="performance-characteristics" id="performance-characteristics">Performance Characteristics</h2>
<p>Go typically provides 2-5x better performance than Python for inference and data processing tasks while using 30-50% less memory. Compilation produces optimized binaries approaching C++ performance for many operations, particularly benefiting from Go’s efficient garbage collector tuned for server workloads.</p>
<p>However, Go lacks the optimized numerical computing libraries that make Python fast (NumPy’s BLAS/LAPACK integration, optimized convolution kernels), so raw model execution may not match Python frameworks using native acceleration.</p>
</section>
<section id="limitations-1" class="level2">
<h2 class="anchored" data-anchor-id="limitations-1" id="limitations-1">Limitations</h2>
<p><strong>Immature Ecosystem</strong>: Go’s ML ecosystem is years behind Python, with fewer pre-trained models, less documentation, smaller communities, and ongoing API changes in core libraries.</p>
<p><strong>Limited GPU Support</strong>: While Gorgonia supports CUDA, GPU acceleration is less mature and harder to configure compared to Python frameworks with extensive optimization.</p>
<p><strong>Training Capabilities</strong>: Training complex models in Go is impractical due to limited automatic differentiation frameworks and lack of training-focused tools and optimizations.</p>
<p><strong>Interoperability Friction</strong>: Integrating with Python-trained models often requires conversion steps, format compatibility checks, and debugging serialization issues.</p>
</section>
<section id="when-to-choose-go" class="level2">
<h2 class="anchored" data-anchor-id="when-to-choose-go" id="when-to-choose-go">When to Choose Go</h2>
<p>Go is the optimal choice when:</p>
<ul>
<li>Building production inference services requiring high throughput</li>
<li>Developing microservices architecture for ML systems</li>
<li>Creating CLI tools for ML operations and deployment</li>
<li>Implementing data processing pipelines with heavy concurrency</li>
<li>Deploying to resource-constrained cloud environments</li>
<li>Requiring simple deployment without Python dependencies</li>
<li>Building real-time processing systems with Go-native components</li>
<li>Needing better performance than Python without C++ complexity</li>
<li>Working in organizations with existing Go infrastructure</li>
</ul>
</section>
</section>
<section id="comparison-use-case-selection" class="level1">
<h1>Comparison &amp; Use Case Selection</h1>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<section id="inference-speed" class="level3">
<h3 class="anchored" data-anchor-id="inference-speed" id="inference-speed">Inference Speed</h3>
<p>(Relative, CPU-bound operations)</p>
<ul>
<li><strong>C++</strong>: 1.0x (baseline, fastest)</li>
<li><strong>Go</strong>: 1.5-3x slower than C++</li>
<li><strong>Python (NumPy/optimized)</strong>: 2-4x slower than C++</li>
<li><strong>Python (pure)</strong>: 50-100x slower than C++</li>
<li><strong>JavaScript (WebGL)</strong>: 2-5x slower than C++</li>
<li><strong>JavaScript (CPU)</strong>: 10-30x slower than C++</li>
</ul>
</section>
<section id="development-speed" class="level3">
<h3 class="anchored" data-anchor-id="development-speed" id="development-speed">Development Speed</h3>
<ul>
<li><strong>Python</strong>: Fastest (hours to prototype)</li>
<li><strong>JavaScript</strong>: Fast (hours to days)</li>
<li><strong>Go</strong>: Medium (days)</li>
<li><strong>C++</strong>: Slowest (days to weeks)</li>
</ul>
</section>
<section id="memory-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficiency" id="memory-efficiency">Memory Efficiency</h3>
<ul>
<li><strong>C++</strong>: Most efficient (full control)</li>
<li><strong>Go</strong>: Very efficient (garbage collection overhead)</li>
<li><strong>JavaScript</strong>: Moderate (browser constraints)</li>
<li><strong>Python</strong>: Least efficient (interpreter overhead)</li>
</ul>
</section>
</section>
<section id="selection-matrix" class="level2">
<h2 class="anchored" data-anchor-id="selection-matrix" id="selection-matrix">Selection Matrix</h2>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Python</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">C++</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">JavaScript</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-4" role="tab" aria-controls="tabset-1-4" aria-selected="false" href="">Go</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>Choose Python when:</p>
<ul>
<li>Research and experimentation are primary goals</li>
<li>Leveraging pre-trained models and established architectures</li>
<li>Rapid prototyping is essential</li>
<li>Working with data science teams</li>
<li>Building end-to-end ML pipelines</li>
<li>Using Jupyter notebooks for exploration</li>
<li>Requiring the richest ecosystem and community support</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Choose C++ when:</p>
<ul>
<li>Real-time performance with low latency is critical</li>
<li>Deploying to embedded or edge devices</li>
<li>Building production inference at massive scale</li>
<li>Integrating with game engines or robotics systems</li>
<li>Developing for platforms without high-level language support</li>
<li>Requiring custom hardware acceleration</li>
<li>Building commercial products with strict performance SLAs</li>
</ul>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Choose JavaScript when:</p>
<ul>
<li>Deploying directly to web browsers</li>
<li>Building interactive demos and visualizations</li>
<li>Privacy-preserving client-side inference</li>
<li>Creating progressive web apps with ML</li>
<li>Zero-installation deployment is essential</li>
<li>Targeting the widest possible audience</li>
<li>Developing browser extensions with ML features</li>
</ul>
</div>
<div id="tabset-1-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-4-tab">
<p>Choose Go when:</p>
<ul>
<li>Building scalable microservices for inference</li>
<li>Developing ML infrastructure and tooling</li>
<li>Creating data processing pipelines</li>
<li>Deploying containerized services efficiently</li>
<li>Requiring better performance than Python without C++ complexity</li>
<li>Building CLI tools for MLOps</li>
<li>Working in Go-native environments</li>
</ul>
</div>
</div>
</div>
</section>
<section id="hybrid-approaches" class="level2">
<h2 class="anchored" data-anchor-id="hybrid-approaches" id="hybrid-approaches">Hybrid Approaches</h2>
<p>Most production ML systems use multiple languages, each for its strengths:</p>
<p><strong>Research → Production Pipeline</strong>:</p>
<ol type="1">
<li>Prototype and train models in Python (PyTorch/TensorFlow)</li>
<li>Convert to ONNX or TorchScript</li>
<li>Deploy inference in C++ or Go for performance</li>
<li>Use JavaScript for web-based demos and client applications</li>
</ol>
<p><strong>Microservices Architecture</strong>:</p>
<ul>
<li>Go services handle routing, load balancing, and orchestration</li>
<li>Python services perform model inference and complex data processing</li>
<li>C++ services handle real-time components and hardware interfaces</li>
<li>JavaScript clients provide user interfaces and client-side features</li>
</ul>
<p><strong>Edge-Cloud Hybrid</strong>:</p>
<ul>
<li>Train models in Python on cloud GPUs</li>
<li>Deploy lightweight models to edge devices in C++</li>
<li>Use Go for edge gateway aggregation and processing</li>
<li>Provide web interfaces with JavaScript for monitoring and control</li>
</ul>
</section>
<section id="future-trends" class="level2">
<h2 class="anchored" data-anchor-id="future-trends" id="future-trends">Future Trends</h2>
<p><strong>Python</strong>: Will maintain dominance in research and development, with continued focus on making production deployment easier through better compilation (PyTorch 2.0), type hints, and packaging improvements.</p>
<p><strong>C++</strong>: Remains essential for performance-critical production systems, with modern C++ standards (C++20, C++23) making the language more accessible while maintaining zero-overhead principles.</p>
<p><strong>JavaScript</strong>: Growing capabilities with WebGPU on the horizon, enabling better performance for ML in browsers and expanding use cases for client-side inference.</p>
<p><strong>Go</strong>: Ecosystem maturation with better ML libraries, increased adoption for ML infrastructure, and improved interoperability with Python, making it increasingly viable for production deployments.</p>
</section>
<section id="practical-decision-framework" class="level2">
<h2 class="anchored" data-anchor-id="practical-decision-framework" id="practical-decision-framework">Practical Decision Framework</h2>
<p>When selecting a language for a CV/ML project, consider these factors in order:</p>
<ol type="1">
<li><strong>Deployment Target</strong>: Where will the model run? (cloud, edge, browser, mobile)</li>
<li><strong>Performance Requirements</strong>: What latency and throughput are needed?</li>
<li><strong>Team Expertise</strong>: What languages does your team know well?</li>
<li><strong>Development Timeline</strong>: How quickly do you need to deliver?</li>
<li><strong>Ecosystem Needs</strong>: What pre-trained models or libraries are required?</li>
<li><strong>Maintenance Burden</strong>: Who will maintain the code long-term?</li>
<li><strong>Integration Constraints</strong>: What existing systems must you integrate with?</li>
</ol>
</section>
<section id="cost-considerations" class="level2">
<h2 class="anchored" data-anchor-id="cost-considerations" id="cost-considerations">Cost Considerations</h2>
<section id="development-costs" class="level3">
<h3 class="anchored" data-anchor-id="development-costs" id="development-costs">Development Costs</h3>
<ul>
<li><strong>Python</strong>: Lowest (fast development, large talent pool)</li>
<li><strong>JavaScript</strong>: Low to moderate (web developers abundant)</li>
<li><strong>Go</strong>: Moderate (smaller talent pool than Python/JS)</li>
<li><strong>C++</strong>: Highest (longer development time, specialized skills)</li>
</ul>
</section>
<section id="infrastructure-costs" class="level3">
<h3 class="anchored" data-anchor-id="infrastructure-costs" id="infrastructure-costs">Infrastructure Costs</h3>
<ul>
<li><strong>C++</strong>: Lowest (efficient resource usage)</li>
<li><strong>Go</strong>: Low (efficient, good concurrency)</li>
<li><strong>Python</strong>: Moderate to high (higher memory/CPU needs)</li>
<li><strong>JavaScript</strong>: Variable (client-side = free, server-side = moderate)</li>
</ul>
<p><strong>Total Cost of Ownership</strong>: For many projects, Python’s lower development costs outweigh higher infrastructure costs. C++ makes sense when infrastructure costs dominate or performance requirements are absolute.</p>
</section>
</section>
</section>
<section id="sec-advanced" class="level1">
<h1>Advanced Topics</h1>
<section id="cross-language-integration" class="level2">
<h2 class="anchored" data-anchor-id="cross-language-integration" id="cross-language-integration">Cross-Language Integration</h2>
<section id="python-c-integration" class="level3">
<h3 class="anchored" data-anchor-id="python-c-integration" id="python-c-integration">Python-C++ Integration</h3>
<p><strong>pybind11</strong>: Modern C++ binding generator allowing seamless Python-C++ interoperation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode cpp code-with-copy"><code class="sourceCode cpp"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="pp">#include </span><span class="im">&lt;pybind11/pybind11.h&gt;</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="dt">int</span> fast_compute<span class="op">(</span><span class="dt">int</span> n<span class="op">)</span> <span class="op">{</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Performance-critical C++ code</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">*</span> n<span class="op">;</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>PYBIND11_MODULE<span class="op">(</span>example<span class="op">,</span> m<span class="op">)</span> <span class="op">{</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    m<span class="op">.</span>def<span class="op">(</span><span class="st">"fast_compute"</span><span class="op">,</span> <span class="op">&amp;</span>fast_compute<span class="op">);</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
<p><strong>ctypes</strong>: Call C/C++ shared libraries directly from Python without compilation:</p>
<div id="7c72db79" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> ctypes</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>lib <span class="op">=</span> ctypes.CDLL(<span class="st">'./libexample.so'</span>)</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>lib.fast_compute.argtypes <span class="op">=</span> [ctypes.c_int]</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>lib.fast_compute.restype <span class="op">=</span> ctypes.c_int</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> lib.fast_compute(<span class="dv">42</span>)</span></code></pre></div></div>
</div>
<p><strong>Cython</strong>: Write Python-like code that compiles to C extensions:</p>
<div id="24066ec6" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># cython_module.pyx</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fast_compute(<span class="bu">int</span> n):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    cdef <span class="bu">int</span> result <span class="op">=</span> n <span class="op">*</span> n</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result</span></code></pre></div></div>
</div>
</section>
<section id="go-python-integration" class="level3">
<h3 class="anchored" data-anchor-id="go-python-integration" id="go-python-integration">Go-Python Integration</h3>
<p><strong>gRPC</strong>: Language-agnostic RPC framework for microservices communication:</p>
<ul>
<li>Define service contracts in Protocol Buffers</li>
<li>Generate client/server code for both languages</li>
<li>Efficient binary serialization</li>
<li>Streaming support for large data</li>
</ul>
<p><strong>Message Queues</strong>: Decouple services using RabbitMQ, Kafka, or Redis:</p>
<ul>
<li>Python services publish inference requests</li>
<li>Go services consume and process</li>
<li>Asynchronous, scalable architecture</li>
<li>Fault tolerance and retry logic</li>
</ul>
</section>
</section>
<section id="model-conversion-and-interoperability" class="level2">
<h2 class="anchored" data-anchor-id="model-conversion-and-interoperability" id="model-conversion-and-interoperability">Model Conversion and Interoperability</h2>
<p><strong>ONNX (Open Neural Network Exchange)</strong>: Universal format for model interchange:</p>
<ul>
<li>Export from PyTorch, TensorFlow, or other frameworks</li>
<li>Import into C++, JavaScript, or Go runtimes</li>
<li>Maintain model accuracy across platforms</li>
<li>Optimize for specific hardware targets</li>
</ul>
<div id="e30ad7aa" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Export PyTorch to ONNX</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(model, dummy_input, <span class="st">"model.onnx"</span>)</span></code></pre></div></div>
</div>
<p><strong>TorchScript</strong>: PyTorch’s serialization format for production:</p>
<ul>
<li>Trace or script Python models</li>
<li>Load in C++ with LibTorch</li>
<li>Preserve dynamic behavior</li>
<li>Optimize for inference</li>
</ul>
<p><strong>SavedModel</strong>: TensorFlow’s standard format:</p>
<ul>
<li>Compatible with TensorFlow Serving</li>
<li>Load in C++, Go, or JavaScript</li>
<li>Include preprocessing and postprocessing</li>
<li>Version management built-in</li>
</ul>
</section>
<section id="deployment-strategies" class="level2">
<h2 class="anchored" data-anchor-id="deployment-strategies" id="deployment-strategies">Deployment Strategies</h2>
<p><strong>Containerization</strong>: Use Docker for consistent environments:</p>
<ul>
<li>Python: Include dependencies in requirements.txt</li>
<li>C++: Multi-stage builds for minimal images</li>
<li>Go: Scratch or distroless base images</li>
<li>JavaScript: Node.js or static file serving</li>
</ul>
<p><strong>Serverless</strong>: Deploy models without managing infrastructure:</p>
<ul>
<li>Python: AWS Lambda, Google Cloud Functions</li>
<li>JavaScript: Cloudflare Workers, Vercel</li>
<li>Go: Supported by major cloud providers</li>
<li>C++: Limited support, often via custom runtimes</li>
</ul>
<p><strong>Kubernetes</strong>: Orchestrate ML microservices at scale:</p>
<ul>
<li>Horizontal pod autoscaling for inference services</li>
<li>GPU scheduling and resource quotas</li>
<li>Service mesh for traffic management</li>
<li>Helm charts for deployment automation</li>
</ul>
</section>
<section id="monitoring-and-observability" class="level2">
<h2 class="anchored" data-anchor-id="monitoring-and-observability" id="monitoring-and-observability">Monitoring and Observability</h2>
<p>Regardless of language choice, production ML systems require:</p>
<p><strong>Metrics Collection</strong>:</p>
<ul>
<li>Inference latency (p50, p95, p99)</li>
<li>Throughput (requests per second)</li>
<li>Model accuracy and drift detection</li>
<li>Resource utilization (CPU, memory, GPU)</li>
</ul>
<p><strong>Logging</strong>:</p>
<ul>
<li>Request/response logging for debugging</li>
<li>Error tracking and alerting</li>
<li>Model version and configuration tracking</li>
<li>A/B test result aggregation</li>
</ul>
<p><strong>Tracing</strong>:</p>
<ul>
<li>Distributed tracing for microservices</li>
<li>Identify bottlenecks in pipelines</li>
<li>Understand cross-service dependencies</li>
<li>Debug performance issues</li>
</ul>
</section>
</section>
<section id="learning-resources" class="level1">
<h1>Learning Resources</h1>
<section id="python-1" class="level2">
<h2 class="anchored" data-anchor-id="python-1" id="python-1">Python</h2>
<ul>
<li><strong>Official PyTorch Tutorials</strong>: <a href="https://tutorials.pytorch.org">tutorials.pytorch.org</a></li>
<li><strong>TensorFlow Guides</strong>: <a href="https://tensorflow.org/tutorials">tensorflow.org/tutorials</a></li>
<li><strong>Fast.ai Course</strong>: Practical deep learning for coders</li>
<li><strong>Papers with Code</strong>: Browse implementations of latest research</li>
<li><strong>Kaggle</strong>: Competitions and notebooks for hands-on learning</li>
</ul>
</section>
<section id="c-1" class="level2">
<h2 class="anchored" data-anchor-id="c-1" id="c-1">C++</h2>
<ul>
<li><strong>Learn OpenCV</strong>: <a href="https://learnopencv.com">learnopencv.com</a> for practical tutorials</li>
<li><strong>LibTorch Documentation</strong>: <a href="https://pytorch.org/cppdocs">pytorch.org/cppdocs</a></li>
<li><strong>Modern C++ for CV</strong>: Focus on C++17/20 features</li>
<li><strong>CUDA Programming Guide</strong>: For GPU acceleration</li>
<li><strong>Effective Modern C++</strong>: Book by Scott Meyers</li>
</ul>
</section>
<section id="javascript-1" class="level2">
<h2 class="anchored" data-anchor-id="javascript-1" id="javascript-1">JavaScript</h2>
<ul>
<li><strong>TensorFlow.js Documentation</strong>: <a href="https://js.tensorflow.org">js.tensorflow.org</a></li>
<li><strong>ML5.js Examples</strong>: <a href="https://ml5js.org">ml5js.org</a> for creative coding</li>
<li><strong>WebGL Fundamentals</strong>: Understanding GPU acceleration</li>
<li><strong>JavaScript.info</strong>: Deep dive into modern JavaScript</li>
<li><strong>MDN Web Docs</strong>: Authoritative web API reference</li>
</ul>
</section>
<section id="go-1" class="level2">
<h2 class="anchored" data-anchor-id="go-1" id="go-1">Go</h2>
<ul>
<li><strong>Gorgonia Documentation</strong>: <a href="https://gorgonia.org">gorgonia.org</a></li>
<li><strong>GoCV Examples</strong>: <a href="https://gocv.io/getting-started">gocv.io/getting-started</a></li>
<li><strong>A Tour of Go</strong>: <a href="https://tour.golang.org">tour.golang.org</a> for language basics</li>
<li><strong>Go by Example</strong>: <a href="https://gobyexample.com">gobyexample.com</a> for practical patterns</li>
<li><strong>Effective Go</strong>: <a href="https://golang.org/doc/effective_go">golang.org/doc/effective_go</a></li>
</ul>
</section>
</section>
<section id="conclusion" class="level1">
<h1>Conclusion</h1>
<p>The choice of programming language for computer vision and machine learning projects depends on a careful balance of performance requirements, development speed, team expertise, and deployment constraints. While Python dominates research and initial development, production systems often benefit from C++’s performance, Go’s efficiency, or JavaScript’s accessibility.</p>
<p>The most successful ML systems typically leverage multiple languages, using each for its strengths: Python for experimentation and training, C++ for performance-critical components, Go for scalable infrastructure, and JavaScript for user interfaces. Understanding the capabilities and trade-offs of each language enables you to architect systems that are both powerful and maintainable.</p>
<p>As the field evolves, the boundaries between languages blur through improved interoperability tools, cross-compilation, and unified runtime environments. The key is not to seek a single “best” language, but to develop proficiency across multiple languages and understand when each is the right tool for the job.</p>
<p>Whether you’re building cutting-edge research prototypes, deploying models to millions of users, or creating interactive educational tools, mastering the intersection of these languages with computer vision and machine learning will position you to tackle any challenge in this rapidly advancing field.</p>
<section id="summary-table" class="level2">
<h2 class="anchored" data-anchor-id="summary-table" id="summary-table">Summary Table</h2>
<div id="tbl-summary" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-summary-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Language Comparison Summary
</figcaption>
<div aria-describedby="tbl-summary-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 19%">
<col style="width: 16%">
<col style="width: 20%">
<col style="width: 17%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Language</th>
<th>Best For</th>
<th>Performance</th>
<th>Ecosystem</th>
<th>Learning Curve</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Python</td>
<td>Research, Prototyping, Training</td>
<td>Moderate</td>
<td>Excellent</td>
<td>Easy</td>
</tr>
<tr class="even">
<td>C++</td>
<td>Production, Embedded, Real-time</td>
<td>Excellent</td>
<td>Good</td>
<td>Hard</td>
</tr>
<tr class="odd">
<td>JavaScript</td>
<td>Web Apps, Demos, Client-side</td>
<td>Moderate</td>
<td>Good</td>
<td>Easy</td>
</tr>
<tr class="even">
<td>Go</td>
<td>Infrastructure, Microservices</td>
<td>Good</td>
<td>Growing</td>
<td>Moderate</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<hr>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Additional Resources
</div>
</div>
<div class="callout-body-container callout-body">
<p>For more information on specific topics, refer to the linked documentation and tutorials throughout this guide. The ML/CV landscape evolves rapidly, so always check for the latest versions and best practices.</p>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Getting Started
</div>
</div>
<div class="callout-body-container callout-body">
<p>If you’re new to ML/CV, start with Python and PyTorch. Once comfortable, explore other languages based on your specific deployment needs and performance requirements.</p>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[SGLang: Comprehensive Guide to Structured Generation Language]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/sglang/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/sglang/</guid>
      <pubDate>Mon, 25 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="sglang-comprehensive-guide-to-structured-generation-language" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/sglang/lang.png" class="img-fluid"></p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>About This Guide
</div>
</div>
<div class="callout-body-container callout-body">
<p>This comprehensive guide covers SGLang (Structured Generation Language), a revolutionary framework that transforms how developers interact with large language models (LLMs) and vision-language models. SGLang achieves unprecedented performance improvements while maintaining programming simplicity and flexibility.</p>
</div>
</div>
<section id="sec-introduction" class="level2">
<h2 class="anchored" data-anchor-id="sec-introduction" id="sec-introduction">Introduction</h2>
<p>SGLang (Structured Generation Language) is a revolutionary framework that transforms how developers interact with large language models (LLMs) and vision-language models. By co-designing both the frontend programming interface and the backend runtime system, SGLang achieves unprecedented performance improvements while maintaining programming simplicity and flexibility.</p>
</section>
<section id="sec-what-is-sglang" class="level2">
<h2 class="anchored" data-anchor-id="sec-what-is-sglang" id="sec-what-is-sglang">What is SGLang?</h2>
<p>SGLang is a fast serving framework for large language models and vision language models that makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. SGLang consists of a frontend language and a runtime, where the frontend simplifies programming with primitives for generation and parallelism control, and the runtime accelerates execution with novel optimizations.</p>
<section id="key-benefits" class="level3">
<h3 class="anchored" data-anchor-id="key-benefits" id="key-benefits">Key Benefits</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Performance</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Controllability</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Expressiveness</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-4" role="tab" aria-controls="tabset-1-4" aria-selected="false" href="">Efficiency</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-5-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-5" role="tab" aria-controls="tabset-1-5" aria-selected="false" href="">Multimodal Support</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>Up to <strong>5x throughput improvements</strong> over traditional serving methods through advanced optimization techniques.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Fine-grained control over generation processes with structured primitives and constraint handling.</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Rich primitives for complex LLM programming patterns including parallel execution and multi-step reasoning.</p>
</div>
<div id="tabset-1-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-4-tab">
<p>Advanced caching and optimization techniques including RadixAttention for automatic KV cache reuse.</p>
</div>
<div id="tabset-1-5" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-5-tab">
<p>Native support for both language and vision-language models with unified processing pipeline.</p>
</div>
</div>
</div>
</section>
</section>
<section id="sec-key-features" class="level2">
<h2 class="anchored" data-anchor-id="sec-key-features" id="sec-key-features">Key Features</h2>
<section id="frontend-language-features" class="level3">
<h3 class="anchored" data-anchor-id="frontend-language-features" id="frontend-language-features">Frontend Language Features</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A[Frontend Language] --&gt; B[Embedded DSL]
    A --&gt; C[Generation Primitives]
    A --&gt; D[Parallelism Control]
    A --&gt; E[Structured Outputs]
    A --&gt; F[Template System]
    
    B --&gt; B1[Python Integration]
    C --&gt; C1["gen()" function]
    C --&gt; C2["select()" function]
    D --&gt; D1["fork()" for Parallel]
    E --&gt; E1[JSON/XML Support]
    F --&gt; F1[Dynamic Prompts]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<ul>
<li><strong>Embedded DSL</strong>: Domain-specific language embedded in Python</li>
<li><strong>Generation Primitives</strong>: Built-in functions for text generation and control</li>
<li><strong>Parallelism Control</strong>: Native support for parallel generation calls</li>
<li><strong>Structured Outputs</strong>: Easy handling of JSON, XML, and custom formats</li>
<li><strong>Template System</strong>: Powerful templating for dynamic prompt construction</li>
</ul>
</section>
<section id="backend-runtime-features" class="level3">
<h3 class="anchored" data-anchor-id="backend-runtime-features" id="backend-runtime-features">Backend Runtime Features</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A[Backend Runtime] --&gt; B[RadixAttention]
    A --&gt; C[Zero-overhead Scheduler]
    A --&gt; D[Continuous Batching]
    A --&gt; E[Speculative Decoding]
    A --&gt; F[Multi-modal Processing]
    A --&gt; G[Quantization Support]
    A --&gt; H[Parallel Execution]
    
    B --&gt; B1[KV Cache Reuse]
    D --&gt; D1[Dynamic Batching]
    G --&gt; G1[FP4/FP8/INT4/AWQ/GPTQ]
    H --&gt; H1[Tensor/Pipeline/Expert/Data]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
</section>
<section id="sec-architecture" class="level2">
<h2 class="anchored" data-anchor-id="sec-architecture" id="sec-architecture">Architecture Overview</h2>
<p>SGLang’s architecture consists of two main components:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-2-contents" aria-controls="callout-2" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Architecture Details
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-2" class="callout-2-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<section id="frontend-language" class="level3">
<h3 class="anchored" data-anchor-id="frontend-language" id="frontend-language">1. Frontend Language</h3>
<p>The frontend provides a Python-embedded DSL that simplifies LLM programming with:</p>
<ul>
<li>Intuitive syntax for generation tasks</li>
<li>Built-in primitives for common patterns</li>
<li>Automatic optimization of generation calls</li>
<li>Type safety and error handling</li>
</ul>
</section>
<section id="backend-runtime" class="level3">
<h3 class="anchored" data-anchor-id="backend-runtime" id="backend-runtime">2. Backend Runtime</h3>
<p>The backend proposes RadixAttention, a technique for automatic and efficient KV cache reuse across multiple LLM generation calls. The runtime includes:</p>
<ul>
<li>High-performance serving engine</li>
<li>Advanced memory management</li>
<li>Automatic optimization passes</li>
<li>Multi-GPU/multi-node support</li>
</ul>
</section>
</div>
</div>
</div>
</section>
<section id="sec-installation" class="level2">
<h2 class="anchored" data-anchor-id="sec-installation" id="sec-installation">Installation and Setup</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>System Requirements
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Python 3.8 or higher</li>
<li>CUDA 11.8+ (for GPU acceleration)</li>
<li>PyTorch 2.0+</li>
</ul>
</div>
</div>
</section>
<section id="basic-installation" class="level3">
<h3 class="anchored" data-anchor-id="basic-installation" id="basic-installation">Basic Installation</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>install.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1" data-filename="install.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install from PyPI</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install sglang</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Or install from source</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/sgl-project/sglang.git</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> sglang</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span></code></pre></div></div>
</div>
</section>
<section id="gpu-support" class="level3">
<h3 class="anchored" data-anchor-id="gpu-support" id="gpu-support">GPU Support</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># For CUDA support</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install sglang<span class="pp">[</span><span class="ss">cuda</span><span class="pp">]</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># For ROCm/AMD GPU support</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install sglang<span class="pp">[</span><span class="ss">rocm</span><span class="pp">]</span></span></code></pre></div></div>
</section>
<section id="docker-installation" class="level3">
<h3 class="anchored" data-anchor-id="docker-installation" id="docker-installation">Docker Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Pull official Docker image</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> pull lmsysorg/sglang:latest</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Run with GPU support</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> run <span class="at">--gpus</span> all <span class="at">-p</span> 30000:30000 lmsysorg/sglang:latest</span></code></pre></div></div>
</section>
</section>
<section id="sec-core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="sec-core-concepts" id="sec-core-concepts">Core Concepts</h2>
<section id="generation-functions" class="level3">
<h3 class="anchored" data-anchor-id="generation-functions" id="generation-functions">1. Generation Functions</h3>
<p>The core abstraction in SGLang is the generation function, which encapsulates prompts and generation logic:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>basic_generation.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4" data-filename="basic_generation.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sglang <span class="im">as</span> sgl</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> simple_chat(s, user_message):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.user(user_message)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.assistant(sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>))</span></code></pre></div></div>
</div>
</section>
<section id="state-management" class="level3">
<h3 class="anchored" data-anchor-id="state-management" id="state-management">2. State Management</h3>
<p>SGLang uses a state object <code>s</code> to track conversation history and manage generation context:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>state_management.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5" data-filename="state_management.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multi_turn_chat(s, messages):</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> msg <span class="kw">in</span> messages:</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.user(msg)</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.assistant(sgl.gen(<span class="st">"response"</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>))</span></code></pre></div></div>
</div>
</section>
<section id="control-primitives" class="level3">
<h3 class="anchored" data-anchor-id="control-primitives" id="control-primitives">3. Control Primitives</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">gen()</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">select()</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-3" role="tab" aria-controls="tabset-2-3" aria-selected="false" href="">fork()</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-4" role="tab" aria-controls="tabset-2-4" aria-selected="false" href="">image()</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<p>Generate text with specified constraints and parameters.</p>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<p>Choose from predefined options or multiple choice answers.</p>
</div>
<div id="tabset-2-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-3-tab">
<p>Create parallel execution branches for concurrent processing.</p>
</div>
<div id="tabset-2-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-4-tab">
<p>Process image inputs for vision-language model tasks.</p>
</div>
</div>
</div>
</section>
</section>
<section id="sec-frontend" class="level2">
<h2 class="anchored" data-anchor-id="sec-frontend" id="sec-frontend">Frontend Language Features</h2>
<section id="generation-primitives" class="level3">
<h3 class="anchored" data-anchor-id="generation-primitives" id="generation-primitives">Generation Primitives</h3>
<section id="basic-text-generation" class="level4">
<h4 class="anchored" data-anchor-id="basic-text-generation">Basic Text Generation</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>story_generator.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6" data-filename="story_generator.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> story_writer(s, theme):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Write a story about </span><span class="sc">{</span>theme<span class="sc">}</span><span class="ss">:</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"story"</span>, max_tokens<span class="op">=</span><span class="dv">500</span>, temperature<span class="op">=</span><span class="fl">0.7</span>)</span></code></pre></div></div>
</div>
</section>
<section id="structured-generation" class="level4">
<h4 class="anchored" data-anchor-id="structured-generation">Structured Generation</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>json_generator.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7" data-filename="json_generator.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> json_generator(s, query):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Generate JSON for: </span><span class="sc">{</span>query<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"json"</span>, max_tokens<span class="op">=</span><span class="dv">200</span>, regex<span class="op">=</span><span class="vs">r'</span><span class="ch">\{</span><span class="dv">.</span><span class="op">*</span><span class="ch">\}</span><span class="vs">'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="conditional-generation" class="level4">
<h4 class="anchored" data-anchor-id="conditional-generation">Conditional Generation</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>conditional_response.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8" data-filename="conditional_response.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> conditional_response(s, question, context):</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Context: </span><span class="sc">{</span>context<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Question: </span><span class="sc">{</span>question<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># First, determine if answerable</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Is this answerable? "</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"answerable"</span>, choices<span class="op">=</span>[<span class="st">"Yes"</span>, <span class="st">"No"</span>])</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> s[<span class="st">"answerable"</span>] <span class="op">==</span> <span class="st">"Yes"</span>:</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">Answer: "</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"answer"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">I don't have enough information to answer this question."</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="parallel-execution" class="level3">
<h3 class="anchored" data-anchor-id="parallel-execution" id="parallel-execution">Parallel Execution</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>parallel_processing.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9" data-filename="parallel_processing.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> parallel_summarization(s, documents):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Fork execution for parallel processing</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.fork([</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="kw">lambda</span>: summarize_doc(doc) <span class="cf">for</span> doc <span class="kw">in</span> documents</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Combine results</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    summaries <span class="op">=</span> [s[<span class="ss">f"summary_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>] <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(documents))]</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> summaries</span></code></pre></div></div>
</div>
</section>
<section id="template-system" class="level3">
<h3 class="anchored" data-anchor-id="template-system" id="template-system">Template System</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>email_template.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10" data-filename="email_template.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> email_generator(s, recipient, subject, tone<span class="op">=</span><span class="st">"professional"</span>):</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.system(<span class="ss">f"Write emails in a </span><span class="sc">{</span>tone<span class="sc">}</span><span class="ss"> tone."</span>)</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"To: </span><span class="sc">{</span>recipient<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Subject: </span><span class="sc">{</span>subject<span class="sc">}</span><span class="ch">\n\n</span><span class="ss">"</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"body"</span>, max_tokens<span class="op">=</span><span class="dv">300</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-backend" class="level2">
<h2 class="anchored" data-anchor-id="sec-backend" id="sec-backend">Backend Runtime</h2>
<section id="radixattention" class="level3">
<h3 class="anchored" data-anchor-id="radixattention" id="radixattention">RadixAttention</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>RadixAttention Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>RadixAttention structures and automates the reuse of Key-Value (KV) caches during runtime by storing them in a radix tree data structure.</p>
</div>
</div>
<p>This enables:</p>
<ul>
<li><strong>Prefix Sharing</strong>: Common prompt prefixes are cached and reused</li>
<li><strong>Memory Efficiency</strong>: Reduced memory usage through intelligent caching<br>
</li>
<li><strong>Speed Improvements</strong>: Faster generation through cache hits</li>
</ul>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A[Input Prompts] --&gt; B[Radix Tree]
    B --&gt; C[Shared Prefixes]
    B --&gt; D[Unique Suffixes]
    C --&gt; E[KV Cache Reuse]
    D --&gt; F[New Computation]
    E --&gt; G[Performance Boost]
    F --&gt; G
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="continuous-batching" class="level3">
<h3 class="anchored" data-anchor-id="continuous-batching" id="continuous-batching">Continuous Batching</h3>
<p>The runtime implements continuous batching to:</p>
<ul>
<li>Process multiple requests simultaneously</li>
<li>Dynamically adjust batch sizes</li>
<li>Optimize GPU utilization</li>
</ul>
</section>
<section id="speculative-decoding" class="level3">
<h3 class="anchored" data-anchor-id="speculative-decoding" id="speculative-decoding">Speculative Decoding</h3>
<p>Acceleration technique that:</p>
<ul>
<li>Predicts multiple tokens ahead</li>
<li>Verifies predictions in parallel</li>
<li>Falls back to standard decoding when needed</li>
</ul>
</section>
</section>
<section id="sec-basic-examples" class="level2">
<h2 class="anchored" data-anchor-id="sec-basic-examples" id="sec-basic-examples">Basic Usage Examples</h2>
<section id="simple-text-generation" class="level3">
<h3 class="anchored" data-anchor-id="simple-text-generation" id="simple-text-generation">1. Simple Text Generation</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>poem_generator.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11" data-filename="poem_generator.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sglang <span class="im">as</span> sgl</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Set backend</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>sgl.set_default_backend(sgl.RuntimeEndpoint(<span class="st">"http://localhost:30000"</span>))</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_poem(s, topic):</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Write a haiku about </span><span class="sc">{</span>topic<span class="sc">}</span><span class="ss">:</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"poem"</span>, max_tokens<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Execute</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> generate_poem(<span class="st">"spring"</span>)</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result[<span class="st">"poem"</span>])</span></code></pre></div></div>
</div>
</section>
<section id="multi-step-reasoning" class="level3">
<h3 class="anchored" data-anchor-id="multi-step-reasoning" id="multi-step-reasoning">2. Multi-step Reasoning</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>math_solver.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12" data-filename="math_solver.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> math_solver(s, problem):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Problem: </span><span class="sc">{</span>problem<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Let me solve this step by step.</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Step 1: "</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"step1"</span>, max_tokens<span class="op">=</span><span class="dv">50</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>)</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">Step 2: "</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"step2"</span>, max_tokens<span class="op">=</span><span class="dv">50</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>)</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">Therefore, the answer is: "</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"answer"</span>, max_tokens<span class="op">=</span><span class="dv">20</span>)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> math_solver(<span class="st">"What is 15</span><span class="sc">% o</span><span class="st">f 240?"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="json-structured-output" class="level3">
<h3 class="anchored" data-anchor-id="json-structured-output" id="json-structured-output">3. JSON Structured Output</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>info_extractor.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13" data-filename="info_extractor.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_info(s, text):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Extract key information from this text:</span><span class="ch">\n</span><span class="sc">{</span>text<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Output as JSON:</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"info"</span>, </span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        max_tokens<span class="op">=</span><span class="dv">200</span>, </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        regex<span class="op">=</span><span class="vs">r'</span><span class="ch">\{</span><span class="pp">[^}]</span><span class="op">*</span><span class="vs">"name"</span><span class="pp">[^}]</span><span class="op">*</span><span class="vs">"age"</span><span class="pp">[^}]</span><span class="op">*</span><span class="vs">"location"</span><span class="pp">[^}]</span><span class="op">*</span><span class="ch">\}</span><span class="vs">'</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> extract_info(<span class="st">"John Smith is 30 years old and lives in New York."</span>)</span></code></pre></div></div>
</div>
</section>
<section id="role-playing-conversation" class="level3">
<h3 class="anchored" data-anchor-id="role-playing-conversation" id="role-playing-conversation">4. Role-playing Conversation</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>roleplay.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14" data-filename="roleplay.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> roleplay_chat(s, character, user_input):</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.system(<span class="ss">f"You are </span><span class="sc">{</span>character<span class="sc">}</span><span class="ss">. Stay in character."</span>)</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.user(user_input)</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.assistant(sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">150</span>))</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> roleplay_chat(<span class="st">"a wise old wizard"</span>, <span class="st">"How do I learn magic?"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-advanced-patterns" class="level2">
<h2 class="anchored" data-anchor-id="sec-advanced-patterns" id="sec-advanced-patterns">Advanced Programming Patterns</h2>
<section id="chain-of-thought-reasoning" class="level3">
<h3 class="anchored" data-anchor-id="chain-of-thought-reasoning" id="chain-of-thought-reasoning">1. Chain of Thought Reasoning</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>cot_reasoning.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15" data-filename="cot_reasoning.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cot_reasoning(s, question):</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Question: </span><span class="sc">{</span>question<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Let me think through this step by step:</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>):</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Step </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: "</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="ss">f"step_</span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>)</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Final Answer: "</span></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"answer"</span>, max_tokens<span class="op">=</span><span class="dv">50</span>)</span></code></pre></div></div>
</div>
</section>
<section id="self-correction-loop" class="level3">
<h3 class="anchored" data-anchor-id="self-correction-loop" id="self-correction-loop">2. Self-Correction Loop</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>self_correction.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16" data-filename="self_correction.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> self_correct(s, task, max_iterations<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Task: </span><span class="sc">{</span>task<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(max_iterations):</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Attempt </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: "</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="ss">f"attempt_</span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">"</span>, max_tokens<span class="op">=</span><span class="dv">200</span>)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">Is this correct? "</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"correct"</span>, choices<span class="op">=</span>[<span class="st">"Yes"</span>, <span class="st">"No"</span>])</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> s[<span class="st">"correct"</span>] <span class="op">==</span> <span class="st">"Yes"</span>:</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">Let me try again.</span><span class="ch">\n</span><span class="st">"</span></span></code></pre></div></div>
</div>
</section>
<section id="tree-search-generation" class="level3">
<h3 class="anchored" data-anchor-id="tree-search-generation" id="tree-search-generation">3. Tree Search Generation</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>tree_search.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17" data-filename="tree_search.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> tree_search_story(s, prompt, branches<span class="op">=</span><span class="dv">3</span>, depth<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> prompt</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> explore_branch(state, current_depth):</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> current_depth <span class="op">&gt;=</span> depth:</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        candidates <span class="op">=</span> []</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(branches):</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>            state <span class="op">+=</span> sgl.gen(<span class="ss">f"branch_</span><span class="sc">{</span>current_depth<span class="sc">}</span><span class="ss">_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>, max_tokens<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>            candidates.append(state[<span class="ss">f"branch_</span><span class="sc">{</span>current_depth<span class="sc">}</span><span class="ss">_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>])</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select best candidate (simplified selection)</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        best_idx <span class="op">=</span> <span class="dv">0</span>  <span class="co"># In practice, use a scoring function</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        state <span class="op">+=</span> candidates[best_idx]</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        explore_branch(state, current_depth <span class="op">+</span> <span class="dv">1</span>)</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    explore_branch(s, <span class="dv">0</span>)</span></code></pre></div></div>
</div>
</section>
<section id="parallel-agent-collaboration" class="level3">
<h3 class="anchored" data-anchor-id="parallel-agent-collaboration" id="parallel-agent-collaboration">4. Parallel Agent Collaboration</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>multi_agent.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18" data-filename="multi_agent.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multi_agent_discussion(s, topic, agents):</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Topic: </span><span class="sc">{</span>topic<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Discussion:</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize agents</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    agent_states <span class="op">=</span> {}</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> agent <span class="kw">in</span> agents:</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>        agent_states[agent] <span class="op">=</span> sgl.fork(<span class="kw">lambda</span>: agent_response(agent, topic))</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Simulate rounds of discussion</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> <span class="bu">round</span> <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>):</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"</span><span class="ch">\n</span><span class="ss">Round </span><span class="sc">{</span><span class="bu">round</span> <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss">:</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> agent <span class="kw">in</span> agents:</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="ss">f"</span><span class="sc">{</span>agent<span class="sc">}</span><span class="ss">: "</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> sgl.gen(<span class="ss">f"</span><span class="sc">{</span>agent<span class="sc">}</span><span class="ss">_round_</span><span class="sc">{</span><span class="bu">round</span><span class="sc">}</span><span class="ss">"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">"</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-performance" class="level2">
<h2 class="anchored" data-anchor-id="sec-performance" id="sec-performance">Performance Optimization</h2>
<section id="batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing" id="batch-processing">1. Batch Processing</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Optimization Strategy
</div>
</div>
<div class="callout-body-container callout-body">
<p>Process multiple inputs in a single batch for maximum throughput efficiency.</p>
</div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>batch_processing.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19" data-filename="batch_processing.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Process multiple inputs in a single batch</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_classification(s, texts):</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> text <span class="kw">in</span> texts:</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Classify: </span><span class="sc">{</span>text<span class="sc">}</span><span class="ch">\n</span><span class="ss">Category: "</span></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"category"</span>, choices<span class="op">=</span>[<span class="st">"positive"</span>, <span class="st">"negative"</span>, <span class="st">"neutral"</span>])</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        results.append(s[<span class="st">"category"</span>])</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Execute with batching enabled</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>sgl.set_default_backend(</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    sgl.RuntimeEndpoint(<span class="st">"http://localhost:30000"</span>, batch_size<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="caching-strategies" class="level3">
<h3 class="anchored" data-anchor-id="caching-strategies" id="caching-strategies">2. Caching Strategies</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>caching.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20" data-filename="caching.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable aggressive caching for repeated patterns</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cached_qa(s, question, context):</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use consistent formatting for better cache hits</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Context: </span><span class="sc">{</span>context<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Question: </span><span class="sc">{</span>question<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Answer: "</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"answer"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, temperature<span class="op">=</span><span class="fl">0.0</span>)  <span class="co"># Deterministic for caching</span></span></code></pre></div></div>
</div>
</section>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">3. Memory Management</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>memory_management.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21" data-filename="memory_management.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimize memory usage for long conversations</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_chat(s, messages, max_context_length<span class="op">=</span><span class="dv">2000</span>):</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Truncate context to stay within limits</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>    total_length <span class="op">=</span> <span class="bu">sum</span>(<span class="bu">len</span>(msg) <span class="cf">for</span> msg <span class="kw">in</span> messages)</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> total_length <span class="op">&gt;</span> max_context_length:</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>        messages <span class="op">=</span> messages[<span class="op">-</span>(max_context_length <span class="op">//</span> <span class="dv">100</span>):]</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> msg <span class="kw">in</span> messages:</span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.user(msg)</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.assistant(sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">150</span>))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-vision-language" class="level2">
<h2 class="anchored" data-anchor-id="sec-vision-language" id="sec-vision-language">Vision-Language Model Support</h2>
<section id="image-understanding" class="level3">
<h3 class="anchored" data-anchor-id="image-understanding" id="image-understanding">1. Image Understanding</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>image_description.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22" data-filename="image_description.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> describe_image(s, image_path, detail_level<span class="op">=</span><span class="st">"medium"</span>):</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.image(image_path)</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Describe this image in </span><span class="sc">{</span>detail_level<span class="sc">}</span><span class="ss"> detail:</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"description"</span>, max_tokens<span class="op">=</span><span class="dv">300</span>)</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> describe_image(<span class="st">"/path/to/image.jpg"</span>, <span class="st">"high"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="visual-question-answering" class="level3">
<h3 class="anchored" data-anchor-id="visual-question-answering" id="visual-question-answering">2. Visual Question Answering</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>visual_qa.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23" data-filename="visual_qa.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> visual_qa(s, image_path, question):</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.image(image_path)</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Question: </span><span class="sc">{</span>question<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Answer: "</span></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"answer"</span>, max_tokens<span class="op">=</span><span class="dv">150</span>)</span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> visual_qa(<span class="st">"/path/to/chart.png"</span>, <span class="st">"What is the highest value in this chart?"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="multi-modal-reasoning" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-reasoning" id="multi-modal-reasoning">3. Multi-modal Reasoning</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>multimodal_analysis.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24" data-filename="multimodal_analysis.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multimodal_analysis(s, image_path, context):</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="ss">f"Context: </span><span class="sc">{</span>context<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.image(image_path)</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Based on the context and image, analyze:</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"1. Visual elements: "</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"visual"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>)</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">2. Relationship to context: "</span></span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"relationship"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, stop<span class="op">=</span><span class="st">"</span><span class="ch">\n</span><span class="st">"</span>)</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"</span><span class="ch">\n</span><span class="st">3. Conclusion: "</span></span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> sgl.gen(<span class="st">"conclusion"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-deployment" class="level2">
<h2 class="anchored" data-anchor-id="sec-deployment" id="sec-deployment">Deployment and Serving</h2>
<section id="starting-a-server" class="level3">
<h3 class="anchored" data-anchor-id="starting-a-server" id="starting-a-server">1. Starting a Server</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>start_server.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25" data-filename="start_server.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic server startup</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> sglang.launch_server <span class="at">--model-path</span> meta-llama/Llama-2-7b-chat-hf <span class="at">--port</span> 30000</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="co"># With specific configurations</span></span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> sglang.launch_server <span class="dt">\</span></span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a>    <span class="at">--model-path</span> meta-llama/Llama-2-7b-chat-hf <span class="dt">\</span></span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a>    <span class="at">--port</span> 30000 <span class="dt">\</span></span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">--host</span> 0.0.0.0 <span class="dt">\</span></span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a>    <span class="at">--tp-size</span> 2 <span class="dt">\</span></span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a>    <span class="at">--mem-fraction-static</span> 0.8</span></code></pre></div></div>
</div>
</section>
<section id="client-configuration" class="level3">
<h3 class="anchored" data-anchor-id="client-configuration" id="client-configuration">2. Client Configuration</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>client_setup.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26" data-filename="client_setup.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sglang <span class="im">as</span> sgl</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Connect to local server</span></span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>backend <span class="op">=</span> sgl.RuntimeEndpoint(<span class="st">"http://localhost:30000"</span>)</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a>sgl.set_default_backend(backend)</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Connect to remote server with authentication</span></span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a>backend <span class="op">=</span> sgl.RuntimeEndpoint(</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"https://api.example.com"</span>,</span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a>    headers<span class="op">=</span>{<span class="st">"Authorization"</span>: <span class="st">"Bearer your-token"</span>}</span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="load-balancing" class="level3">
<h3 class="anchored" data-anchor-id="load-balancing" id="load-balancing">3. Load Balancing</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>load_balancing.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb27" data-filename="load_balancing.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple endpoints for load distribution</span></span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a>endpoints <span class="op">=</span> [</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"http://server1:30000"</span>,</span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"http://server2:30000"</span>, </span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"http://server3:30000"</span></span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a>backend <span class="op">=</span> sgl.LoadBalancedEndpoint(endpoints)</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>sgl.set_default_backend(backend)</span></code></pre></div></div>
</div>
</section>
<section id="production-deployment" class="level3">
<h3 class="anchored" data-anchor-id="production-deployment" id="production-deployment">4. Production Deployment</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>docker-compose.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb28" data-filename="docker-compose.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Docker Compose example</span></span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a><span class="fu">version</span><span class="kw">:</span><span class="at"> </span><span class="st">'3.8'</span></span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a><span class="fu">services</span><span class="kw">:</span></span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">sglang-server</span><span class="kw">:</span></span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> lmsysorg/sglang:latest</span></span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"30000:30000"</span></span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> MODEL_PATH=meta-llama/Llama-2-7b-chat-hf</span></span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> PORT=30000</span></span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> TP_SIZE=2</span></span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">deploy</span><span class="kw">:</span></span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">reservations</span><span class="kw">:</span></span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">devices</span><span class="kw">:</span></span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> </span><span class="fu">driver</span><span class="kw">:</span><span class="at"> nvidia</span></span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">count</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">capabilities</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">gpu</span><span class="kw">]</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-best-practices" class="level2">
<h2 class="anchored" data-anchor-id="sec-best-practices" id="sec-best-practices">Best Practices</h2>
<section id="prompt-engineering" class="level3">
<h3 class="anchored" data-anchor-id="prompt-engineering" id="prompt-engineering">1. Prompt Engineering</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>prompt_engineering.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb29" data-filename="prompt_engineering.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use clear, structured prompts</span></span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> good_prompt(s, task, examples):</span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Task: "</span> <span class="op">+</span> task <span class="op">+</span> <span class="st">"</span><span class="ch">\n\n</span><span class="st">"</span></span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Provide examples</span></span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, example <span class="kw">in</span> <span class="bu">enumerate</span>(examples):</span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Example </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">:</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Input: </span><span class="sc">{</span>example[<span class="st">'input'</span>]<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span></span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Output: </span><span class="sc">{</span>example[<span class="st">'output'</span>]<span class="sc">}</span><span class="ch">\n\n</span><span class="ss">"</span></span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Now, complete this task:</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Input: "</span> <span class="op">+</span> sgl.gen(<span class="st">"input"</span>) <span class="op">+</span> <span class="st">"</span><span class="ch">\n</span><span class="st">"</span></span>
<span id="cb29-14"><a href="#cb29-14" aria-hidden="true" tabindex="-1"></a>    s <span class="op">+=</span> <span class="st">"Output: "</span> <span class="op">+</span> sgl.gen(<span class="st">"output"</span>, max_tokens<span class="op">=</span><span class="dv">200</span>)</span></code></pre></div></div>
</div>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">2. Error Handling</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>error_handling.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb30" data-filename="error_handling.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> robust_generation(s, prompt):</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> prompt</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, timeout<span class="op">=</span><span class="dv">30</span>)</span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Validate output</span></span>
<span id="cb30-8"><a href="#cb30-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(s[<span class="st">"response"</span>].strip()) <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb30-9"><a href="#cb30-9" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="st">"Please provide a more detailed response: "</span></span>
<span id="cb30-10"><a href="#cb30-10" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> sgl.gen(<span class="st">"retry"</span>, max_tokens<span class="op">=</span><span class="dv">150</span>)</span>
<span id="cb30-11"><a href="#cb30-11" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb30-12"><a href="#cb30-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> sgl.GenerationError <span class="im">as</span> e:</span>
<span id="cb30-13"><a href="#cb30-13" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="ss">f"Generation failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">. Using fallback."</span></span>
<span id="cb30-14"><a href="#cb30-14" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"I apologize, but I cannot process this request."</span></span></code></pre></div></div>
</div>
</section>
<section id="testing-strategies" class="level3">
<h3 class="anchored" data-anchor-id="testing-strategies" id="testing-strategies">3. Testing Strategies</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>testing.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb31" data-filename="testing.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> unittest</span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sglang <span class="im">as</span> sgl</span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TestSGLangFunctions(unittest.TestCase):</span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setUp(<span class="va">self</span>):</span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use mock backend for testing</span></span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backend <span class="op">=</span> sgl.MockBackend()</span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a>        sgl.set_default_backend(<span class="va">self</span>.backend)</span>
<span id="cb31-9"><a href="#cb31-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-10"><a href="#cb31-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_simple_generation(<span class="va">self</span>):</span>
<span id="cb31-11"><a href="#cb31-11" aria-hidden="true" tabindex="-1"></a>        <span class="at">@sgl.function</span></span>
<span id="cb31-12"><a href="#cb31-12" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> test_func(s):</span>
<span id="cb31-13"><a href="#cb31-13" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="st">"Hello"</span></span>
<span id="cb31-14"><a href="#cb31-14" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb31-15"><a href="#cb31-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb31-16"><a href="#cb31-16" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> test_func()</span>
<span id="cb31-17"><a href="#cb31-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.assertIn(<span class="st">"response"</span>, result)</span>
<span id="cb31-18"><a href="#cb31-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-19"><a href="#cb31-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_structured_output(<span class="va">self</span>):</span>
<span id="cb31-20"><a href="#cb31-20" aria-hidden="true" tabindex="-1"></a>        <span class="at">@sgl.function</span></span>
<span id="cb31-21"><a href="#cb31-21" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> json_test(s):</span>
<span id="cb31-22"><a href="#cb31-22" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> <span class="st">"Generate JSON: "</span></span>
<span id="cb31-23"><a href="#cb31-23" aria-hidden="true" tabindex="-1"></a>            s <span class="op">+=</span> sgl.gen(<span class="st">"json"</span>, regex<span class="op">=</span><span class="vs">r'</span><span class="ch">\{</span><span class="dv">.</span><span class="op">*</span><span class="ch">\}</span><span class="vs">'</span>)</span>
<span id="cb31-24"><a href="#cb31-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb31-25"><a href="#cb31-25" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> json_test()</span>
<span id="cb31-26"><a href="#cb31-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.assertTrue(result[<span class="st">"json"</span>].startswith(<span class="st">"{"</span>))</span></code></pre></div></div>
</div>
</section>
<section id="monitoring-and-logging" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-and-logging" id="monitoring-and-logging">4. Monitoring and Logging</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>monitoring.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb32" data-filename="monitoring.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Set up logging</span></span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a>logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> monitored_generation(s, prompt):</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> prompt</span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>        duration <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>        logger.info(<span class="ss">f"Generation completed in </span><span class="sc">{</span>duration<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb32-18"><a href="#cb32-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-19"><a href="#cb32-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb32-20"><a href="#cb32-20" aria-hidden="true" tabindex="-1"></a>        logger.error(<span class="ss">f"Generation failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb32-21"><a href="#cb32-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="sec-comparisons" class="level2">
<h2 class="anchored" data-anchor-id="sec-comparisons" id="sec-comparisons">Comparison with Other Frameworks</h2>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-3-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-1" role="tab" aria-controls="tabset-3-1" aria-selected="true" href="">SGLang vs.&nbsp;LMQL</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-2" role="tab" aria-controls="tabset-3-2" aria-selected="false" href="">SGLang vs.&nbsp;Guidance</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-3" role="tab" aria-controls="tabset-3-3" aria-selected="false" href="">SGLang vs.&nbsp;LangChain</a></li></ul>
<div class="tab-content">
<div id="tabset-3-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-3-1-tab">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Feature</th>
<th>SGLang</th>
<th>LMQL</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Performance</td>
<td>High (RadixAttention)</td>
<td>Medium</td>
</tr>
<tr class="even">
<td>Python Integration</td>
<td>Native embedding</td>
<td>External DSL</td>
</tr>
<tr class="odd">
<td>Caching</td>
<td>Automatic</td>
<td>Manual</td>
</tr>
<tr class="even">
<td>Parallelism</td>
<td>Built-in</td>
<td>Limited</td>
</tr>
</tbody>
</table>
</div>
<div id="tabset-3-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-2-tab">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Feature</th>
<th>SGLang</th>
<th>Guidance</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Runtime Optimization</td>
<td>Yes</td>
<td>Limited</td>
</tr>
<tr class="even">
<td>Structured Output</td>
<td>Advanced</td>
<td>Basic</td>
</tr>
<tr class="odd">
<td>Vision Support</td>
<td>Yes</td>
<td>No</td>
</tr>
<tr class="even">
<td>Deployment</td>
<td>Production-ready</td>
<td>Research-focused</td>
</tr>
</tbody>
</table>
</div>
<div id="tabset-3-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-3-tab">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Feature</th>
<th>SGLang</th>
<th>LangChain</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Level</td>
<td>Low-level control</td>
<td>High-level abstractions</td>
</tr>
<tr class="even">
<td>Performance</td>
<td>Optimized runtime</td>
<td>Variable</td>
</tr>
<tr class="odd">
<td>Flexibility</td>
<td>High</td>
<td>Medium</td>
</tr>
<tr class="even">
<td>Learning Curve</td>
<td>Moderate</td>
<td>Low</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</section>
<section id="sec-troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="sec-troubleshooting" id="sec-troubleshooting">Troubleshooting</h2>
<section id="common-issues" class="level3">
<h3 class="anchored" data-anchor-id="common-issues" id="common-issues">Common Issues</h3>
<section id="connection-problems" class="level4">
<h4 class="anchored" data-anchor-id="connection-problems">1. Connection Problems</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>debug_connection.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb33" data-filename="debug_connection.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Debug connection issues</span></span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a>    backend <span class="op">=</span> sgl.RuntimeEndpoint(<span class="st">"http://localhost:30000"</span>)</span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a>    backend.health_check()</span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Server is healthy"</span>)</span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">ConnectionError</span>:</span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Cannot connect to server. Check if it's running."</span>)</span></code></pre></div></div>
</div>
</section>
<section id="memory-issues" class="level4">
<h4 class="anchored" data-anchor-id="memory-issues">2. Memory Issues</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>memory_debug.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb34" data-filename="memory_debug.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb34-1"><a href="#cb34-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor memory usage</span></span>
<span id="cb34-2"><a href="#cb34-2" aria-hidden="true" tabindex="-1"></a><span class="ex">nvidia-smi</span></span>
<span id="cb34-3"><a href="#cb34-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-4"><a href="#cb34-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Adjust memory settings</span></span>
<span id="cb34-5"><a href="#cb34-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> sglang.launch_server <span class="dt">\</span></span>
<span id="cb34-6"><a href="#cb34-6" aria-hidden="true" tabindex="-1"></a>    <span class="at">--model-path</span> your-model <span class="dt">\</span></span>
<span id="cb34-7"><a href="#cb34-7" aria-hidden="true" tabindex="-1"></a>    <span class="at">--mem-fraction-static</span> 0.6  <span class="co"># Reduce if getting OOM</span></span></code></pre></div></div>
</div>
</section>
<section id="generation-timeouts" class="level4">
<h4 class="anchored" data-anchor-id="generation-timeouts">3. Generation Timeouts</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>timeout_handling.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb35" data-filename="timeout_handling.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timeout_handling(s, prompt):</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> prompt</span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"response"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>, timeout<span class="op">=</span><span class="dv">30</span>)</span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> sgl.<span class="pp">TimeoutError</span>:</span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="st">"Request timed out. Please try again."</span></span></code></pre></div></div>
</div>
</section>
<section id="performance-issues" class="level4">
<h4 class="anchored" data-anchor-id="performance-issues">4. Performance Issues</h4>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>performance_debug.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb36" data-filename="performance_debug.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb36-1"><a href="#cb36-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable performance profiling</span></span>
<span id="cb36-2"><a href="#cb36-2" aria-hidden="true" tabindex="-1"></a>sgl.set_debug_mode(<span class="va">True</span>)</span>
<span id="cb36-3"><a href="#cb36-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb36-4"><a href="#cb36-4" aria-hidden="true" tabindex="-1"></a><span class="at">@sgl.function</span></span>
<span id="cb36-5"><a href="#cb36-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> profiled_function(s, <span class="bu">input</span>):</span>
<span id="cb36-6"><a href="#cb36-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> sgl.profile(<span class="st">"generation"</span>):</span>
<span id="cb36-7"><a href="#cb36-7" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> <span class="bu">input</span></span>
<span id="cb36-8"><a href="#cb36-8" aria-hidden="true" tabindex="-1"></a>        s <span class="op">+=</span> sgl.gen(<span class="st">"output"</span>, max_tokens<span class="op">=</span><span class="dv">100</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="debugging-tips" class="level3">
<h3 class="anchored" data-anchor-id="debugging-tips" id="debugging-tips">Debugging Tips</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Debugging Checklist
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><p><strong>Enable Verbose Logging</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a>logging.getLogger(<span class="st">"sglang"</span>).setLevel(logging.DEBUG)</span></code></pre></div></div></li>
<li><p><strong>Check Server Logs</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb38"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb38-1"><a href="#cb38-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Server logs show detailed execution info</span></span>
<span id="cb38-2"><a href="#cb38-2" aria-hidden="true" tabindex="-1"></a><span class="fu">tail</span> <span class="at">-f</span> sglang_server.log</span></code></pre></div></div></li>
<li><p><strong>Use Mock Backend for Testing</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Test logic without actual model calls</span></span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>sgl.set_default_backend(sgl.MockBackend())</span></code></pre></div></div></li>
</ol>
</div>
</div>
</section>
</section>
<section id="sec-contributing" class="level2">
<h2 class="anchored" data-anchor-id="sec-contributing" id="sec-contributing">Contributing</h2>
<section id="development-setup" class="level3">
<h3 class="anchored" data-anchor-id="development-setup" id="development-setup">Development Setup</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>dev_setup.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb40" data-filename="dev_setup.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb40-1"><a href="#cb40-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clone repository</span></span>
<span id="cb40-2"><a href="#cb40-2" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/sgl-project/sglang.git</span>
<span id="cb40-3"><a href="#cb40-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> sglang</span>
<span id="cb40-4"><a href="#cb40-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-5"><a href="#cb40-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create development environment</span></span>
<span id="cb40-6"><a href="#cb40-6" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> create <span class="at">-n</span> sglang-dev python=3.9</span>
<span id="cb40-7"><a href="#cb40-7" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> activate sglang-dev</span>
<span id="cb40-8"><a href="#cb40-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb40-9"><a href="#cb40-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Install in development mode</span></span>
<span id="cb40-10"><a href="#cb40-10" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span>
<span id="cb40-11"><a href="#cb40-11" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-r</span> requirements-dev.txt</span></code></pre></div></div>
</div>
</section>
<section id="running-tests" class="level3">
<h3 class="anchored" data-anchor-id="running-tests" id="running-tests">Running Tests</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>run_tests.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb41" data-filename="run_tests.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Run all tests</span></span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> pytest tests/</span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Run specific test category</span></span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> pytest tests/test_frontend.py</span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-7"><a href="#cb41-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Run with coverage</span></span>
<span id="cb41-8"><a href="#cb41-8" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> pytest <span class="at">--cov</span><span class="op">=</span>sglang tests/</span></code></pre></div></div>
</div>
</section>
<section id="code-style" class="level3">
<h3 class="anchored" data-anchor-id="code-style" id="code-style">Code Style</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>code_style.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb42" data-filename="code_style.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb42-1"><a href="#cb42-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Format code</span></span>
<span id="cb42-2"><a href="#cb42-2" aria-hidden="true" tabindex="-1"></a><span class="ex">black</span> sglang/</span>
<span id="cb42-3"><a href="#cb42-3" aria-hidden="true" tabindex="-1"></a><span class="ex">isort</span> sglang/</span>
<span id="cb42-4"><a href="#cb42-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb42-5"><a href="#cb42-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Check style</span></span>
<span id="cb42-6"><a href="#cb42-6" aria-hidden="true" tabindex="-1"></a><span class="ex">flake8</span> sglang/</span>
<span id="cb42-7"><a href="#cb42-7" aria-hidden="true" tabindex="-1"></a><span class="ex">mypy</span> sglang/</span></code></pre></div></div>
</div>
</section>
<section id="submitting-prs" class="level3">
<h3 class="anchored" data-anchor-id="submitting-prs" id="submitting-prs">Submitting PRs</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Pull Request Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li>Fork the repository</li>
<li>Create a feature branch</li>
<li>Add tests for new functionality</li>
<li>Update documentation</li>
<li>Submit pull request with clear description</li>
</ol>
</div>
</div>
</section>
</section>
<section id="sec-resources" class="level2">
<h2 class="anchored" data-anchor-id="sec-resources" id="sec-resources">Resources</h2>
<section id="official-documentation" class="level3">
<h3 class="anchored" data-anchor-id="official-documentation" id="official-documentation">Official Documentation</h3>
<ul>
<li><a href="https://docs.sglang.ai/">SGLang Documentation</a></li>
<li><a href="https://github.com/sgl-project/sglang">GitHub Repository</a></li>
<li><a href="https://arxiv.org/abs/2312.07104">Paper: SGLang: Efficient Execution of Structured Language Model Programs</a></li>
</ul>
</section>
<section id="community" class="level3">
<h3 class="anchored" data-anchor-id="community" id="community">Community</h3>
<ul>
<li><a href="https://github.com/sgl-project/sglang/discussions">GitHub Discussions</a></li>
<li><a href="https://discord.gg/sglang">Discord Server</a></li>
<li><a href="https://twitter.com/sglang_ai">Twitter Updates</a></li>
</ul>
</section>
<section id="examples-and-tutorials" class="level3">
<h3 class="anchored" data-anchor-id="examples-and-tutorials" id="examples-and-tutorials">Examples and Tutorials</h3>
<ul>
<li><a href="https://github.com/sgl-project/sglang/tree/main/examples">Official Examples</a></li>
<li><a href="https://github.com/sgl-project/sglang/tree/main/notebooks">Tutorial Notebooks</a></li>
<li><a href="https://github.com/sgl-project/sglang/tree/main/cookbook">Cookbook Recipes</a></li>
</ul>
</section>
<section id="related-projects" class="level3">
<h3 class="anchored" data-anchor-id="related-projects" id="related-projects">Related Projects</h3>
<ul>
<li><a href="https://sky.cs.berkeley.edu/project/sglang/">UC Berkeley Sky Computing Lab</a></li>
<li><a href="https://lmsys.org/">LMSYS Organization</a></li>
</ul>
<hr>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Final Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>This guide covers the essential aspects of SGLang, from basic concepts to advanced usage patterns. As SGLang continues to evolve rapidly, always refer to the official documentation for the most current information and updates.</p>
<p>For questions and support, please visit the <a href="https://github.com/sgl-project/sglang/discussions">GitHub Discussions</a> or check the official documentation.</p>
</div>
</div>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Mamba Transformers: Implementation and Theory]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-code/</guid>
      <pubDate>Sat, 23 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-mamba-transformers-implementation-and-theory" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-code/mambac.png" class="img-fluid"></p>
<section id="sec-introduction" class="level2">
<h2 class="anchored" data-anchor-id="sec-introduction" id="sec-introduction">Introduction to Mamba</h2>
<p>Mamba is a revolutionary architecture that addresses the quadratic complexity problem of traditional transformers through selective state space models (SSMs). Unlike transformers that use attention mechanisms, Mamba processes sequences with linear complexity while maintaining comparable or superior performance.</p>
<section id="key-advantages" class="level3">
<h3 class="anchored" data-anchor-id="key-advantages" id="key-advantages">Key Advantages</h3>
<ul>
<li><strong>Linear Complexity</strong>: <span class="math inline">\(O(L)\)</span> instead of <span class="math inline">\(O(L^2)\)</span> for sequence length <span class="math inline">\(L\)</span></li>
<li><strong>Selective Mechanism</strong>: Dynamic parameter adjustment based on input</li>
<li><strong>Hardware Efficiency</strong>: Better memory usage and parallelization</li>
<li><strong>Long Context</strong>: Can handle much longer sequences effectively</li>
</ul>
</section>
<section id="architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A[Input] --&gt; B[Embedding]
    B --&gt; C[Mamba Blocks]
    C --&gt; D[Output Projection]
    D --&gt; E[Logits]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
</section>
<section id="mathematical-foundation" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-foundation" id="mathematical-foundation">Mathematical Foundation</h2>
<section id="state-space-models-ssms" class="level3">
<h3 class="anchored" data-anchor-id="state-space-models-ssms" id="state-space-models-ssms">State Space Models (SSMs)</h3>
<p>The core of Mamba is based on continuous-time state space models:</p>
<p><span class="math display">\[
\frac{dx}{dt} = Ax(t) + Bu(t)
\]</span></p>
<p><span class="math display">\[
y(t) = Cx(t) + Du(t)
\]</span></p>
<p>Discretized version:</p>
<p><span class="math display">\[
x_k = \bar{A}x_{k-1} + \bar{B}u_k
\]</span></p>
<p><span class="math display">\[
y_k = Cx_k + Du_k
\]</span></p>
<p>Where:</p>
<ul>
<li><span class="math inline">\(\bar{A} = \exp(\Delta A)\)</span> (matrix exponential)</li>
<li><span class="math inline">\(\bar{B} = (\Delta A)^{-1}(\bar{A} - I)\Delta B\)</span></li>
<li><span class="math inline">\(\Delta\)</span> is the discretization step size</li>
</ul>
</section>
<section id="selective-mechanism" class="level3">
<h3 class="anchored" data-anchor-id="selective-mechanism" id="selective-mechanism">Selective Mechanism</h3>
<p>Mamba introduces selectivity by making <span class="math inline">\(B\)</span>, <span class="math inline">\(C\)</span>, and <span class="math inline">\(\Delta\)</span> input-dependent:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a>B <span class="op">=</span> Linear_B(x)    <span class="co"># Input-dependent B matrix</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>C <span class="op">=</span> Linear_C(x)    <span class="co"># Input-dependent C matrix  </span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>Δ <span class="op">=</span> softplus(Linear_Δ(x))  <span class="co"># Input-dependent step size</span></span></code></pre></div></div>
</section>
</section>
<section id="core-components" class="level2">
<h2 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h2>
<section id="selective-scan-algorithm" class="level3">
<h3 class="anchored" data-anchor-id="selective-scan-algorithm" id="selective-scan-algorithm">Selective Scan Algorithm</h3>
<p>The heart of Mamba is the selective scan that computes:</p>
<div id="67ffd0c4" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> einops <span class="im">import</span> rearrange, repeat</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> selective_scan(u, delta, A, B, C, D):</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co">    Selective scan implementation</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a><span class="co">    u : torch.Tensor</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="co">        Input sequence (B, L, D)</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="co">    delta : torch.Tensor</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="co">        Step sizes (B, L, D) </span></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a><span class="co">    A : torch.Tensor</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="co">        State matrix (D, N)</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co">    B : torch.Tensor</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a><span class="co">        Input matrix (B, L, N)</span></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="co">    C : torch.Tensor</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="co">        Output matrix (B, L, N) </span></span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a><span class="co">    D : torch.Tensor</span></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a><span class="co">        Feedthrough (D,)</span></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a><span class="co">    torch.Tensor</span></span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a><span class="co">        Output sequence (B, L, D)</span></span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>    deltaA <span class="op">=</span> torch.exp(delta.unsqueeze(<span class="op">-</span><span class="dv">1</span>) <span class="op">*</span> A)  <span class="co"># (B, L, D, N)</span></span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>    deltaB <span class="op">=</span> delta.unsqueeze(<span class="op">-</span><span class="dv">1</span>) <span class="op">*</span> B.unsqueeze(<span class="dv">2</span>)  <span class="co"># (B, L, D, N)</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Parallel scan implementation</span></span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.zeros(B.shape[<span class="dv">0</span>], A.shape[<span class="op">-</span><span class="dv">1</span>], device<span class="op">=</span>u.device)</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> []</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(u.shape[<span class="dv">1</span>]):</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> deltaA[:, i] <span class="op">*</span> x <span class="op">+</span> deltaB[:, i] <span class="op">*</span> u[:, i].unsqueeze(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> torch.einsum(<span class="st">'bdn,bn-&gt;bd'</span>, x, C[:, i]) <span class="op">+</span> D <span class="op">*</span> u[:, i]</span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>        outputs.append(y)</span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.stack(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span></code></pre></div></div>
</div>
</section>
<section id="mamba-block-architecture" class="level3">
<h3 class="anchored" data-anchor-id="mamba-block-architecture" id="mamba-block-architecture">Mamba Block Architecture</h3>
<div id="69ac4b7b" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaBlock(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Mamba block implementing selective state space model</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, d_state<span class="op">=</span><span class="dv">16</span>, d_conv<span class="op">=</span><span class="dv">4</span>, expand<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_state <span class="op">=</span> d_state</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_conv <span class="op">=</span> d_conv</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_inner <span class="op">=</span> <span class="bu">int</span>(expand <span class="op">*</span> d_model)</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Input projection</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.in_proj <span class="op">=</span> nn.Linear(d_model, <span class="va">self</span>.d_inner <span class="op">*</span> <span class="dv">2</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convolution layer</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1d <span class="op">=</span> nn.Conv1d(</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>            in_channels<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            out_channels<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>            kernel_size<span class="op">=</span>d_conv,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>            bias<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            groups<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            padding<span class="op">=</span>d_conv <span class="op">-</span> <span class="dv">1</span>,</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># SSM parameters</span></span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.x_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_inner, <span class="va">self</span>.d_state <span class="op">*</span> <span class="dv">2</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dt_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_inner, <span class="va">self</span>.d_inner, bias<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize A matrix (complex initialization for stability)</span></span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>        A <span class="op">=</span> repeat(torch.arange(<span class="dv">1</span>, <span class="va">self</span>.d_state <span class="op">+</span> <span class="dv">1</span>), <span class="st">'n -&gt; d n'</span>, d<span class="op">=</span><span class="va">self</span>.d_inner)</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.A_log <span class="op">=</span> nn.Parameter(torch.log(A))</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Output projection</span></span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.out_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_inner, d_model, bias<span class="op">=</span><span class="va">False</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="complete-implementation" class="level2">
<h2 class="anchored" data-anchor-id="complete-implementation" id="complete-implementation">Complete Implementation</h2>
<section id="full-mamba-model" class="level3">
<h3 class="anchored" data-anchor-id="full-mamba-model" id="full-mamba-model">Full Mamba Model</h3>
<div id="afa6c235" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Mamba(nn.Module):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Complete Mamba model implementation</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        d_model: <span class="bu">int</span>,</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        n_layer: <span class="bu">int</span>,</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        vocab_size: <span class="bu">int</span>,</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        d_state: <span class="bu">int</span> <span class="op">=</span> <span class="dv">16</span>,</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        expand: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2</span>,</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        dt_rank: <span class="bu">str</span> <span class="op">=</span> <span class="st">"auto"</span>,</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        d_conv: <span class="bu">int</span> <span class="op">=</span> <span class="dv">4</span>,</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        conv_bias: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span>,</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        bias: <span class="bu">bool</span> <span class="op">=</span> <span class="va">False</span>,</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    ):</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.n_layer <span class="op">=</span> n_layer</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vocab_size <span class="op">=</span> vocab_size</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Token embeddings</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.embedding <span class="op">=</span> nn.Embedding(vocab_size, d_model)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mamba layers</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList([</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>            ResidualBlock(</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>                MambaBlock(</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>                    d_model<span class="op">=</span>d_model,</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>                    d_state<span class="op">=</span>d_state,</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>                    expand<span class="op">=</span>expand,</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>                    dt_rank<span class="op">=</span>dt_rank,</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>                    d_conv<span class="op">=</span>d_conv,</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>                    conv_bias<span class="op">=</span>conv_bias,</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>                    bias<span class="op">=</span>bias,</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_layer)</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final layer norm and output projection</span></span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm_f <span class="op">=</span> RMSNorm(d_model)</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lm_head <span class="op">=</span> nn.Linear(d_model, vocab_size, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Weight tying</span></span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lm_head.weight <span class="op">=</span> <span class="va">self</span>.embedding.weight</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, input_ids):</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass</span></span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a><span class="co">        input_ids : torch.Tensor</span></span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a><span class="co">            Input token ids (batch, seqlen)</span></span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a><span class="co">            Logits (batch, seqlen, vocab_size)</span></span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.embedding(input_ids)</span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> layer(x)</span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm_f(x)</span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>.lm_head(x)</span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits</span></code></pre></div></div>
</div>
</section>
<section id="enhanced-mambablock-implementation" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-mambablock-implementation" id="enhanced-mambablock-implementation">Enhanced MambaBlock Implementation</h3>
<div id="87e80f9e" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaBlock(nn.Module):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>        d_model,</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        d_state<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        expand<span class="op">=</span><span class="dv">2</span>,</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        dt_rank<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        d_conv<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        conv_bias<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        bias<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    ):</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_state <span class="op">=</span> d_state</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.expand <span class="op">=</span> expand</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_inner <span class="op">=</span> <span class="bu">int</span>(<span class="va">self</span>.expand <span class="op">*</span> <span class="va">self</span>.d_model)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dt_rank <span class="op">=</span> math.ceil(<span class="va">self</span>.d_model <span class="op">/</span> <span class="dv">16</span>) <span class="cf">if</span> dt_rank <span class="op">==</span> <span class="st">"auto"</span> <span class="cf">else</span> dt_rank</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Input projections</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.in_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_model, <span class="va">self</span>.d_inner <span class="op">*</span> <span class="dv">2</span>, bias<span class="op">=</span>bias)</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convolution</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1d <span class="op">=</span> nn.Conv1d(</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>            in_channels<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>            out_channels<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>            bias<span class="op">=</span>conv_bias,</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>            kernel_size<span class="op">=</span>d_conv,</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>            groups<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>            padding<span class="op">=</span>d_conv <span class="op">-</span> <span class="dv">1</span>,</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># SSM projections</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.x_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_inner, <span class="va">self</span>.dt_rank <span class="op">+</span> <span class="va">self</span>.d_state <span class="op">*</span> <span class="dv">2</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dt_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.dt_rank, <span class="va">self</span>.d_inner, bias<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize dt projection</span></span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        dt_init_std <span class="op">=</span> <span class="va">self</span>.dt_rank<span class="op">**-</span><span class="fl">0.5</span> <span class="op">*</span> <span class="va">self</span>.d_model<span class="op">**-</span><span class="fl">0.5</span></span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.dt_proj.weight.uniform_(<span class="op">-</span>dt_init_std, dt_init_std)</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize A matrix (S4D initialization)</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>        A <span class="op">=</span> repeat(</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>            torch.arange(<span class="dv">1</span>, <span class="va">self</span>.d_state <span class="op">+</span> <span class="dv">1</span>, dtype<span class="op">=</span>torch.float32),</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">"n -&gt; d n"</span>,</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>            d<span class="op">=</span><span class="va">self</span>.d_inner,</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>        ).contiguous()</span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>        A_log <span class="op">=</span> torch.log(A)</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.A_log <span class="op">=</span> nn.Parameter(A_log)</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize D parameter</span></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.D <span class="op">=</span> nn.Parameter(torch.ones(<span class="va">self</span>.d_inner))</span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Output projection</span></span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.out_proj <span class="op">=</span> nn.Linear(<span class="va">self</span>.d_inner, <span class="va">self</span>.d_model, bias<span class="op">=</span>bias)</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass through Mamba block</span></span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a><span class="co">        x : torch.Tensor</span></span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a><span class="co">            Input tensor (B, L, D)</span></span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb5-68"><a href="#cb5-68" aria-hidden="true" tabindex="-1"></a><span class="co">            Output tensor (B, L, D)</span></span>
<span id="cb5-69"><a href="#cb5-69" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-70"><a href="#cb5-70" aria-hidden="true" tabindex="-1"></a>        (B, L, D) <span class="op">=</span> x.shape</span>
<span id="cb5-71"><a href="#cb5-71" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-72"><a href="#cb5-72" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Input projections</span></span>
<span id="cb5-73"><a href="#cb5-73" aria-hidden="true" tabindex="-1"></a>        x_and_res <span class="op">=</span> <span class="va">self</span>.in_proj(x)  <span class="co"># (B, L, 2 * d_inner)</span></span>
<span id="cb5-74"><a href="#cb5-74" aria-hidden="true" tabindex="-1"></a>        x, res <span class="op">=</span> x_and_res.split(split_size<span class="op">=</span>[<span class="va">self</span>.d_inner, <span class="va">self</span>.d_inner], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-75"><a href="#cb5-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-76"><a href="#cb5-76" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convolution</span></span>
<span id="cb5-77"><a href="#cb5-77" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> rearrange(x, <span class="st">'b l d -&gt; b d l'</span>)</span>
<span id="cb5-78"><a href="#cb5-78" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv1d(x)[:, :, :L]  <span class="co"># Truncate to original length</span></span>
<span id="cb5-79"><a href="#cb5-79" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> rearrange(x, <span class="st">'b d l -&gt; b l d'</span>)</span>
<span id="cb5-80"><a href="#cb5-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-81"><a href="#cb5-81" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Activation</span></span>
<span id="cb5-82"><a href="#cb5-82" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.silu(x)</span>
<span id="cb5-83"><a href="#cb5-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-84"><a href="#cb5-84" aria-hidden="true" tabindex="-1"></a>        <span class="co"># SSM</span></span>
<span id="cb5-85"><a href="#cb5-85" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> <span class="va">self</span>.ssm(x)</span>
<span id="cb5-86"><a href="#cb5-86" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-87"><a href="#cb5-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gating and output projection</span></span>
<span id="cb5-88"><a href="#cb5-88" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> y <span class="op">*</span> F.silu(res)</span>
<span id="cb5-89"><a href="#cb5-89" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.out_proj(y)</span>
<span id="cb5-90"><a href="#cb5-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-91"><a href="#cb5-91" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span>
<span id="cb5-92"><a href="#cb5-92" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-93"><a href="#cb5-93" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> ssm(<span class="va">self</span>, x):</span>
<span id="cb5-94"><a href="#cb5-94" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-95"><a href="#cb5-95" aria-hidden="true" tabindex="-1"></a><span class="co">        Selective State Space Model computation</span></span>
<span id="cb5-96"><a href="#cb5-96" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-97"><a href="#cb5-97" aria-hidden="true" tabindex="-1"></a>        (B, L, D) <span class="op">=</span> x.shape</span>
<span id="cb5-98"><a href="#cb5-98" aria-hidden="true" tabindex="-1"></a>        N <span class="op">=</span> <span class="va">self</span>.d_state</span>
<span id="cb5-99"><a href="#cb5-99" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-100"><a href="#cb5-100" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract A matrix</span></span>
<span id="cb5-101"><a href="#cb5-101" aria-hidden="true" tabindex="-1"></a>        A <span class="op">=</span> <span class="op">-</span>torch.exp(<span class="va">self</span>.A_log.<span class="bu">float</span>())  <span class="co"># (d_inner, d_state)</span></span>
<span id="cb5-102"><a href="#cb5-102" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-103"><a href="#cb5-103" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute Δ, B, C</span></span>
<span id="cb5-104"><a href="#cb5-104" aria-hidden="true" tabindex="-1"></a>        x_dbl <span class="op">=</span> <span class="va">self</span>.x_proj(x)  <span class="co"># (B, L, dt_rank + 2*d_state)</span></span>
<span id="cb5-105"><a href="#cb5-105" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-106"><a href="#cb5-106" aria-hidden="true" tabindex="-1"></a>        delta, B, C <span class="op">=</span> torch.split(</span>
<span id="cb5-107"><a href="#cb5-107" aria-hidden="true" tabindex="-1"></a>            x_dbl, [<span class="va">self</span>.dt_rank, N, N], dim<span class="op">=-</span><span class="dv">1</span></span>
<span id="cb5-108"><a href="#cb5-108" aria-hidden="true" tabindex="-1"></a>        )  <span class="co"># delta: (B, L, dt_rank), B, C: (B, L, d_state)</span></span>
<span id="cb5-109"><a href="#cb5-109" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-110"><a href="#cb5-110" aria-hidden="true" tabindex="-1"></a>        delta <span class="op">=</span> F.softplus(<span class="va">self</span>.dt_proj(delta))  <span class="co"># (B, L, d_inner)</span></span>
<span id="cb5-111"><a href="#cb5-111" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-112"><a href="#cb5-112" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Selective scan</span></span>
<span id="cb5-113"><a href="#cb5-113" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> <span class="va">self</span>.selective_scan(x, delta, A, B, C, <span class="va">self</span>.D)</span>
<span id="cb5-114"><a href="#cb5-114" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-115"><a href="#cb5-115" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> y</span>
<span id="cb5-116"><a href="#cb5-116" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-117"><a href="#cb5-117" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> selective_scan(<span class="va">self</span>, u, delta, A, B, C, D):</span>
<span id="cb5-118"><a href="#cb5-118" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-119"><a href="#cb5-119" aria-hidden="true" tabindex="-1"></a><span class="co">        Selective scan implementation with parallel processing</span></span>
<span id="cb5-120"><a href="#cb5-120" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-121"><a href="#cb5-121" aria-hidden="true" tabindex="-1"></a>        (B, L, D) <span class="op">=</span> u.shape</span>
<span id="cb5-122"><a href="#cb5-122" aria-hidden="true" tabindex="-1"></a>        N <span class="op">=</span> A.shape[<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb5-123"><a href="#cb5-123" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-124"><a href="#cb5-124" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Discretize A and B</span></span>
<span id="cb5-125"><a href="#cb5-125" aria-hidden="true" tabindex="-1"></a>        deltaA <span class="op">=</span> torch.exp(<span class="va">self</span>.einsum(delta, A, <span class="st">'b l d, d n -&gt; b l d n'</span>))</span>
<span id="cb5-126"><a href="#cb5-126" aria-hidden="true" tabindex="-1"></a>        deltaB_u <span class="op">=</span> <span class="va">self</span>.einsum(delta, B, u, <span class="st">'b l d, b l n, b l d -&gt; b l d n'</span>)</span>
<span id="cb5-127"><a href="#cb5-127" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-128"><a href="#cb5-128" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Parallel scan (simplified version)</span></span>
<span id="cb5-129"><a href="#cb5-129" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.zeros((B, D, N), device<span class="op">=</span>deltaA.device, dtype<span class="op">=</span>deltaA.dtype)</span>
<span id="cb5-130"><a href="#cb5-130" aria-hidden="true" tabindex="-1"></a>        ys <span class="op">=</span> []</span>
<span id="cb5-131"><a href="#cb5-131" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-132"><a href="#cb5-132" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(L):</span>
<span id="cb5-133"><a href="#cb5-133" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> deltaA[:, i] <span class="op">*</span> x <span class="op">+</span> deltaB_u[:, i]</span>
<span id="cb5-134"><a href="#cb5-134" aria-hidden="true" tabindex="-1"></a>            y <span class="op">=</span> <span class="va">self</span>.einsum(x, C[:, i], <span class="st">'b d n, b n -&gt; b d'</span>)</span>
<span id="cb5-135"><a href="#cb5-135" aria-hidden="true" tabindex="-1"></a>            ys.append(y)</span>
<span id="cb5-136"><a href="#cb5-136" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-137"><a href="#cb5-137" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> torch.stack(ys, dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># (B, L, D)</span></span>
<span id="cb5-138"><a href="#cb5-138" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-139"><a href="#cb5-139" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add skip connection</span></span>
<span id="cb5-140"><a href="#cb5-140" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> y <span class="op">+</span> u <span class="op">*</span> D</span>
<span id="cb5-141"><a href="#cb5-141" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-142"><a href="#cb5-142" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> y</span>
<span id="cb5-143"><a href="#cb5-143" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-144"><a href="#cb5-144" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb5-145"><a href="#cb5-145" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> einsum(q, k, v<span class="op">=</span><span class="va">None</span>, equation<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb5-146"><a href="#cb5-146" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Helper function for einsum operations"""</span></span>
<span id="cb5-147"><a href="#cb5-147" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> v <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb5-148"><a href="#cb5-148" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> torch.einsum(equation, q, k)</span>
<span id="cb5-149"><a href="#cb5-149" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.einsum(equation, q, k, v)</span></code></pre></div></div>
</div>
</section>
<section id="supporting-components" class="level3">
<h3 class="anchored" data-anchor-id="supporting-components" id="supporting-components">Supporting Components</h3>
<div id="c0e1e8ef" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ResidualBlock(nn.Module):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Residual block with pre-normalization"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, mixer):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mixer <span class="op">=</span> mixer</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm <span class="op">=</span> RMSNorm(mixer.d_model)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.mixer(<span class="va">self</span>.norm(x)) <span class="op">+</span> x</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RMSNorm(nn.Module):</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Root Mean Square Layer Normalization"""</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, eps<span class="op">=</span><span class="fl">1e-5</span>):</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.eps <span class="op">=</span> eps</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.weight <span class="op">=</span> nn.Parameter(torch.ones(d_model))</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> x <span class="op">*</span> torch.rsqrt(x.<span class="bu">pow</span>(<span class="dv">2</span>).mean(<span class="op">-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>) <span class="op">+</span> <span class="va">self</span>.eps) <span class="op">*</span> <span class="va">self</span>.weight</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
</section>
</section>
<section id="training-and-optimization" class="level2">
<h2 class="anchored" data-anchor-id="training-and-optimization" id="training-and-optimization">Training and Optimization</h2>
<section id="training-configuration" class="level3">
<h3 class="anchored" data-anchor-id="training-configuration" id="training-configuration">Training Configuration</h3>
<div id="b86b21ce" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TrainingConfig:</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Configuration class for training hyperparameters"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model architecture</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    d_model: <span class="bu">int</span> <span class="op">=</span> <span class="dv">768</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    n_layer: <span class="bu">int</span> <span class="op">=</span> <span class="dv">24</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    vocab_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">50257</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training parameters</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">32</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    learning_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1e-4</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    weight_decay: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    max_seq_len: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2048</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimization</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    warmup_steps: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2000</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    max_steps: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100000</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    eval_interval: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1000</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Hardware optimization</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    mixed_precision: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    gradient_checkpointing: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span></span></code></pre></div></div>
</div>
</section>
<section id="optimizer-setup" class="level3">
<h3 class="anchored" data-anchor-id="optimizer-setup" id="optimizer-setup">Optimizer Setup</h3>
<div id="9a9184da" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_optimizer(model, config):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Create optimizer with proper weight decay configuration</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co">    model : nn.Module</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="co">        The model to optimize</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="co">    config : TrainingConfig</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="co">        Training configuration</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a><span class="co">    torch.optim.AdamW</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a><span class="co">        Configured optimizer</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Separate parameters for weight decay</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    decay <span class="op">=</span> <span class="bu">set</span>()</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    no_decay <span class="op">=</span> <span class="bu">set</span>()</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> mn, m <span class="kw">in</span> model.named_modules():</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> pn, p <span class="kw">in</span> m.named_parameters():</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>            fpn <span class="op">=</span> <span class="ss">f'</span><span class="sc">{</span>mn<span class="sc">}</span><span class="ss">.</span><span class="sc">{</span>pn<span class="sc">}</span><span class="ss">'</span> <span class="cf">if</span> mn <span class="cf">else</span> pn</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">'bias'</span> <span class="kw">in</span> pn <span class="kw">or</span> <span class="st">'norm'</span> <span class="kw">in</span> pn <span class="kw">or</span> <span class="st">'embedding'</span> <span class="kw">in</span> pn:</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>                no_decay.add(fpn)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>                decay.add(fpn)</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    param_dict <span class="op">=</span> {pn: p <span class="cf">for</span> pn, p <span class="kw">in</span> model.named_parameters()}</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    optim_groups <span class="op">=</span> [</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'params'</span>: [param_dict[pn] <span class="cf">for</span> pn <span class="kw">in</span> <span class="bu">sorted</span>(<span class="bu">list</span>(decay))], </span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>            <span class="st">'weight_decay'</span>: config.weight_decay</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        },</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>            <span class="st">'params'</span>: [param_dict[pn] <span class="cf">for</span> pn <span class="kw">in</span> <span class="bu">sorted</span>(<span class="bu">list</span>(no_decay))], </span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>            <span class="st">'weight_decay'</span>: <span class="fl">0.0</span></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        },</span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.optim.AdamW(optim_groups, lr<span class="op">=</span>config.learning_rate)</span></code></pre></div></div>
</div>
</section>
<section id="training-loop-implementation" class="level3">
<h3 class="anchored" data-anchor-id="training-loop-implementation" id="training-loop-implementation">Training Loop Implementation</h3>
<div id="647559f2" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaTrainer:</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive trainer for Mamba models"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, config, train_loader, val_loader):</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_loader <span class="op">=</span> train_loader</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_loader <span class="op">=</span> val_loader</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> create_optimizer(model, config)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler <span class="op">=</span> <span class="va">self</span>.create_scheduler()</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scaler <span class="op">=</span> torch.cuda.amp.GradScaler() <span class="cf">if</span> config.mixed_precision <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_scheduler(<span class="va">self</span>):</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create cosine annealing scheduler with warmup"""</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> lr_lambda(step):</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> step <span class="op">&lt;</span> <span class="va">self</span>.config.warmup_steps:</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> step <span class="op">/</span> <span class="va">self</span>.config.warmup_steps</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>                progress <span class="op">=</span> (step <span class="op">-</span> <span class="va">self</span>.config.warmup_steps) <span class="op">/</span> <span class="op">\</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>                          (<span class="va">self</span>.config.max_steps <span class="op">-</span> <span class="va">self</span>.config.warmup_steps)</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="fl">0.5</span> <span class="op">*</span> (<span class="dv">1</span> <span class="op">+</span> math.cos(math.pi <span class="op">*</span> progress))</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.optim.lr_scheduler.LambdaLR(<span class="va">self</span>.optimizer, lr_lambda)</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_step(<span class="va">self</span>, batch):</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Single training step with mixed precision"""</span></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.train()</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>        input_ids <span class="op">=</span> batch[<span class="st">'input_ids'</span>]</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        targets <span class="op">=</span> input_ids[:, <span class="dv">1</span>:].contiguous()</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        input_ids <span class="op">=</span> input_ids[:, :<span class="op">-</span><span class="dv">1</span>].contiguous()</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.cuda.amp.autocast(enabled<span class="op">=</span><span class="va">self</span>.config.mixed_precision):</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>            logits <span class="op">=</span> <span class="va">self</span>.model(input_ids)</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> F.cross_entropy(</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>                logits.view(<span class="op">-</span><span class="dv">1</span>, logits.size(<span class="op">-</span><span class="dv">1</span>)), </span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>                targets.view(<span class="op">-</span><span class="dv">1</span>),</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>                ignore_index<span class="op">=-</span><span class="dv">1</span></span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass with gradient scaling</span></span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.scaler:</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.scale(loss).backward()</span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.unscale_(<span class="va">self</span>.optimizer)</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.model.parameters(), <span class="fl">1.0</span>)</span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.step(<span class="va">self</span>.optimizer)</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.update()</span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.model.parameters(), <span class="fl">1.0</span>)</span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.step()</span>
<span id="cb9-53"><a href="#cb9-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-54"><a href="#cb9-54" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb9-55"><a href="#cb9-55" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler.step()</span>
<span id="cb9-56"><a href="#cb9-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-57"><a href="#cb9-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss.item()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="practical-applications" class="level2">
<h2 class="anchored" data-anchor-id="practical-applications" id="practical-applications">Practical Applications</h2>
<section id="text-generation" class="level3">
<h3 class="anchored" data-anchor-id="text-generation" id="text-generation">Text Generation</h3>
<div id="edc61fcc" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_text(model, tokenizer, prompt, max_length<span class="op">=</span><span class="dv">100</span>, temperature<span class="op">=</span><span class="fl">0.8</span>):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate text using Mamba model</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="co">    model : Mamba</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="co">        Trained Mamba model</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a><span class="co">    tokenizer : Tokenizer</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="co">        Text tokenizer</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="co">    prompt : str</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a><span class="co">        Input prompt</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a><span class="co">    max_length : int</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a><span class="co">        Maximum generation length</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a><span class="co">    temperature : float</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a><span class="co">        Sampling temperature</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="co">    str</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a><span class="co">        Generated text</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Tokenize prompt</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    input_ids <span class="op">=</span> tokenizer.encode(prompt, return_tensors<span class="op">=</span><span class="st">'pt'</span>)</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(max_length):</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>            logits <span class="op">=</span> model(input_ids)</span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample next token</span></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>            next_token_logits <span class="op">=</span> logits[:, <span class="op">-</span><span class="dv">1</span>, :] <span class="op">/</span> temperature</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>            probs <span class="op">=</span> F.softmax(next_token_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>            next_token <span class="op">=</span> torch.multinomial(probs, num_samples<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Append to sequence</span></span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a>            input_ids <span class="op">=</span> torch.cat([input_ids, next_token], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Check for end token</span></span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> next_token.item() <span class="op">==</span> tokenizer.eos_token_id:</span>
<span id="cb10-43"><a href="#cb10-43" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb10-44"><a href="#cb10-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-45"><a href="#cb10-45" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> tokenizer.decode(input_ids[<span class="dv">0</span>], skip_special_tokens<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb10-46"><a href="#cb10-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-47"><a href="#cb10-47" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb10-48"><a href="#cb10-48" aria-hidden="true" tabindex="-1"></a><span class="co"># prompt = "The future of artificial intelligence is"</span></span>
<span id="cb10-49"><a href="#cb10-49" aria-hidden="true" tabindex="-1"></a><span class="co"># generated = generate_text(model, tokenizer, prompt)</span></span>
<span id="cb10-50"><a href="#cb10-50" aria-hidden="true" tabindex="-1"></a><span class="co"># print(generated)</span></span></code></pre></div></div>
</div>
</section>
<section id="document-classification" class="level3">
<h3 class="anchored" data-anchor-id="document-classification" id="document-classification">Document Classification</h3>
<div id="2a60ff73" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaClassifier(nn.Module):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Mamba-based document classifier"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, mamba_model, num_classes):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mamba <span class="op">=</span> mamba_model</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(mamba_model.d_model, num_classes)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, input_ids, attention_mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass for classification</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a><span class="co">        input_ids : torch.Tensor</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a><span class="co">            Input token ids</span></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a><span class="co">        attention_mask : torch.Tensor, optional</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a><span class="co">            Attention mask for padding tokens</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a><span class="co">            Classification logits</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get Mamba features</span></span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        hidden_states <span class="op">=</span> <span class="va">self</span>.mamba.embedding(input_ids)</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.mamba.layers:</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>            hidden_states <span class="op">=</span> layer(hidden_states)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        hidden_states <span class="op">=</span> <span class="va">self</span>.mamba.norm_f(hidden_states)</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global average pooling</span></span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> attention_mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>            mask <span class="op">=</span> attention_mask.unsqueeze(<span class="op">-</span><span class="dv">1</span>).expand_as(hidden_states).<span class="bu">float</span>()</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>            pooled <span class="op">=</span> (hidden_states <span class="op">*</span> mask).<span class="bu">sum</span>(<span class="dv">1</span>) <span class="op">/</span> mask.<span class="bu">sum</span>(<span class="dv">1</span>)</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>            pooled <span class="op">=</span> hidden_states.mean(<span class="dv">1</span>)</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classification</span></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>.classifier(pooled)</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits</span></code></pre></div></div>
</div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<div id="7fd35770" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedMamba(Mamba):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Memory-optimized Mamba with gradient checkpointing"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.gradient_checkpointing <span class="op">=</span> <span class="va">True</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, input_ids):</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Forward pass with optional gradient checkpointing"""</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.embedding(input_ids)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use checkpointing for memory efficiency</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.gradient_checkpointing <span class="kw">and</span> <span class="va">self</span>.training:</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> torch.utils.checkpoint.checkpoint(layer, x)</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> layer(x)</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm_f(x)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>.lm_head(x)</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> profile_memory(model, input_size):</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a><span class="co">    Profile memory usage of the model</span></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a><span class="co">    model : nn.Module</span></span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a><span class="co">        Model to profile</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a><span class="co">    input_size : tuple</span></span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a><span class="co">        Input tensor size</span></span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a><span class="co">    float</span></span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a><span class="co">        Peak memory usage in GB</span></span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>    dummy_input <span class="op">=</span> torch.randint(<span class="dv">0</span>, model.vocab_size, input_size)</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>    torch.cuda.reset_peak_memory_stats()</span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(dummy_input)</span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> output.<span class="bu">sum</span>()</span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>    peak_memory <span class="op">=</span> torch.cuda.max_memory_allocated() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">3</span>  <span class="co"># GB</span></span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Peak memory usage: </span><span class="sc">{</span>peak_memory<span class="sc">:.2f}</span><span class="ss"> GB"</span>)</span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> peak_memory</span></code></pre></div></div>
</div>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<section id="complexity-analysis" class="level3">
<h3 class="anchored" data-anchor-id="complexity-analysis" id="complexity-analysis">Complexity Analysis</h3>
<table class="caption-top table">
<caption>Computational complexity comparison between Transformer and Mamba architectures</caption>
<thead>
<tr class="header">
<th>Metric</th>
<th>Transformer</th>
<th>Mamba</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Time Complexity</td>
<td><span class="math inline">\(O(L^2d)\)</span></td>
<td><span class="math inline">\(O(Ld)\)</span></td>
</tr>
<tr class="even">
<td>Memory Complexity</td>
<td><span class="math inline">\(O(L^2)\)</span></td>
<td><span class="math inline">\(O(L)\)</span></td>
</tr>
<tr class="odd">
<td>Parallelization</td>
<td>High (attention)</td>
<td>Medium (selective scan)</td>
</tr>
<tr class="even">
<td>Long Context Scaling</td>
<td>Quadratic</td>
<td>Linear</td>
</tr>
</tbody>
</table>
</section>
<section id="benchmarking-implementation" class="level3">
<h3 class="anchored" data-anchor-id="benchmarking-implementation" id="benchmarking-implementation">Benchmarking Implementation</h3>
<div id="6d1aae94" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_models():</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Compare Mamba vs Transformer performance across sequence lengths</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co">    dict</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="co">        Benchmark results containing memory and time measurements</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    sequence_lengths <span class="op">=</span> [<span class="dv">512</span>, <span class="dv">1024</span>, <span class="dv">2048</span>, <span class="dv">4096</span>, <span class="dv">8192</span>]</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> {</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">'mamba'</span>: {<span class="st">'memory'</span>: [], <span class="st">'time'</span>: []},</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">'transformer'</span>: {<span class="st">'memory'</span>: [], <span class="st">'time'</span>: []}</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> seq_len <span class="kw">in</span> sequence_lengths:</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark Mamba</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        mamba_model <span class="op">=</span> Mamba(d_model<span class="op">=</span><span class="dv">768</span>, n_layer<span class="op">=</span><span class="dv">12</span>, vocab_size<span class="op">=</span><span class="dv">50257</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        mamba_memory, mamba_time <span class="op">=</span> benchmark_single_model(mamba_model, seq_len)</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark would require transformer implementation</span></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># transformer_model = GPT2Model.from_pretrained('gpt2')</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># transformer_memory, transformer_time = benchmark_single_model(transformer_model, seq_len)</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'mamba'</span>][<span class="st">'memory'</span>].append(mamba_memory)</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'mamba'</span>][<span class="st">'time'</span>].append(mamba_time)</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># results['transformer']['memory'].append(transformer_memory)</span></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># results['transformer']['time'].append(transformer_time)</span></span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_single_model(model, seq_len):</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a><span class="co">    Benchmark a single model for memory and time</span></span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a><span class="co">    model : nn.Module</span></span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a><span class="co">        Model to benchmark</span></span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a><span class="co">    seq_len : int</span></span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a><span class="co">        Sequence length to test</span></span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a><span class="co">    tuple</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a><span class="co">        (memory_usage_gb, time_seconds)</span></span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> time</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>    vocab_size <span class="op">=</span> <span class="bu">getattr</span>(model, <span class="st">'vocab_size'</span>, <span class="dv">50257</span>)</span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>    input_ids <span class="op">=</span> torch.randint(<span class="dv">0</span>, vocab_size, (batch_size, seq_len))</span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Memory benchmark</span></span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a>    torch.cuda.reset_peak_memory_stats()</span>
<span id="cb13-56"><a href="#cb13-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-57"><a href="#cb13-57" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb13-58"><a href="#cb13-58" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb13-59"><a href="#cb13-59" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(input_ids)</span>
<span id="cb13-60"><a href="#cb13-60" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> output.logits.mean() <span class="cf">if</span> <span class="bu">hasattr</span>(output, <span class="st">'logits'</span>) <span class="cf">else</span> output.mean()</span>
<span id="cb13-61"><a href="#cb13-61" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb13-62"><a href="#cb13-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-63"><a href="#cb13-63" aria-hidden="true" tabindex="-1"></a>    end_time <span class="op">=</span> time.time()</span>
<span id="cb13-64"><a href="#cb13-64" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-65"><a href="#cb13-65" aria-hidden="true" tabindex="-1"></a>    memory_used <span class="op">=</span> torch.cuda.max_memory_allocated() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">3</span>  <span class="co"># GB</span></span>
<span id="cb13-66"><a href="#cb13-66" aria-hidden="true" tabindex="-1"></a>    time_taken <span class="op">=</span> end_time <span class="op">-</span> start_time</span>
<span id="cb13-67"><a href="#cb13-67" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-68"><a href="#cb13-68" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> memory_used, time_taken</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-extensions" class="level2">
<h2 class="anchored" data-anchor-id="advanced-extensions" id="advanced-extensions">Advanced Extensions</h2>
<section id="multi-modal-mamba" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-mamba" id="multi-modal-mamba">Multi-Modal Mamba</h3>
<div id="fbc6f79e" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiModalMamba(nn.Module):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Multi-modal Mamba for text and vision processing"""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, text_vocab_size, d_model, n_layer):</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Text processing</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text_embedding <span class="op">=</span> nn.Embedding(text_vocab_size, d_model)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Vision processing</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vision_encoder <span class="op">=</span> nn.Linear(<span class="dv">768</span>, d_model)  <span class="co"># From vision transformer</span></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Shared Mamba layers</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mamba_layers <span class="op">=</span> nn.ModuleList([</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>            MambaBlock(d_model) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_layer)</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Modality fusion</span></span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fusion_layer <span class="op">=</span> nn.Linear(d_model <span class="op">*</span> <span class="dv">2</span>, d_model)</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, text_ids, vision_features):</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a><span class="co">        Process multi-modal inputs</span></span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a><span class="co">        text_ids : torch.Tensor</span></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a><span class="co">            Text token ids</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a><span class="co">        vision_features : torch.Tensor</span></span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a><span class="co">            Vision features from encoder</span></span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a><span class="co">            Fused multi-modal representations</span></span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process text</span></span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>        text_embeds <span class="op">=</span> <span class="va">self</span>.text_embedding(text_ids)</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process vision</span></span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>        vision_embeds <span class="op">=</span> <span class="va">self</span>.vision_encoder(vision_features)</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine modalities</span></span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>        combined <span class="op">=</span> torch.cat([text_embeds, vision_embeds], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>        fused <span class="op">=</span> <span class="va">self</span>.fusion_layer(combined)</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process through Mamba</span></span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.mamba_layers:</span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>            fused <span class="op">=</span> layer(fused)</span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> fused</span></code></pre></div></div>
</div>
</section>
<section id="sparse-mamba-implementation" class="level3">
<h3 class="anchored" data-anchor-id="sparse-mamba-implementation" id="sparse-mamba-implementation">Sparse Mamba Implementation</h3>
<div id="82aadec4" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SparseMamba(MambaBlock):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Sparse version of Mamba with reduced connectivity"""</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">*</span>args, sparsity_ratio<span class="op">=</span><span class="fl">0.1</span>, <span class="op">**</span>kwargs):</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.sparsity_ratio <span class="op">=</span> sparsity_ratio</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'sparsity_mask'</span>, torch.ones(<span class="va">self</span>.d_inner, <span class="va">self</span>.d_state))</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize sparse connectivity</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._initialize_sparse_mask()</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _initialize_sparse_mask(<span class="va">self</span>):</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize sparse connectivity pattern"""</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Random sparsity pattern</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        num_connections <span class="op">=</span> <span class="bu">int</span>(<span class="va">self</span>.d_inner <span class="op">*</span> <span class="va">self</span>.d_state <span class="op">*</span> (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.sparsity_ratio))</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        flat_mask <span class="op">=</span> torch.zeros(<span class="va">self</span>.d_inner <span class="op">*</span> <span class="va">self</span>.d_state)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>        indices <span class="op">=</span> torch.randperm(<span class="va">self</span>.d_inner <span class="op">*</span> <span class="va">self</span>.d_state)[:num_connections]</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        flat_mask[indices] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.sparsity_mask <span class="op">=</span> flat_mask.view(<span class="va">self</span>.d_inner, <span class="va">self</span>.d_state)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> ssm(<span class="va">self</span>, x):</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""SSM computation with sparse connections"""</span></span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>        (B, L, D) <span class="op">=</span> x.shape</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>        N <span class="op">=</span> <span class="va">self</span>.d_state</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply sparsity mask to A matrix</span></span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>        A <span class="op">=</span> <span class="op">-</span>torch.exp(<span class="va">self</span>.A_log.<span class="bu">float</span>())</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>        A <span class="op">=</span> A <span class="op">*</span> <span class="va">self</span>.sparsity_mask  <span class="co"># Apply sparsity</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Rest of the SSM computation remains the same</span></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>        x_dbl <span class="op">=</span> <span class="va">self</span>.x_proj(x)</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>        delta, B, C <span class="op">=</span> torch.split(x_dbl, [<span class="va">self</span>.dt_rank, N, N], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        delta <span class="op">=</span> F.softplus(<span class="va">self</span>.dt_proj(delta))</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>        y <span class="op">=</span> <span class="va">self</span>.selective_scan(x, delta, A, B, C, <span class="va">self</span>.D)</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> y</span></code></pre></div></div>
</div>
</section>
<section id="mixture-of-experts-moe-mamba" class="level3">
<h3 class="anchored" data-anchor-id="mixture-of-experts-moe-mamba" id="mixture-of-experts-moe-mamba">Mixture of Experts (MoE) Mamba</h3>
<div id="863b6fca" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaExpert(nn.Module):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Individual expert in MoE Mamba"""</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, expert_id):</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.expert_id <span class="op">=</span> expert_id</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mamba_block <span class="op">=</span> MambaBlock(d_model)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.mamba_block(x)</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaMoE(nn.Module):</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Mamba with Mixture of Experts"""</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, num_experts<span class="op">=</span><span class="dv">8</span>, top_k<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_experts <span class="op">=</span> num_experts</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.top_k <span class="op">=</span> top_k</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Router network</span></span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.router <span class="op">=</span> nn.Linear(d_model, num_experts)</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Expert networks</span></span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.experts <span class="op">=</span> nn.ModuleList([</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>            MambaExpert(d_model, i) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_experts)</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load balancing</span></span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.load_balancing_loss_coeff <span class="op">=</span> <span class="fl">0.01</span></span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass through MoE Mamba</span></span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb16-36"><a href="#cb16-36" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb16-37"><a href="#cb16-37" aria-hidden="true" tabindex="-1"></a><span class="co">        x : torch.Tensor</span></span>
<span id="cb16-38"><a href="#cb16-38" aria-hidden="true" tabindex="-1"></a><span class="co">            Input tensor (batch_size, seq_len, d_model)</span></span>
<span id="cb16-39"><a href="#cb16-39" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb16-40"><a href="#cb16-40" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb16-41"><a href="#cb16-41" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb16-42"><a href="#cb16-42" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb16-43"><a href="#cb16-43" aria-hidden="true" tabindex="-1"></a><span class="co">            Output tensor (batch_size, seq_len, d_model)</span></span>
<span id="cb16-44"><a href="#cb16-44" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb16-45"><a href="#cb16-45" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len, d_model <span class="op">=</span> x.shape</span>
<span id="cb16-46"><a href="#cb16-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-47"><a href="#cb16-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Flatten for routing</span></span>
<span id="cb16-48"><a href="#cb16-48" aria-hidden="true" tabindex="-1"></a>        x_flat <span class="op">=</span> x.view(<span class="op">-</span><span class="dv">1</span>, d_model)  <span class="co"># (batch_size * seq_len, d_model)</span></span>
<span id="cb16-49"><a href="#cb16-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-50"><a href="#cb16-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Route tokens to experts</span></span>
<span id="cb16-51"><a href="#cb16-51" aria-hidden="true" tabindex="-1"></a>        router_logits <span class="op">=</span> <span class="va">self</span>.router(x_flat)  <span class="co"># (batch_size * seq_len, num_experts)</span></span>
<span id="cb16-52"><a href="#cb16-52" aria-hidden="true" tabindex="-1"></a>        routing_weights <span class="op">=</span> F.softmax(router_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb16-53"><a href="#cb16-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-54"><a href="#cb16-54" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select top-k experts</span></span>
<span id="cb16-55"><a href="#cb16-55" aria-hidden="true" tabindex="-1"></a>        top_k_weights, top_k_indices <span class="op">=</span> torch.topk(routing_weights, <span class="va">self</span>.top_k, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb16-56"><a href="#cb16-56" aria-hidden="true" tabindex="-1"></a>        top_k_weights <span class="op">=</span> F.softmax(top_k_weights, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb16-57"><a href="#cb16-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-58"><a href="#cb16-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize output</span></span>
<span id="cb16-59"><a href="#cb16-59" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.zeros_like(x_flat)</span>
<span id="cb16-60"><a href="#cb16-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-61"><a href="#cb16-61" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process tokens through selected experts</span></span>
<span id="cb16-62"><a href="#cb16-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.top_k):</span>
<span id="cb16-63"><a href="#cb16-63" aria-hidden="true" tabindex="-1"></a>            expert_indices <span class="op">=</span> top_k_indices[:, i]</span>
<span id="cb16-64"><a href="#cb16-64" aria-hidden="true" tabindex="-1"></a>            expert_weights <span class="op">=</span> top_k_weights[:, i].unsqueeze(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb16-65"><a href="#cb16-65" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb16-66"><a href="#cb16-66" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Group tokens by expert</span></span>
<span id="cb16-67"><a href="#cb16-67" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> expert_id <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_experts):</span>
<span id="cb16-68"><a href="#cb16-68" aria-hidden="true" tabindex="-1"></a>                mask <span class="op">=</span> expert_indices <span class="op">==</span> expert_id</span>
<span id="cb16-69"><a href="#cb16-69" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> mask.<span class="bu">any</span>():</span>
<span id="cb16-70"><a href="#cb16-70" aria-hidden="true" tabindex="-1"></a>                    expert_input <span class="op">=</span> x_flat[mask]</span>
<span id="cb16-71"><a href="#cb16-71" aria-hidden="true" tabindex="-1"></a>                    expert_output <span class="op">=</span> <span class="va">self</span>.experts[expert_id](</span>
<span id="cb16-72"><a href="#cb16-72" aria-hidden="true" tabindex="-1"></a>                        expert_input.view(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>, d_model)</span>
<span id="cb16-73"><a href="#cb16-73" aria-hidden="true" tabindex="-1"></a>                    ).view(<span class="op">-</span><span class="dv">1</span>, d_model)</span>
<span id="cb16-74"><a href="#cb16-74" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb16-75"><a href="#cb16-75" aria-hidden="true" tabindex="-1"></a>                    output[mask] <span class="op">+=</span> expert_weights[mask] <span class="op">*</span> expert_output</span>
<span id="cb16-76"><a href="#cb16-76" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-77"><a href="#cb16-77" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load balancing loss</span></span>
<span id="cb16-78"><a href="#cb16-78" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.training:</span>
<span id="cb16-79"><a href="#cb16-79" aria-hidden="true" tabindex="-1"></a>            load_balancing_loss <span class="op">=</span> <span class="va">self</span>._compute_load_balancing_loss(routing_weights)</span>
<span id="cb16-80"><a href="#cb16-80" aria-hidden="true" tabindex="-1"></a>            <span class="co"># This would be added to the main loss during training</span></span>
<span id="cb16-81"><a href="#cb16-81" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-82"><a href="#cb16-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output.view(batch_size, seq_len, d_model)</span>
<span id="cb16-83"><a href="#cb16-83" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-84"><a href="#cb16-84" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _compute_load_balancing_loss(<span class="va">self</span>, routing_weights):</span>
<span id="cb16-85"><a href="#cb16-85" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute load balancing loss for even expert utilization"""</span></span>
<span id="cb16-86"><a href="#cb16-86" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Fraction of tokens routed to each expert</span></span>
<span id="cb16-87"><a href="#cb16-87" aria-hidden="true" tabindex="-1"></a>        expert_usage <span class="op">=</span> routing_weights.<span class="bu">sum</span>(dim<span class="op">=</span><span class="dv">0</span>) <span class="op">/</span> routing_weights.shape[<span class="dv">0</span>]</span>
<span id="cb16-88"><a href="#cb16-88" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-89"><a href="#cb16-89" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Ideal usage (uniform distribution)</span></span>
<span id="cb16-90"><a href="#cb16-90" aria-hidden="true" tabindex="-1"></a>        ideal_usage <span class="op">=</span> <span class="fl">1.0</span> <span class="op">/</span> <span class="va">self</span>.num_experts</span>
<span id="cb16-91"><a href="#cb16-91" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-92"><a href="#cb16-92" aria-hidden="true" tabindex="-1"></a>        <span class="co"># L2 penalty for deviation from uniform usage</span></span>
<span id="cb16-93"><a href="#cb16-93" aria-hidden="true" tabindex="-1"></a>        load_balancing_loss <span class="op">=</span> torch.<span class="bu">sum</span>((expert_usage <span class="op">-</span> ideal_usage) <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb16-94"><a href="#cb16-94" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-95"><a href="#cb16-95" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.load_balancing_loss_coeff <span class="op">*</span> load_balancing_loss</span></code></pre></div></div>
</div>
</section>
<section id="bidirectional-mamba" class="level3">
<h3 class="anchored" data-anchor-id="bidirectional-mamba" id="bidirectional-mamba">Bidirectional Mamba</h3>
<div id="9110a9ef" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BidirectionalMamba(nn.Module):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Bidirectional Mamba for enhanced context modeling"""</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, d_state<span class="op">=</span><span class="dv">16</span>, expand<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward and backward Mamba blocks</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.forward_mamba <span class="op">=</span> MambaBlock(d_model, d_state, expand)</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backward_mamba <span class="op">=</span> MambaBlock(d_model, d_state, expand)</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Fusion layer</span></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fusion <span class="op">=</span> nn.Linear(d_model <span class="op">*</span> <span class="dv">2</span>, d_model)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a><span class="co">        Bidirectional processing of input sequence</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a><span class="co">        x : torch.Tensor</span></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a><span class="co">            Input tensor (batch_size, seq_len, d_model)</span></span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a><span class="co">            Bidirectionally processed output</span></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward direction</span></span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>        forward_output <span class="op">=</span> <span class="va">self</span>.forward_mamba(x)</span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward direction (reverse sequence)</span></span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>        backward_input <span class="op">=</span> torch.flip(x, dims<span class="op">=</span>[<span class="dv">1</span>])</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>        backward_output <span class="op">=</span> <span class="va">self</span>.backward_mamba(backward_input)</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>        backward_output <span class="op">=</span> torch.flip(backward_output, dims<span class="op">=</span>[<span class="dv">1</span>])</span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine forward and backward</span></span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>        combined <span class="op">=</span> torch.cat([forward_output, backward_output], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.fusion(combined)</span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
</section>
</section>
<section id="model-analysis-and-interpretability" class="level2">
<h2 class="anchored" data-anchor-id="model-analysis-and-interpretability" id="model-analysis-and-interpretability">Model Analysis and Interpretability</h2>
<section id="visualization-tools" class="level3">
<h3 class="anchored" data-anchor-id="visualization-tools" id="visualization-tools">Visualization Tools</h3>
<div id="50d146ac" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaVisualizer:</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Visualization tools for Mamba model analysis"""</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.activations <span class="op">=</span> {}</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.hooks <span class="op">=</span> []</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> register_hooks(<span class="va">self</span>):</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Register hooks to capture intermediate activations"""</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> hook_fn(name):</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> hook(module, <span class="bu">input</span>, output):</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.activations[name] <span class="op">=</span> output.detach()</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> hook</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, module <span class="kw">in</span> <span class="va">self</span>.model.named_modules():</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, MambaBlock):</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.hooks.append(</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>                    module.register_forward_hook(hook_fn(name))</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_state_importance(<span class="va">self</span>, input_text, layer_idx<span class="op">=-</span><span class="dv">1</span>):</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a><span class="co">        Compute importance scores similar to attention weights</span></span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a><span class="co">        input_text : str</span></span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a><span class="co">            Input text to analyze</span></span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a><span class="co">        layer_idx : int</span></span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a><span class="co">            Layer index to analyze</span></span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a><span class="co">            Importance scores for each position</span></span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_hooks()</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>        tokens <span class="op">=</span> <span class="va">self</span>.tokenizer.encode(input_text, return_tensors<span class="op">=</span><span class="st">'pt'</span>)</span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(tokens)</span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get activations from specified layer</span></span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>        layer_name <span class="op">=</span> <span class="ss">f'layers.</span><span class="sc">{</span>layer_idx<span class="sc">}</span><span class="ss">'</span></span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> layer_name <span class="kw">in</span> <span class="va">self</span>.activations:</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>            activations <span class="op">=</span> <span class="va">self</span>.activations[layer_name]</span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute importance as gradient of output w.r.t. hidden states</span></span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>            importance <span class="op">=</span> torch.autograd.grad(</span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>                output.<span class="bu">sum</span>(), activations, retain_graph<span class="op">=</span><span class="va">True</span></span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>            )[<span class="dv">0</span>]</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Normalize importance scores</span></span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>            importance <span class="op">=</span> F.softmax(importance.<span class="bu">abs</span>().<span class="bu">sum</span>(<span class="op">-</span><span class="dv">1</span>), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.remove_hooks()</span>
<span id="cb18-59"><a href="#cb18-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> importance</span>
<span id="cb18-60"><a href="#cb18-60" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-61"><a href="#cb18-61" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> remove_hooks(<span class="va">self</span>):</span>
<span id="cb18-62"><a href="#cb18-62" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Remove all registered hooks"""</span></span>
<span id="cb18-63"><a href="#cb18-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> hook <span class="kw">in</span> <span class="va">self</span>.hooks:</span>
<span id="cb18-64"><a href="#cb18-64" aria-hidden="true" tabindex="-1"></a>            hook.remove()</span>
<span id="cb18-65"><a href="#cb18-65" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.hooks <span class="op">=</span> []</span>
<span id="cb18-66"><a href="#cb18-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-67"><a href="#cb18-67" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> analyze_state_space(model, input_sequence):</span>
<span id="cb18-68"><a href="#cb18-68" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb18-69"><a href="#cb18-69" aria-hidden="true" tabindex="-1"></a><span class="co">    Analyze the state space dynamics of Mamba</span></span>
<span id="cb18-70"><a href="#cb18-70" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb18-71"><a href="#cb18-71" aria-hidden="true" tabindex="-1"></a><span class="co">    Parameters:</span></span>
<span id="cb18-72"><a href="#cb18-72" aria-hidden="true" tabindex="-1"></a><span class="co">    -----------</span></span>
<span id="cb18-73"><a href="#cb18-73" aria-hidden="true" tabindex="-1"></a><span class="co">    model : Mamba</span></span>
<span id="cb18-74"><a href="#cb18-74" aria-hidden="true" tabindex="-1"></a><span class="co">        Trained Mamba model</span></span>
<span id="cb18-75"><a href="#cb18-75" aria-hidden="true" tabindex="-1"></a><span class="co">    input_sequence : torch.Tensor</span></span>
<span id="cb18-76"><a href="#cb18-76" aria-hidden="true" tabindex="-1"></a><span class="co">        Input sequence to analyze</span></span>
<span id="cb18-77"><a href="#cb18-77" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb18-78"><a href="#cb18-78" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb18-79"><a href="#cb18-79" aria-hidden="true" tabindex="-1"></a><span class="co">    --------</span></span>
<span id="cb18-80"><a href="#cb18-80" aria-hidden="true" tabindex="-1"></a><span class="co">    dict</span></span>
<span id="cb18-81"><a href="#cb18-81" aria-hidden="true" tabindex="-1"></a><span class="co">        Dictionary containing state analysis results</span></span>
<span id="cb18-82"><a href="#cb18-82" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb18-83"><a href="#cb18-83" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Extract state trajectories</span></span>
<span id="cb18-84"><a href="#cb18-84" aria-hidden="true" tabindex="-1"></a>    states <span class="op">=</span> []</span>
<span id="cb18-85"><a href="#cb18-85" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-86"><a href="#cb18-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> state_hook(module, <span class="bu">input</span>, output):</span>
<span id="cb18-87"><a href="#cb18-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Capture state evolution during selective scan</span></span>
<span id="cb18-88"><a href="#cb18-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(module, <span class="st">'ssm'</span>):</span>
<span id="cb18-89"><a href="#cb18-89" aria-hidden="true" tabindex="-1"></a>            <span class="co"># This would require modifying the SSM to return intermediate states</span></span>
<span id="cb18-90"><a href="#cb18-90" aria-hidden="true" tabindex="-1"></a>            states.append(module.current_state.detach())</span>
<span id="cb18-91"><a href="#cb18-91" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-92"><a href="#cb18-92" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Register hooks</span></span>
<span id="cb18-93"><a href="#cb18-93" aria-hidden="true" tabindex="-1"></a>    hooks <span class="op">=</span> []</span>
<span id="cb18-94"><a href="#cb18-94" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> module <span class="kw">in</span> model.modules():</span>
<span id="cb18-95"><a href="#cb18-95" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(module, MambaBlock):</span>
<span id="cb18-96"><a href="#cb18-96" aria-hidden="true" tabindex="-1"></a>            hooks.append(module.register_forward_hook(state_hook))</span>
<span id="cb18-97"><a href="#cb18-97" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-98"><a href="#cb18-98" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forward pass</span></span>
<span id="cb18-99"><a href="#cb18-99" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb18-100"><a href="#cb18-100" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(input_sequence)</span>
<span id="cb18-101"><a href="#cb18-101" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-102"><a href="#cb18-102" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Remove hooks</span></span>
<span id="cb18-103"><a href="#cb18-103" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> hook <span class="kw">in</span> hooks:</span>
<span id="cb18-104"><a href="#cb18-104" aria-hidden="true" tabindex="-1"></a>        hook.remove()</span>
<span id="cb18-105"><a href="#cb18-105" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-106"><a href="#cb18-106" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Analyze state dynamics</span></span>
<span id="cb18-107"><a href="#cb18-107" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> states:</span>
<span id="cb18-108"><a href="#cb18-108" aria-hidden="true" tabindex="-1"></a>        state_tensor <span class="op">=</span> torch.stack(states, dim<span class="op">=</span><span class="dv">0</span>)  <span class="co"># (layers, batch, seq_len, state_dim)</span></span>
<span id="cb18-109"><a href="#cb18-109" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-110"><a href="#cb18-110" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute state change magnitudes</span></span>
<span id="cb18-111"><a href="#cb18-111" aria-hidden="true" tabindex="-1"></a>        state_changes <span class="op">=</span> torch.norm(state_tensor[<span class="dv">1</span>:] <span class="op">-</span> state_tensor[:<span class="op">-</span><span class="dv">1</span>], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb18-112"><a href="#cb18-112" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-113"><a href="#cb18-113" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Identify critical transition points</span></span>
<span id="cb18-114"><a href="#cb18-114" aria-hidden="true" tabindex="-1"></a>        mean_change <span class="op">=</span> state_changes.mean()</span>
<span id="cb18-115"><a href="#cb18-115" aria-hidden="true" tabindex="-1"></a>        std_change <span class="op">=</span> state_changes.std()</span>
<span id="cb18-116"><a href="#cb18-116" aria-hidden="true" tabindex="-1"></a>        critical_points <span class="op">=</span> torch.where(state_changes <span class="op">&gt;</span> mean_change <span class="op">+</span> <span class="dv">2</span> <span class="op">*</span> std_change)</span>
<span id="cb18-117"><a href="#cb18-117" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-118"><a href="#cb18-118" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb18-119"><a href="#cb18-119" aria-hidden="true" tabindex="-1"></a>            <span class="st">'states'</span>: state_tensor,</span>
<span id="cb18-120"><a href="#cb18-120" aria-hidden="true" tabindex="-1"></a>            <span class="st">'state_changes'</span>: state_changes,</span>
<span id="cb18-121"><a href="#cb18-121" aria-hidden="true" tabindex="-1"></a>            <span class="st">'critical_points'</span>: critical_points</span>
<span id="cb18-122"><a href="#cb18-122" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-123"><a href="#cb18-123" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-124"><a href="#cb18-124" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">'states'</span>: <span class="va">None</span>, <span class="st">'state_changes'</span>: <span class="va">None</span>, <span class="st">'critical_points'</span>: <span class="va">None</span>}</span></code></pre></div></div>
</div>
</section>
</section>
<section id="production-deployment" class="level2">
<h2 class="anchored" data-anchor-id="production-deployment" id="production-deployment">Production Deployment</h2>
<section id="model-serving-with-fastapi" class="level3">
<h3 class="anchored" data-anchor-id="model-serving-with-fastapi" id="model-serving-with-fastapi">Model Serving with FastAPI</h3>
<div id="fdf6d87e" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> fastapi <span class="im">import</span> FastAPI, HTTPException</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pydantic <span class="im">import</span> BaseModel</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> uvicorn</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> asyncio</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Optional</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>app <span class="op">=</span> FastAPI(title<span class="op">=</span><span class="st">"Mamba Model API"</span>)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> GenerationRequest(BaseModel):</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Request model for text generation"""</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>    prompt: <span class="bu">str</span></span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    max_length: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100</span></span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>    temperature: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.8</span></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>    top_p: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.95</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>    num_return_sequences: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> GenerationResponse(BaseModel):</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Response model for text generation"""</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>    generated_texts: List[<span class="bu">str</span>]</span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>    generation_time: <span class="bu">float</span></span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MambaServer:</span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Production server for Mamba model inference"""</span></span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path: <span class="bu">str</span>, device: <span class="bu">str</span> <span class="op">=</span> <span class="st">"cuda"</span>):</span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>.load_model(model_path, device)</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.tokenizer <span class="op">=</span> <span class="va">self</span>.load_tokenizer(model_path)</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_model(<span class="va">self</span>, model_path: <span class="bu">str</span>, device: <span class="bu">str</span>):</span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load optimized Mamba model for inference"""</span></span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> Mamba.from_pretrained(model_path)</span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> model.half().to(device)</span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compile for faster inference</span></span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> model</span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_tokenizer(<span class="va">self</span>, model_path: <span class="bu">str</span>):</span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load tokenizer"""</span></span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Assuming using HuggingFace tokenizer</span></span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> transformers <span class="im">import</span> AutoTokenizer</span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> AutoTokenizer.from_pretrained(model_path)</span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-48"><a href="#cb19-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">async</span> <span class="kw">def</span> generate(<span class="va">self</span>, request: GenerationRequest) <span class="op">-&gt;</span> GenerationResponse:</span>
<span id="cb19-49"><a href="#cb19-49" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate text asynchronously"""</span></span>
<span id="cb19-50"><a href="#cb19-50" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb19-51"><a href="#cb19-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-52"><a href="#cb19-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb19-53"><a href="#cb19-53" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Tokenize input</span></span>
<span id="cb19-54"><a href="#cb19-54" aria-hidden="true" tabindex="-1"></a>            input_ids <span class="op">=</span> <span class="va">self</span>.tokenizer.encode(</span>
<span id="cb19-55"><a href="#cb19-55" aria-hidden="true" tabindex="-1"></a>                request.prompt, </span>
<span id="cb19-56"><a href="#cb19-56" aria-hidden="true" tabindex="-1"></a>                return_tensors<span class="op">=</span><span class="st">'pt'</span></span>
<span id="cb19-57"><a href="#cb19-57" aria-hidden="true" tabindex="-1"></a>            ).to(<span class="va">self</span>.device)</span>
<span id="cb19-58"><a href="#cb19-58" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-59"><a href="#cb19-59" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Generate</span></span>
<span id="cb19-60"><a href="#cb19-60" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb19-61"><a href="#cb19-61" aria-hidden="true" tabindex="-1"></a>                generated_sequences <span class="op">=</span> []</span>
<span id="cb19-62"><a href="#cb19-62" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb19-63"><a href="#cb19-63" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(request.num_return_sequences):</span>
<span id="cb19-64"><a href="#cb19-64" aria-hidden="true" tabindex="-1"></a>                    generated_ids <span class="op">=</span> <span class="cf">await</span> <span class="va">self</span>.generate_sequence(</span>
<span id="cb19-65"><a href="#cb19-65" aria-hidden="true" tabindex="-1"></a>                        input_ids, </span>
<span id="cb19-66"><a href="#cb19-66" aria-hidden="true" tabindex="-1"></a>                        request.max_length,</span>
<span id="cb19-67"><a href="#cb19-67" aria-hidden="true" tabindex="-1"></a>                        request.temperature,</span>
<span id="cb19-68"><a href="#cb19-68" aria-hidden="true" tabindex="-1"></a>                        request.top_p</span>
<span id="cb19-69"><a href="#cb19-69" aria-hidden="true" tabindex="-1"></a>                    )</span>
<span id="cb19-70"><a href="#cb19-70" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb19-71"><a href="#cb19-71" aria-hidden="true" tabindex="-1"></a>                    generated_text <span class="op">=</span> <span class="va">self</span>.tokenizer.decode(</span>
<span id="cb19-72"><a href="#cb19-72" aria-hidden="true" tabindex="-1"></a>                        generated_ids[<span class="dv">0</span>], </span>
<span id="cb19-73"><a href="#cb19-73" aria-hidden="true" tabindex="-1"></a>                        skip_special_tokens<span class="op">=</span><span class="va">True</span></span>
<span id="cb19-74"><a href="#cb19-74" aria-hidden="true" tabindex="-1"></a>                    )</span>
<span id="cb19-75"><a href="#cb19-75" aria-hidden="true" tabindex="-1"></a>                    generated_sequences.append(generated_text)</span>
<span id="cb19-76"><a href="#cb19-76" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-77"><a href="#cb19-77" aria-hidden="true" tabindex="-1"></a>            generation_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb19-78"><a href="#cb19-78" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-79"><a href="#cb19-79" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> GenerationResponse(</span>
<span id="cb19-80"><a href="#cb19-80" aria-hidden="true" tabindex="-1"></a>                generated_texts<span class="op">=</span>generated_sequences,</span>
<span id="cb19-81"><a href="#cb19-81" aria-hidden="true" tabindex="-1"></a>                generation_time<span class="op">=</span>generation_time</span>
<span id="cb19-82"><a href="#cb19-82" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb19-83"><a href="#cb19-83" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-84"><a href="#cb19-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb19-85"><a href="#cb19-85" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">500</span>, detail<span class="op">=</span><span class="bu">str</span>(e))</span>
<span id="cb19-86"><a href="#cb19-86" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-87"><a href="#cb19-87" aria-hidden="true" tabindex="-1"></a>    <span class="cf">async</span> <span class="kw">def</span> generate_sequence(<span class="va">self</span>, input_ids, max_length, temperature, top_p):</span>
<span id="cb19-88"><a href="#cb19-88" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate a single sequence with top-p sampling"""</span></span>
<span id="cb19-89"><a href="#cb19-89" aria-hidden="true" tabindex="-1"></a>        current_ids <span class="op">=</span> input_ids.clone()</span>
<span id="cb19-90"><a href="#cb19-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-91"><a href="#cb19-91" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(max_length):</span>
<span id="cb19-92"><a href="#cb19-92" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Run inference in thread pool to avoid blocking</span></span>
<span id="cb19-93"><a href="#cb19-93" aria-hidden="true" tabindex="-1"></a>            logits <span class="op">=</span> <span class="cf">await</span> asyncio.get_event_loop().run_in_executor(</span>
<span id="cb19-94"><a href="#cb19-94" aria-hidden="true" tabindex="-1"></a>                <span class="va">None</span>, <span class="kw">lambda</span>: <span class="va">self</span>.model(current_ids)</span>
<span id="cb19-95"><a href="#cb19-95" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb19-96"><a href="#cb19-96" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-97"><a href="#cb19-97" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample next token</span></span>
<span id="cb19-98"><a href="#cb19-98" aria-hidden="true" tabindex="-1"></a>            next_token_logits <span class="op">=</span> logits[:, <span class="op">-</span><span class="dv">1</span>, :] <span class="op">/</span> temperature</span>
<span id="cb19-99"><a href="#cb19-99" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-100"><a href="#cb19-100" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Top-p sampling</span></span>
<span id="cb19-101"><a href="#cb19-101" aria-hidden="true" tabindex="-1"></a>            sorted_logits, sorted_indices <span class="op">=</span> torch.sort(next_token_logits, descending<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb19-102"><a href="#cb19-102" aria-hidden="true" tabindex="-1"></a>            cumulative_probs <span class="op">=</span> torch.cumsum(F.softmax(sorted_logits, dim<span class="op">=-</span><span class="dv">1</span>), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb19-103"><a href="#cb19-103" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-104"><a href="#cb19-104" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Remove tokens with cumulative probability above threshold</span></span>
<span id="cb19-105"><a href="#cb19-105" aria-hidden="true" tabindex="-1"></a>            sorted_indices_to_remove <span class="op">=</span> cumulative_probs <span class="op">&gt;</span> top_p</span>
<span id="cb19-106"><a href="#cb19-106" aria-hidden="true" tabindex="-1"></a>            sorted_indices_to_remove[..., <span class="dv">1</span>:] <span class="op">=</span> sorted_indices_to_remove[..., :<span class="op">-</span><span class="dv">1</span>].clone()</span>
<span id="cb19-107"><a href="#cb19-107" aria-hidden="true" tabindex="-1"></a>            sorted_indices_to_remove[..., <span class="dv">0</span>] <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb19-108"><a href="#cb19-108" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-109"><a href="#cb19-109" aria-hidden="true" tabindex="-1"></a>            indices_to_remove <span class="op">=</span> sorted_indices_to_remove.scatter(<span class="dv">1</span>, sorted_indices, sorted_indices_to_remove)</span>
<span id="cb19-110"><a href="#cb19-110" aria-hidden="true" tabindex="-1"></a>            next_token_logits[indices_to_remove] <span class="op">=</span> <span class="op">-</span><span class="bu">float</span>(<span class="st">'Inf'</span>)</span>
<span id="cb19-111"><a href="#cb19-111" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-112"><a href="#cb19-112" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample</span></span>
<span id="cb19-113"><a href="#cb19-113" aria-hidden="true" tabindex="-1"></a>            probs <span class="op">=</span> F.softmax(next_token_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb19-114"><a href="#cb19-114" aria-hidden="true" tabindex="-1"></a>            next_token <span class="op">=</span> torch.multinomial(probs, num_samples<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb19-115"><a href="#cb19-115" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-116"><a href="#cb19-116" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Append token</span></span>
<span id="cb19-117"><a href="#cb19-117" aria-hidden="true" tabindex="-1"></a>            current_ids <span class="op">=</span> torch.cat([current_ids, next_token], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb19-118"><a href="#cb19-118" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-119"><a href="#cb19-119" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Check for end token</span></span>
<span id="cb19-120"><a href="#cb19-120" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> next_token.item() <span class="op">==</span> <span class="va">self</span>.tokenizer.eos_token_id:</span>
<span id="cb19-121"><a href="#cb19-121" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb19-122"><a href="#cb19-122" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-123"><a href="#cb19-123" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> current_ids</span>
<span id="cb19-124"><a href="#cb19-124" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-125"><a href="#cb19-125" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize server</span></span>
<span id="cb19-126"><a href="#cb19-126" aria-hidden="true" tabindex="-1"></a><span class="co"># mamba_server = MambaServer("path/to/mamba/model")</span></span>
<span id="cb19-127"><a href="#cb19-127" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-128"><a href="#cb19-128" aria-hidden="true" tabindex="-1"></a><span class="at">@app.post</span>(<span class="st">"/generate"</span>, response_model<span class="op">=</span>GenerationResponse)</span>
<span id="cb19-129"><a href="#cb19-129" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> generate_text(request: GenerationRequest):</span>
<span id="cb19-130"><a href="#cb19-130" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""API endpoint for text generation"""</span></span>
<span id="cb19-131"><a href="#cb19-131" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="cf">await</span> mamba_server.generate(request)</span>
<span id="cb19-132"><a href="#cb19-132" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-133"><a href="#cb19-133" aria-hidden="true" tabindex="-1"></a><span class="at">@app.get</span>(<span class="st">"/health"</span>)</span>
<span id="cb19-134"><a href="#cb19-134" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> health_check():</span>
<span id="cb19-135"><a href="#cb19-135" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Health check endpoint"""</span></span>
<span id="cb19-136"><a href="#cb19-136" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">"status"</span>: <span class="st">"healthy"</span>}</span>
<span id="cb19-137"><a href="#cb19-137" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-138"><a href="#cb19-138" aria-hidden="true" tabindex="-1"></a><span class="co"># if __name__ == "__main__":</span></span>
<span id="cb19-139"><a href="#cb19-139" aria-hidden="true" tabindex="-1"></a><span class="co">#     uvicorn.run(app, host="0.0.0.0", port=8000)</span></span></code></pre></div></div>
</div>
</section>
<section id="distributed-training-setup" class="level3">
<h3 class="anchored" data-anchor-id="distributed-training-setup" id="distributed-training-setup">Distributed Training Setup</h3>
<div id="aa9765b8" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.distributed <span class="im">as</span> dist</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.nn.parallel <span class="im">import</span> DistributedDataParallel <span class="im">as</span> DDP</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data.distributed <span class="im">import</span> DistributedSampler</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DistributedMambaTrainer:</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Distributed trainer for large-scale Mamba training"""</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, config, train_dataset, val_dataset):</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_dataset <span class="op">=</span> train_dataset</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_dataset <span class="op">=</span> val_dataset</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize distributed training</span></span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.setup_distributed()</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup model</span></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>.setup_model(model)</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup data loaders</span></span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_loader, <span class="va">self</span>.val_loader <span class="op">=</span> <span class="va">self</span>.setup_data_loaders()</span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup optimizer and scheduler</span></span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> create_optimizer(<span class="va">self</span>.model, config)</span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler <span class="op">=</span> <span class="va">self</span>.create_scheduler()</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_distributed(<span class="va">self</span>):</span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize distributed training environment"""</span></span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>        dist.init_process_group(backend<span class="op">=</span><span class="st">'nccl'</span>)</span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.local_rank <span class="op">=</span> <span class="bu">int</span>(os.environ[<span class="st">'LOCAL_RANK'</span>])</span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_rank <span class="op">=</span> <span class="bu">int</span>(os.environ[<span class="st">'RANK'</span>])</span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.world_size <span class="op">=</span> <span class="bu">int</span>(os.environ[<span class="st">'WORLD_SIZE'</span>])</span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>        torch.cuda.set_device(<span class="va">self</span>.local_rank)</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_model(<span class="va">self</span>, model):</span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup model for distributed training"""</span></span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> model.to(<span class="va">self</span>.local_rank)</span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Wrap with DDP</span></span>
<span id="cb20-42"><a href="#cb20-42" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> DDP(</span>
<span id="cb20-43"><a href="#cb20-43" aria-hidden="true" tabindex="-1"></a>            model, </span>
<span id="cb20-44"><a href="#cb20-44" aria-hidden="true" tabindex="-1"></a>            device_ids<span class="op">=</span>[<span class="va">self</span>.local_rank],</span>
<span id="cb20-45"><a href="#cb20-45" aria-hidden="true" tabindex="-1"></a>            find_unused_parameters<span class="op">=</span><span class="va">False</span></span>
<span id="cb20-46"><a href="#cb20-46" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-47"><a href="#cb20-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-48"><a href="#cb20-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> model</span>
<span id="cb20-49"><a href="#cb20-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-50"><a href="#cb20-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_data_loaders(<span class="va">self</span>):</span>
<span id="cb20-51"><a href="#cb20-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup distributed data loaders"""</span></span>
<span id="cb20-52"><a href="#cb20-52" aria-hidden="true" tabindex="-1"></a>        train_sampler <span class="op">=</span> DistributedSampler(</span>
<span id="cb20-53"><a href="#cb20-53" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.train_dataset,</span>
<span id="cb20-54"><a href="#cb20-54" aria-hidden="true" tabindex="-1"></a>            num_replicas<span class="op">=</span><span class="va">self</span>.world_size,</span>
<span id="cb20-55"><a href="#cb20-55" aria-hidden="true" tabindex="-1"></a>            rank<span class="op">=</span><span class="va">self</span>.global_rank,</span>
<span id="cb20-56"><a href="#cb20-56" aria-hidden="true" tabindex="-1"></a>            shuffle<span class="op">=</span><span class="va">True</span></span>
<span id="cb20-57"><a href="#cb20-57" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-58"><a href="#cb20-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-59"><a href="#cb20-59" aria-hidden="true" tabindex="-1"></a>        val_sampler <span class="op">=</span> DistributedSampler(</span>
<span id="cb20-60"><a href="#cb20-60" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.val_dataset,</span>
<span id="cb20-61"><a href="#cb20-61" aria-hidden="true" tabindex="-1"></a>            num_replicas<span class="op">=</span><span class="va">self</span>.world_size,</span>
<span id="cb20-62"><a href="#cb20-62" aria-hidden="true" tabindex="-1"></a>            rank<span class="op">=</span><span class="va">self</span>.global_rank,</span>
<span id="cb20-63"><a href="#cb20-63" aria-hidden="true" tabindex="-1"></a>            shuffle<span class="op">=</span><span class="va">False</span></span>
<span id="cb20-64"><a href="#cb20-64" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-65"><a href="#cb20-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-66"><a href="#cb20-66" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb20-67"><a href="#cb20-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-68"><a href="#cb20-68" aria-hidden="true" tabindex="-1"></a>        train_loader <span class="op">=</span> DataLoader(</span>
<span id="cb20-69"><a href="#cb20-69" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.train_dataset,</span>
<span id="cb20-70"><a href="#cb20-70" aria-hidden="true" tabindex="-1"></a>            batch_size<span class="op">=</span><span class="va">self</span>.config.batch_size,</span>
<span id="cb20-71"><a href="#cb20-71" aria-hidden="true" tabindex="-1"></a>            sampler<span class="op">=</span>train_sampler,</span>
<span id="cb20-72"><a href="#cb20-72" aria-hidden="true" tabindex="-1"></a>            num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb20-73"><a href="#cb20-73" aria-hidden="true" tabindex="-1"></a>            pin_memory<span class="op">=</span><span class="va">True</span></span>
<span id="cb20-74"><a href="#cb20-74" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-75"><a href="#cb20-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-76"><a href="#cb20-76" aria-hidden="true" tabindex="-1"></a>        val_loader <span class="op">=</span> DataLoader(</span>
<span id="cb20-77"><a href="#cb20-77" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.val_dataset,</span>
<span id="cb20-78"><a href="#cb20-78" aria-hidden="true" tabindex="-1"></a>            batch_size<span class="op">=</span><span class="va">self</span>.config.batch_size,</span>
<span id="cb20-79"><a href="#cb20-79" aria-hidden="true" tabindex="-1"></a>            sampler<span class="op">=</span>val_sampler,</span>
<span id="cb20-80"><a href="#cb20-80" aria-hidden="true" tabindex="-1"></a>            num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb20-81"><a href="#cb20-81" aria-hidden="true" tabindex="-1"></a>            pin_memory<span class="op">=</span><span class="va">True</span></span>
<span id="cb20-82"><a href="#cb20-82" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-83"><a href="#cb20-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-84"><a href="#cb20-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> train_loader, val_loader</span>
<span id="cb20-85"><a href="#cb20-85" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-86"><a href="#cb20-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train(<span class="va">self</span>):</span>
<span id="cb20-87"><a href="#cb20-87" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Main distributed training loop"""</span></span>
<span id="cb20-88"><a href="#cb20-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.config.num_epochs):</span>
<span id="cb20-89"><a href="#cb20-89" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.train_loader.sampler.set_epoch(epoch)</span>
<span id="cb20-90"><a href="#cb20-90" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-91"><a href="#cb20-91" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Training</span></span>
<span id="cb20-92"><a href="#cb20-92" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.train()</span>
<span id="cb20-93"><a href="#cb20-93" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">=</span> <span class="va">self</span>.train_epoch()</span>
<span id="cb20-94"><a href="#cb20-94" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-95"><a href="#cb20-95" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validation</span></span>
<span id="cb20-96"><a href="#cb20-96" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.global_rank <span class="op">==</span> <span class="dv">0</span>:  <span class="co"># Only on main process</span></span>
<span id="cb20-97"><a href="#cb20-97" aria-hidden="true" tabindex="-1"></a>                val_loss <span class="op">=</span> <span class="va">self</span>.validate()</span>
<span id="cb20-98"><a href="#cb20-98" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss">, Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb20-99"><a href="#cb20-99" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb20-100"><a href="#cb20-100" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Save checkpoint</span></span>
<span id="cb20-101"><a href="#cb20-101" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.save_checkpoint(epoch, train_loss, val_loss)</span>
<span id="cb20-102"><a href="#cb20-102" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-103"><a href="#cb20-103" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_epoch(<span class="va">self</span>):</span>
<span id="cb20-104"><a href="#cb20-104" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train for one epoch with distributed synchronization"""</span></span>
<span id="cb20-105"><a href="#cb20-105" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb20-106"><a href="#cb20-106" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb20-107"><a href="#cb20-107" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-108"><a href="#cb20-108" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> <span class="va">self</span>.train_loader:</span>
<span id="cb20-109"><a href="#cb20-109" aria-hidden="true" tabindex="-1"></a>            input_ids <span class="op">=</span> batch[<span class="st">'input_ids'</span>].to(<span class="va">self</span>.local_rank)</span>
<span id="cb20-110"><a href="#cb20-110" aria-hidden="true" tabindex="-1"></a>            targets <span class="op">=</span> input_ids[:, <span class="dv">1</span>:].contiguous()</span>
<span id="cb20-111"><a href="#cb20-111" aria-hidden="true" tabindex="-1"></a>            input_ids <span class="op">=</span> input_ids[:, :<span class="op">-</span><span class="dv">1</span>].contiguous()</span>
<span id="cb20-112"><a href="#cb20-112" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-113"><a href="#cb20-113" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb20-114"><a href="#cb20-114" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb20-115"><a href="#cb20-115" aria-hidden="true" tabindex="-1"></a>                logits <span class="op">=</span> <span class="va">self</span>.model(input_ids)</span>
<span id="cb20-116"><a href="#cb20-116" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> F.cross_entropy(</span>
<span id="cb20-117"><a href="#cb20-117" aria-hidden="true" tabindex="-1"></a>                    logits.view(<span class="op">-</span><span class="dv">1</span>, logits.size(<span class="op">-</span><span class="dv">1</span>)),</span>
<span id="cb20-118"><a href="#cb20-118" aria-hidden="true" tabindex="-1"></a>                    targets.view(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb20-119"><a href="#cb20-119" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb20-120"><a href="#cb20-120" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-121"><a href="#cb20-121" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass</span></span>
<span id="cb20-122"><a href="#cb20-122" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb20-123"><a href="#cb20-123" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb20-124"><a href="#cb20-124" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-125"><a href="#cb20-125" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping</span></span>
<span id="cb20-126"><a href="#cb20-126" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.model.parameters(), <span class="fl">1.0</span>)</span>
<span id="cb20-127"><a href="#cb20-127" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-128"><a href="#cb20-128" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.step()</span>
<span id="cb20-129"><a href="#cb20-129" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scheduler.step()</span>
<span id="cb20-130"><a href="#cb20-130" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-131"><a href="#cb20-131" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb20-132"><a href="#cb20-132" aria-hidden="true" tabindex="-1"></a>            num_batches <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb20-133"><a href="#cb20-133" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-134"><a href="#cb20-134" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Average loss across all processes</span></span>
<span id="cb20-135"><a href="#cb20-135" aria-hidden="true" tabindex="-1"></a>        avg_loss <span class="op">=</span> total_loss <span class="op">/</span> num_batches</span>
<span id="cb20-136"><a href="#cb20-136" aria-hidden="true" tabindex="-1"></a>        loss_tensor <span class="op">=</span> torch.tensor(avg_loss).to(<span class="va">self</span>.local_rank)</span>
<span id="cb20-137"><a href="#cb20-137" aria-hidden="true" tabindex="-1"></a>        dist.all_reduce(loss_tensor, op<span class="op">=</span>dist.ReduceOp.AVG)</span>
<span id="cb20-138"><a href="#cb20-138" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-139"><a href="#cb20-139" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss_tensor.item()</span>
<span id="cb20-140"><a href="#cb20-140" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-141"><a href="#cb20-141" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> save_checkpoint(<span class="va">self</span>, epoch, train_loss, val_loss):</span>
<span id="cb20-142"><a href="#cb20-142" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Save training checkpoint"""</span></span>
<span id="cb20-143"><a href="#cb20-143" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.global_rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb20-144"><a href="#cb20-144" aria-hidden="true" tabindex="-1"></a>            checkpoint <span class="op">=</span> {</span>
<span id="cb20-145"><a href="#cb20-145" aria-hidden="true" tabindex="-1"></a>                <span class="st">'epoch'</span>: epoch,</span>
<span id="cb20-146"><a href="#cb20-146" aria-hidden="true" tabindex="-1"></a>                <span class="st">'model_state_dict'</span>: <span class="va">self</span>.model.module.state_dict(),</span>
<span id="cb20-147"><a href="#cb20-147" aria-hidden="true" tabindex="-1"></a>                <span class="st">'optimizer_state_dict'</span>: <span class="va">self</span>.optimizer.state_dict(),</span>
<span id="cb20-148"><a href="#cb20-148" aria-hidden="true" tabindex="-1"></a>                <span class="st">'scheduler_state_dict'</span>: <span class="va">self</span>.scheduler.state_dict(),</span>
<span id="cb20-149"><a href="#cb20-149" aria-hidden="true" tabindex="-1"></a>                <span class="st">'train_loss'</span>: train_loss,</span>
<span id="cb20-150"><a href="#cb20-150" aria-hidden="true" tabindex="-1"></a>                <span class="st">'val_loss'</span>: val_loss,</span>
<span id="cb20-151"><a href="#cb20-151" aria-hidden="true" tabindex="-1"></a>                <span class="st">'config'</span>: <span class="va">self</span>.config</span>
<span id="cb20-152"><a href="#cb20-152" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb20-153"><a href="#cb20-153" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-154"><a href="#cb20-154" aria-hidden="true" tabindex="-1"></a>            torch.save(checkpoint, <span class="ss">f'checkpoint_epoch_</span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">.pt'</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="experimental-features" class="level2">
<h2 class="anchored" data-anchor-id="experimental-features" id="experimental-features">Experimental Features</h2>
<section id="adaptive-computation-time-act" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-computation-time-act" id="adaptive-computation-time-act">Adaptive Computation Time (ACT)</h3>
<div id="26d8a7ee" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ACTMamba(nn.Module):</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Mamba with Adaptive Computation Time"""</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, max_computation_steps<span class="op">=</span><span class="dv">10</span>, threshold<span class="op">=</span><span class="fl">0.99</span>):</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_computation_steps <span class="op">=</span> max_computation_steps</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.threshold <span class="op">=</span> threshold</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mamba layer</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mamba <span class="op">=</span> MambaBlock(d_model)</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Halting probability predictor</span></span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.halting_predictor <span class="op">=</span> nn.Linear(d_model, <span class="dv">1</span>)</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass with adaptive computation time</span></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a><span class="co">        x : torch.Tensor</span></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a><span class="co">            Input tensor (batch_size, seq_len, d_model)</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a><span class="co">        tuple</span></span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a><span class="co">            (output, ponder_cost) where ponder_cost is regularization term</span></span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len, d_model <span class="op">=</span> x.shape</span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize states</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>        state <span class="op">=</span> x</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>        halting_probs <span class="op">=</span> torch.zeros(batch_size, seq_len, <span class="dv">1</span>, device<span class="op">=</span>x.device)</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>        remainders <span class="op">=</span> torch.ones(batch_size, seq_len, <span class="dv">1</span>, device<span class="op">=</span>x.device)</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a>        n_updates <span class="op">=</span> torch.zeros(batch_size, seq_len, <span class="dv">1</span>, device<span class="op">=</span>x.device)</span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.zeros_like(x)</span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> step <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.max_computation_steps):</span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Predict halting probability</span></span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a>            p <span class="op">=</span> torch.sigmoid(<span class="va">self</span>.halting_predictor(state))</span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update halting probabilities</span></span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a>            still_running <span class="op">=</span> (halting_probs <span class="op">&lt;</span> <span class="va">self</span>.threshold).<span class="bu">float</span>()</span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a>            new_halted <span class="op">=</span> (halting_probs <span class="op">+</span> p <span class="op">*</span> still_running <span class="op">&gt;=</span> <span class="va">self</span>.threshold).<span class="bu">float</span>()</span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a>            still_running <span class="op">=</span> still_running <span class="op">-</span> new_halted</span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update remainder for newly halted</span></span>
<span id="cb21-49"><a href="#cb21-49" aria-hidden="true" tabindex="-1"></a>            halting_probs <span class="op">=</span> halting_probs <span class="op">+</span> p <span class="op">*</span> still_running</span>
<span id="cb21-50"><a href="#cb21-50" aria-hidden="true" tabindex="-1"></a>            remainders <span class="op">=</span> remainders <span class="op">-</span> p <span class="op">*</span> still_running</span>
<span id="cb21-51"><a href="#cb21-51" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-52"><a href="#cb21-52" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Weight for this step</span></span>
<span id="cb21-53"><a href="#cb21-53" aria-hidden="true" tabindex="-1"></a>            step_weight <span class="op">=</span> p <span class="op">*</span> still_running <span class="op">+</span> new_halted <span class="op">*</span> remainders</span>
<span id="cb21-54"><a href="#cb21-54" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-55"><a href="#cb21-55" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Apply Mamba transformation</span></span>
<span id="cb21-56"><a href="#cb21-56" aria-hidden="true" tabindex="-1"></a>            transformed_state <span class="op">=</span> <span class="va">self</span>.mamba(state)</span>
<span id="cb21-57"><a href="#cb21-57" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-58"><a href="#cb21-58" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update output</span></span>
<span id="cb21-59"><a href="#cb21-59" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> output <span class="op">+</span> step_weight <span class="op">*</span> transformed_state</span>
<span id="cb21-60"><a href="#cb21-60" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-61"><a href="#cb21-61" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update state for next iteration</span></span>
<span id="cb21-62"><a href="#cb21-62" aria-hidden="true" tabindex="-1"></a>            state <span class="op">=</span> transformed_state</span>
<span id="cb21-63"><a href="#cb21-63" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-64"><a href="#cb21-64" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update computation counter</span></span>
<span id="cb21-65"><a href="#cb21-65" aria-hidden="true" tabindex="-1"></a>            n_updates <span class="op">=</span> n_updates <span class="op">+</span> still_running <span class="op">+</span> new_halted</span>
<span id="cb21-66"><a href="#cb21-66" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-67"><a href="#cb21-67" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Check if all sequences have halted</span></span>
<span id="cb21-68"><a href="#cb21-68" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> (halting_probs <span class="op">&gt;=</span> <span class="va">self</span>.threshold).<span class="bu">all</span>():</span>
<span id="cb21-69"><a href="#cb21-69" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb21-70"><a href="#cb21-70" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-71"><a href="#cb21-71" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Ponder cost (regularization term)</span></span>
<span id="cb21-72"><a href="#cb21-72" aria-hidden="true" tabindex="-1"></a>        ponder_cost <span class="op">=</span> n_updates.mean()</span>
<span id="cb21-73"><a href="#cb21-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-74"><a href="#cb21-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output, ponder_cost</span></code></pre></div></div>
</div>
</section>
<section id="hierarchical-processing" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-processing" id="hierarchical-processing">Hierarchical Processing</h3>
<div id="5e5f6f45" class="cell" data-execution_count="21">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> HierarchicalMamba(nn.Module):</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Hierarchical Mamba for multi-scale processing"""</span></span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, n_layer, hierarchy_levels<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.hierarchy_levels <span class="op">=</span> hierarchy_levels</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Different Mamba blocks for different hierarchical levels</span></span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.local_mamba <span class="op">=</span> nn.ModuleList([</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>            MambaBlock(d_model, d_state<span class="op">=</span><span class="dv">16</span>) </span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_layer <span class="op">//</span> hierarchy_levels)</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_mamba <span class="op">=</span> nn.ModuleList([</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>            MambaBlock(d_model, d_state<span class="op">=</span><span class="dv">32</span>) </span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_layer <span class="op">//</span> hierarchy_levels)</span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cross_hierarchy <span class="op">=</span> nn.ModuleList([</span>
<span id="cb22-21"><a href="#cb22-21" aria-hidden="true" tabindex="-1"></a>            nn.MultiheadAttention(d_model, num_heads<span class="op">=</span><span class="dv">8</span>) </span>
<span id="cb22-22"><a href="#cb22-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(hierarchy_levels)</span>
<span id="cb22-23"><a href="#cb22-23" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb22-24"><a href="#cb22-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-25"><a href="#cb22-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb22-26"><a href="#cb22-26" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb22-27"><a href="#cb22-27" aria-hidden="true" tabindex="-1"></a><span class="co">        Hierarchical processing of input</span></span>
<span id="cb22-28"><a href="#cb22-28" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb22-29"><a href="#cb22-29" aria-hidden="true" tabindex="-1"></a><span class="co">        Parameters:</span></span>
<span id="cb22-30"><a href="#cb22-30" aria-hidden="true" tabindex="-1"></a><span class="co">        -----------</span></span>
<span id="cb22-31"><a href="#cb22-31" aria-hidden="true" tabindex="-1"></a><span class="co">        x : torch.Tensor</span></span>
<span id="cb22-32"><a href="#cb22-32" aria-hidden="true" tabindex="-1"></a><span class="co">            Input tensor (batch_size, seq_len, d_model)</span></span>
<span id="cb22-33"><a href="#cb22-33" aria-hidden="true" tabindex="-1"></a><span class="co">            </span></span>
<span id="cb22-34"><a href="#cb22-34" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb22-35"><a href="#cb22-35" aria-hidden="true" tabindex="-1"></a><span class="co">        --------</span></span>
<span id="cb22-36"><a href="#cb22-36" aria-hidden="true" tabindex="-1"></a><span class="co">        torch.Tensor</span></span>
<span id="cb22-37"><a href="#cb22-37" aria-hidden="true" tabindex="-1"></a><span class="co">            Hierarchically processed output</span></span>
<span id="cb22-38"><a href="#cb22-38" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb22-39"><a href="#cb22-39" aria-hidden="true" tabindex="-1"></a>        local_features <span class="op">=</span> x</span>
<span id="cb22-40"><a href="#cb22-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-41"><a href="#cb22-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process at local level</span></span>
<span id="cb22-42"><a href="#cb22-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.local_mamba:</span>
<span id="cb22-43"><a href="#cb22-43" aria-hidden="true" tabindex="-1"></a>            local_features <span class="op">=</span> layer(local_features)</span>
<span id="cb22-44"><a href="#cb22-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-45"><a href="#cb22-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global processing (with downsampling)</span></span>
<span id="cb22-46"><a href="#cb22-46" aria-hidden="true" tabindex="-1"></a>        global_features <span class="op">=</span> local_features[:, ::<span class="dv">4</span>, :]  <span class="co"># Sample every 4th token</span></span>
<span id="cb22-47"><a href="#cb22-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-48"><a href="#cb22-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.global_mamba:</span>
<span id="cb22-49"><a href="#cb22-49" aria-hidden="true" tabindex="-1"></a>            global_features <span class="op">=</span> layer(global_features)</span>
<span id="cb22-50"><a href="#cb22-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-51"><a href="#cb22-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cross-hierarchy attention</span></span>
<span id="cb22-52"><a href="#cb22-52" aria-hidden="true" tabindex="-1"></a>        enhanced_local, _ <span class="op">=</span> <span class="va">self</span>.cross_hierarchy[<span class="dv">0</span>](</span>
<span id="cb22-53"><a href="#cb22-53" aria-hidden="true" tabindex="-1"></a>            local_features, global_features, global_features</span>
<span id="cb22-54"><a href="#cb22-54" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb22-55"><a href="#cb22-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-56"><a href="#cb22-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> enhanced_local <span class="op">+</span> local_features</span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion-and-future-directions" class="level2">
<h2 class="anchored" data-anchor-id="conclusion-and-future-directions" id="conclusion-and-future-directions">Conclusion and Future Directions</h2>
<p>This comprehensive guide has covered the implementation and practical applications of Mamba transformers, from fundamental concepts to advanced optimization techniques. The key contributions of Mamba include:</p>
<section id="key-advantages-1" class="level3">
<h3 class="anchored" data-anchor-id="key-advantages-1" id="key-advantages-1">Key Advantages</h3>
<ol type="1">
<li><p><strong>Linear Complexity</strong>: Mamba achieves <span class="math inline">\(O(L)\)</span> computational complexity compared to <span class="math inline">\(O(L^2)\)</span> for traditional transformers, enabling efficient processing of long sequences.</p></li>
<li><p><strong>Selective Mechanism</strong>: The input-dependent parameterization allows the model to dynamically focus on relevant information, improving modeling capabilities.</p></li>
<li><p><strong>Hardware Efficiency</strong>: Better memory utilization and parallelization characteristics make Mamba suitable for resource-constrained environments.</p></li>
<li><p><strong>Scalability</strong>: The linear scaling properties enable processing of much longer contexts than traditional attention-based models.</p></li>
</ol>
</section>
<section id="implementation-considerations" class="level3">
<h3 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h3>
<ul>
<li><strong>State Space Modeling</strong>: The core selective scan algorithm requires careful implementation for numerical stability</li>
<li><strong>Memory Optimization</strong>: Gradient checkpointing and mixed-precision training are essential for large-scale deployment</li>
<li><strong>Custom Kernels</strong>: Production deployments benefit significantly from optimized CUDA implementations</li>
</ul>
</section>
<section id="future-research-directions" class="level3">
<h3 class="anchored" data-anchor-id="future-research-directions" id="future-research-directions">Future Research Directions</h3>
<ol type="1">
<li><strong>Theoretical Analysis</strong>: Deeper understanding of the selective mechanism’s theoretical properties</li>
<li><strong>Architecture Improvements</strong>: Exploring hybrid architectures combining Mamba with other sequence modeling approaches</li>
<li><strong>Multi-modal Applications</strong>: Extending Mamba to vision, audio, and other modalities</li>
<li><strong>Hardware Optimization</strong>: Developing specialized hardware accelerators for selective scan operations</li>
</ol>
</section>
<section id="practical-applications-1" class="level3">
<h3 class="anchored" data-anchor-id="practical-applications-1" id="practical-applications-1">Practical Applications</h3>
<p>Mamba shows particular promise for:</p>
<ul>
<li><strong>Long Document Processing</strong>: Technical documents, legal texts, and scientific papers</li>
<li><strong>Time Series Analysis</strong>: Financial data, sensor measurements, and sequential predictions<br>
</li>
<li><strong>Code Generation</strong>: Software development with large codebases and long contexts</li>
<li><strong>Conversational AI</strong>: Multi-turn dialogues with extended conversation history</li>
</ul>
<p>The Mamba architecture represents a significant advancement in sequence modeling, offering a compelling alternative to attention-based transformers with superior scalability and efficiency characteristics. As the field continues to evolve, Mamba’s linear complexity and selective processing capabilities position it as a foundation for next-generation language models and sequential AI systems.</p>
</section>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references" id="references">References</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode bibtex code-with-copy"><code class="sourceCode bibtex"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="va">@article</span>{<span class="ot">gu2023mamba</span>,</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">title</span>={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>  <span class="dt">author</span>={Gu, Albert and Dao, Tri},</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>  <span class="dt">journal</span>={arXiv preprint arXiv:2312.00752},</span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>  <span class="dt">year</span>={2023}</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a><span class="va">@article</span>{<span class="ot">gu2021efficiently</span>,</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>  <span class="dt">title</span>={Efficiently modeling long sequences with structured state spaces},</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>  <span class="dt">author</span>={Gu, Albert and Goel, Karan and R{<span class="ch">\'</span>e}, Christopher},</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>  <span class="dt">journal</span>={arXiv preprint arXiv:2111.00396},</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>  <span class="dt">year</span>={2021}</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Mathematics Behind Mamba Transformers: A Complete Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-math/</guid>
      <pubDate>Sat, 23 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="mathematics-behind-mamba-transformers-a-complete-guide" class="level1 page-columns page-full">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-math/mamath.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Mamba represents a breakthrough in sequence modeling that addresses the quadratic complexity limitation of traditional transformers. Built on State Space Models (SSMs), Mamba introduces a selective mechanism that allows the model to dynamically focus on relevant information while maintaining linear computational complexity with respect to sequence length.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Important
</div>
</div>
<div class="callout-body-container callout-body">
<p>The key innovation lies in making the SSM parameters input-dependent, creating a selective state space that can efficiently process long sequences while maintaining the modeling capabilities that made transformers successful.</p>
</div>
</div>
</section>
<section id="foundation-state-space-models" class="level2">
<h2 class="anchored" data-anchor-id="foundation-state-space-models" id="foundation-state-space-models">Foundation: State Space Models</h2>
<section id="continuous-state-space-models" class="level3">
<h3 class="anchored" data-anchor-id="continuous-state-space-models" id="continuous-state-space-models">Continuous State Space Models</h3>
<p>State Space Models originate from control theory and signal processing. In continuous time, they are defined by:</p>
<p><span id="eq-continuous-ssm"><span class="math display">\[
\begin{align}
h'(t) &amp;= Ah(t) + Bx(t) \quad \text{(state equation)} \\
y(t) &amp;= Ch(t) + Dx(t) \quad \text{(output equation)}
\end{align}
\tag{1}\]</span></span></p>
<p>Where:</p>
<ul>
<li><span class="math inline">\(h(t) \in \mathbb{R}^N\)</span> is the state vector at time t</li>
<li><span class="math inline">\(x(t) \in \mathbb{R}\)</span> is the input signal<br>
</li>
<li><span class="math inline">\(y(t) \in \mathbb{R}\)</span> is the output signal</li>
<li><span class="math inline">\(A \in \mathbb{R}^{N \times N}\)</span> is the state transition matrix</li>
<li><span class="math inline">\(B \in \mathbb{R}^N\)</span> is the input matrix</li>
<li><span class="math inline">\(C \in \mathbb{R}^{1 \times N}\)</span> is the output matrix</li>
<li><span class="math inline">\(D \in \mathbb{R}\)</span> is the feedthrough term (often set to 0)</li>
</ul>
</section>
<section id="the-hippo-framework" class="level3">
<h3 class="anchored" data-anchor-id="the-hippo-framework" id="the-hippo-framework">The HiPPO Framework</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>HiPPO Framework
</div>
</div>
<div class="callout-body-container callout-body">
<p>The HiPPO (High-order Polynomial Projection Operators) framework provides a principled way to initialize the A matrix. The key insight is to maintain a polynomial approximation of the input history.</p>
</div>
</div>
<p>For the Legendre polynomials case (LegS):</p>
<ul>
<li>The A matrix has entries: <span class="math inline">\(A_{nk} = (2n+1)^{1/2}(2k+1)^{1/2}\)</span> if <span class="math inline">\(n &gt; k\)</span>, and <span class="math inline">\(A_{nk} = n+1\)</span> if <span class="math inline">\(n = k\)</span></li>
<li>This choice ensures that the state vector maintains an optimal polynomial approximation of the input history</li>
</ul>
</section>
</section>
<section id="from-continuous-to-discrete" class="level2">
<h2 class="anchored" data-anchor-id="from-continuous-to-discrete" id="from-continuous-to-discrete">From Continuous to Discrete</h2>
<section id="discretization-process" class="level3">
<h3 class="anchored" data-anchor-id="discretization-process" id="discretization-process">Discretization Process</h3>
<p>To apply SSMs to discrete sequences, we discretize using a step size <span class="math inline">\(\Delta\)</span>:</p>
<p>The Zero-Order Hold (ZOH) discretization gives us:</p>
<p><span id="eq-discrete-ssm"><span class="math display">\[
\begin{align}
h_k &amp;= \bar{A}h_{k-1} + \bar{B}x_k \\
y_k &amp;= Ch_k
\end{align}
\tag{2}\]</span></span></p>
<p>Where:</p>
<p><span id="eq-discretization"><span class="math display">\[
\begin{align}
\bar{A} &amp;= \exp(\Delta A) \\
\bar{B} &amp;= (\Delta A)^{-1}(\exp(\Delta A) - I)\Delta B
\end{align}
\tag{3}\]</span></span></p>
</section>
<section id="computational-forms" class="level3">
<h3 class="anchored" data-anchor-id="computational-forms" id="computational-forms">Computational Forms</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Recurrent Form</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Convolution Form</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p><strong>For generation:</strong> <span class="math display">\[
\begin{align}
h_k &amp;= \bar{A}h_{k-1} + \bar{B}x_k \\
y_k &amp;= Ch_k
\end{align}
\]</span></p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p><strong>For training:</strong> The SSM can be viewed as a convolution with kernel <span class="math inline">\(K\)</span>: <span class="math display">\[
K = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{L-1}\bar{B})
\]</span> <span class="math display">\[
y = K * x
\]</span> Where <span class="math inline">\(*\)</span> denotes convolution and <span class="math inline">\(L\)</span> is the sequence length.</p>
</div>
</div>
</div>
</section>
</section>
<section id="the-selection-mechanism" class="level2">
<h2 class="anchored" data-anchor-id="the-selection-mechanism" id="the-selection-mechanism">The Selection Mechanism</h2>
<section id="the-core-innovation" class="level3">
<h3 class="anchored" data-anchor-id="the-core-innovation" id="the-core-innovation">The Core Innovation</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>Traditional SSMs use fixed parameters <span class="math inline">\(A\)</span>, <span class="math inline">\(B\)</span>, <span class="math inline">\(C\)</span>, and <span class="math inline">\(\Delta\)</span>. Mamba’s key innovation is making these parameters functions of the input.</p>
</div>
</div>
<p><span id="eq-selection"><span class="math display">\[
\begin{align}
B &amp;= s_B(x) \\
C &amp;= s_C(x) \\
\Delta &amp;= \tau(s_\Delta(x))
\end{align}
\tag{4}\]</span></span></p>
<p>Where:</p>
<ul>
<li><span class="math inline">\(s_B\)</span>, <span class="math inline">\(s_C\)</span>, <span class="math inline">\(s_\Delta\)</span> are learnable projection functions</li>
<li><span class="math inline">\(\tau\)</span> is typically the softplus function: <span class="math inline">\(\tau(x) = \log(1 + \exp(x))\)</span></li>
</ul>
</section>
<section id="selection-functions" class="level3">
<h3 class="anchored" data-anchor-id="selection-functions" id="selection-functions">Selection Functions</h3>
<p>The selection functions are implemented as linear projections:</p>
<p><span id="eq-selection-functions"><span class="math display">\[
\begin{align}
s_B(x) &amp;= \text{Linear}_B(x) \quad \in \mathbb{R}^{B \times N} \\
s_C(x) &amp;= \text{Linear}_C(x) \quad \in \mathbb{R}^{B \times N} \\
s_\Delta(x) &amp;= \text{Broadcast}(\text{Linear}_\Delta(x)) \quad \in \mathbb{R}^{B \times N}
\end{align}
\tag{5}\]</span></span></p>
<p>Where <span class="math inline">\(B\)</span> is the batch size and <span class="math inline">\(N\)</span> is the state dimension.</p>
</section>
<section id="mathematical-justification" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-justification" id="mathematical-justification">Mathematical Justification</h3>
<p>The selection mechanism allows the model to:</p>
<ol type="1">
<li><strong>Filter irrelevant information</strong>: By modulating <span class="math inline">\(B\)</span>, the model controls what information enters the state</li>
<li><strong>Focus on specific aspects</strong>: By modulating <span class="math inline">\(C\)</span>, the model controls what information is output<br>
</li>
<li><strong>Control information flow</strong>: By modulating <span class="math inline">\(\Delta\)</span>, the model controls the rate of state updates</li>
</ol>
</section>
</section>
<section id="mamba-block-architecture" class="level2">
<h2 class="anchored" data-anchor-id="mamba-block-architecture" id="mamba-block-architecture">Mamba Block Architecture</h2>
<section id="complete-block-definition" class="level3">
<h3 class="anchored" data-anchor-id="complete-block-definition" id="complete-block-definition">Complete Block Definition</h3>
<p>A Mamba block processes input <span class="math inline">\(x \in \mathbb{R}^{B \times L \times D}\)</span> as follows:</p>
<div id="e2f46e33" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Pseudocode for Mamba block processing</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mamba_block(x):</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># 1. Input Projections</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>    x_prime <span class="op">=</span> Linear_in(x)  <span class="co"># ∈ R^(B×L×2E) </span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    x1, x2 <span class="op">=</span> split(x_prime)  <span class="co"># each ∈ R^(B×L×E)</span></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># 2. Selection Parameters  </span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    B <span class="op">=</span> s_B(x1)  <span class="co"># ∈ R^(B×L×N)</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>    C <span class="op">=</span> s_C(x1)  <span class="co"># ∈ R^(B×L×N)</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>    Delta <span class="op">=</span> softplus(s_Delta(x1))  <span class="co"># ∈ R^(B×L×N)</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># 3. Discretization</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    A_bar <span class="op">=</span> exp(Delta <span class="op">*</span> A)  <span class="co"># ∈ R^(B×L×N×N) </span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    B_bar <span class="op">=</span> Delta <span class="op">*</span> B       <span class="co"># ∈ R^(B×L×N)</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># 4. SSM Computation</span></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>    y1 <span class="op">=</span> SSM(A_bar, B_bar, C)(x1)  <span class="co"># ∈ R^(B×L×E)</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># 5. Gating and Output</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> y1 <span class="op">*</span> SiLU(x2)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> Linear_out(y)</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output</span></code></pre></div></div>
</div>
</section>
</section>
<section id="mathematical-formulations" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-formulations" id="mathematical-formulations">Mathematical Formulations</h2>
<section id="selective-scan-algorithm" class="level3">
<h3 class="anchored" data-anchor-id="selective-scan-algorithm" id="selective-scan-algorithm">Selective Scan Algorithm</h3>
<p>The core SSM computation for a sequence of length <span class="math inline">\(L\)</span>:</p>
<p><span id="eq-selective-scan"><span class="math display">\[
\begin{align}
h_0 &amp;= 0 \\
\text{for } k &amp;= 1 \text{ to } L: \\
h_k &amp;= \bar{A}_k \odot h_{k-1} + \bar{B}_k \odot x_k \\
y_k &amp;= C_k \odot h_k
\end{align}
\tag{6}\]</span></span></p>
<p>Where <span class="math inline">\(\odot\)</span> denotes element-wise multiplication.</p>
</section>
<section id="parallel-scan-formulation" class="level3">
<h3 class="anchored" data-anchor-id="parallel-scan-formulation" id="parallel-scan-formulation">Parallel Scan Formulation</h3>
<p>For parallel computation, we can express the recurrence as:</p>
<p><span id="eq-parallel-scan"><span class="math display">\[
h_k = \left(\prod_{i=1}^k \bar{A}_i\right) \odot h_0 + \sum_{j=1}^k \left(\prod_{i=j+1}^k \bar{A}_i\right) \odot (\bar{B}_j \odot x_j)
\tag{7}\]</span></span></p>
<p>This can be computed efficiently using parallel prefix sum algorithms.</p>
</section>
<section id="matrix-form" class="level3">
<h3 class="anchored" data-anchor-id="matrix-form" id="matrix-form">Matrix Form</h3>
<p>The complete transformation can be written as:</p>
<p><span id="eq-matrix-form"><span class="math display">\[
Y = \text{SSM}(X; A, B, C, \Delta)
\tag{8}\]</span></span></p>
<p>Where each element is:</p>
<p><span id="eq-element-wise"><span class="math display">\[
Y[b,l,d] = \sum_{k=1}^l \sum_{n=1}^N C[b,l,n] \cdot \left(\prod_{j=k+1}^l \bar{A}[b,j,n]\right) \cdot \bar{B}[b,k,n] \cdot X[b,k,d]
\tag{9}\]</span></span></p>
</section>
</section>
<section id="computational-efficiency" class="level2 page-columns page-full">
<h2 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h2>
<section id="complexity-analysis" class="level3 page-columns page-full">
<h3 class="anchored" data-anchor-id="complexity-analysis" id="complexity-analysis">Complexity Analysis</h3>

<div class="no-row-height column-margin column-container"><div class="">
<p>The linear scaling enables processing of very long sequences that would be prohibitive for standard transformers.</p>
</div></div><div id="tbl-complexity" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-complexity-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Complexity comparison where <span class="math inline">\(L\)</span> is sequence length, <span class="math inline">\(D\)</span> is dimension
</figcaption>
<div aria-describedby="tbl-complexity-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Time Complexity</th>
<th>Memory Complexity</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Transformer Attention</strong></td>
<td><span class="math inline">\(O(L^2D)\)</span></td>
<td><span class="math inline">\(O(L^2)\)</span></td>
</tr>
<tr class="even">
<td><strong>Mamba</strong></td>
<td><span class="math inline">\(O(LD)\)</span></td>
<td><span class="math inline">\(O(LD)\)</span></td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="hardware-aware-implementation" class="level3">
<h3 class="anchored" data-anchor-id="hardware-aware-implementation" id="hardware-aware-implementation">Hardware-Aware Implementation</h3>
<p>The selective scan can be implemented efficiently using:</p>
<ol type="1">
<li><strong>Parallel Scan</strong>: Using associative operations for parallel computation</li>
<li><strong>Kernel Fusion</strong>: Combining discretization and scan operations<br>
</li>
<li><strong>Memory Optimization</strong>: Avoiding materialization of large intermediate tensors</li>
</ol>
</section>
<section id="scan-operation-optimization" class="level3">
<h3 class="anchored" data-anchor-id="scan-operation-optimization" id="scan-operation-optimization">Scan Operation Optimization</h3>
<p>The parallel scan computes:</p>
<p><span id="eq-scan-optimization"><span class="math display">\[
(h_1, h_2, \ldots, h_L) = \text{parallel\_scan}(\odot, (\bar{A}_1\bar{B}_1x_1, \bar{A}_2\bar{B}_2x_2, \ldots, \bar{A}_L\bar{B}_Lx_L))
\tag{10}\]</span></span></p>
<p>Where the binary operator is:</p>
<p><span id="eq-binary-operator"><span class="math display">\[
(\bar{A}^i, b^i) \odot (\bar{A}^j, b^j) = (\bar{A}^j \odot \bar{A}^i, \bar{A}^j \odot b^i + b^j)
\tag{11}\]</span></span></p>
</section>
</section>
<section id="comparison-with-transformers" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-transformers" id="comparison-with-transformers">Comparison with Transformers</h2>
<section id="attention-vs-selection" class="level3">
<h3 class="anchored" data-anchor-id="attention-vs-selection" id="attention-vs-selection">Attention vs Selection</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Transformer Attention</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Mamba Selection</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<p><span class="math display">\[
\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
\]</span></p>
<ul>
<li>Computes all pairwise interactions: <span class="math inline">\(O(L^2)\)</span></li>
<li>Global receptive field</li>
<li>Content-based selection</li>
</ul>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<p><span class="math display">\[
\text{Selection via } B(x), C(x), \Delta(x)
\]</span></p>
<ul>
<li>Input-dependent parameters: <span class="math inline">\(O(L)\)</span></li>
<li>Infinite (theoretically) receptive field through state</li>
<li>Context-based filtering</li>
</ul>
</div>
</div>
</div>
</section>
<section id="information-flow" class="level3">
<h3 class="anchored" data-anchor-id="information-flow" id="information-flow">Information Flow</h3>
</section>
</section>
<div class="callout-compare">
<section id="transformers" class="level2">
<h2 class="anchored" data-anchor-id="transformers" id="transformers">Transformers</h2>
<ul>
<li>Information flows through attention weights</li>
<li>Each token can attend to all previous tokens<br>
</li>
<li>Requires causal masking for autoregressive generation</li>
</ul>
</section>
<section id="mamba" class="level2">
<h2 class="anchored" data-anchor-id="mamba" id="mamba">Mamba</h2>
<ul>
<li>Information flows through the state vector</li>
<li>State acts as a compressed representation of history</li>
<li>Naturally causal due to recurrent structure</li>
</ul>
</section>
</div>
<section id="implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="implementation-details" id="implementation-details">Implementation Details</h2>
<section id="initialization-strategies" class="level3">
<h3 class="anchored" data-anchor-id="initialization-strategies" id="initialization-strategies">Initialization Strategies</h3>
<ol type="1">
<li><strong>A Matrix</strong>: Initialize using HiPPO-LegS or similar structured initialization</li>
<li><strong>B, C Projections</strong>: Standard Gaussian initialization scaled by dimension</li>
<li><strong><span class="math inline">\(\Delta\)</span> Projection</strong>: Initialize to encourage slow dynamics initially</li>
</ol>
</section>
<section id="numerical-stability" class="level3">
<h3 class="anchored" data-anchor-id="numerical-stability" id="numerical-stability">Numerical Stability</h3>
<p>Several techniques ensure stable computation:</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Stability Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Clipping</strong>: Clip <span class="math inline">\(\Delta\)</span> values to prevent overflow in exponential</li>
<li><strong>Recomputation</strong>: Use selective recomputation during backward pass</li>
<li><strong>Mixed Precision</strong>: Use appropriate precision for different operations</li>
</ol>
</div>
</div>
</section>
<section id="training-considerations" class="level3">
<h3 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h3>
<ul>
<li><strong>Gradient Flow</strong>: The recurrent nature requires careful handling of gradients</li>
<li><strong>Truncated BPTT</strong>: May use truncated backpropagation for very long sequences</li>
<li><strong>Regularization</strong>: Apply dropout to projections rather than the state itself</li>
</ul>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="multi-head-mamba" class="level3">
<h3 class="anchored" data-anchor-id="multi-head-mamba" id="multi-head-mamba">Multi-Head Mamba</h3>
<p>Similar to multi-head attention, Mamba can use multiple independent SSM heads:</p>
<p><span class="math display">\[
\text{MultiHead\_Mamba}(x) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
\]</span></p>
<p>where <span class="math inline">\(\text{head}_i = \text{Mamba}_i(x)\)</span></p>
</section>
<section id="bidirectional-processing" class="level3">
<h3 class="anchored" data-anchor-id="bidirectional-processing" id="bidirectional-processing">Bidirectional Processing</h3>
<p>For non-causal applications, bidirectional Mamba processes sequences in both directions:</p>
<p><span id="eq-bidirectional"><span class="math display">\[
y = \text{Mamba}_{\text{forward}}(x) + \text{Mamba}_{\text{backward}}(\text{reverse}(x))
\tag{12}\]</span></span></p>
</section>
<section id="integration-with-other-mechanisms" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-other-mechanisms" id="integration-with-other-mechanisms">Integration with Other Mechanisms</h3>
<p>Mamba blocks can be combined with:</p>
<ul>
<li><strong>MLP blocks</strong>: Following similar patterns to transformer architectures</li>
<li><strong>Convolution</strong>: For local pattern recognition<br>
</li>
<li><strong>Attention</strong>: For hybrid architectures</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Contributions
</div>
</div>
<div class="callout-body-container callout-body">
<p>Mamba transformers represent a significant advance in sequence modeling by:</p>
<ol type="1">
<li><strong>Achieving Linear Complexity</strong>: <span class="math inline">\(O(L)\)</span> instead of <span class="math inline">\(O(L^2)\)</span> for sequence length <span class="math inline">\(L\)</span></li>
<li><strong>Maintaining Expressiveness</strong>: Through the selective mechanism</li>
<li><strong>Enabling Long Sequences</strong>: Practical processing of sequences with 100K+ tokens</li>
<li><strong>Preserving Parallelism</strong>: Training remains efficient through parallel scan</li>
</ol>
</div>
</div>
<p>The mathematical foundation built on selective state space models provides both theoretical rigor and practical efficiency, making Mamba a compelling alternative to traditional transformer architectures for many sequence modeling tasks.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>The key insight is that by making SSM parameters input-dependent, we can maintain the benefits of both recurrent models (linear complexity, infinite receptive field) and transformers (parallelizable training, strong performance), opening new possibilities for efficient sequence modeling at scale.</p>
</div>
</div>
</section>
<section id="appendix" class="level2">
<h2 class="anchored" data-anchor-id="appendix" id="appendix">Appendix</h2>
<section id="mathematical-notation-summary" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-notation-summary" id="mathematical-notation-summary">Mathematical Notation Summary</h3>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Symbol</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><span class="math inline">\(h(t), h_k\)</span></td>
<td>State vector (continuous/discrete)</td>
</tr>
<tr class="even">
<td><span class="math inline">\(x(t), x_k\)</span></td>
<td>Input signal/sequence</td>
</tr>
<tr class="odd">
<td><span class="math inline">\(y(t), y_k\)</span></td>
<td>Output signal/sequence</td>
</tr>
<tr class="even">
<td><span class="math inline">\(A, \bar{A}\)</span></td>
<td>State transition matrix</td>
</tr>
<tr class="odd">
<td><span class="math inline">\(B, \bar{B}\)</span></td>
<td>Input matrix</td>
</tr>
<tr class="even">
<td><span class="math inline">\(C\)</span></td>
<td>Output matrix</td>
</tr>
<tr class="odd">
<td><span class="math inline">\(\Delta\)</span></td>
<td>Discretization step size</td>
</tr>
<tr class="even">
<td><span class="math inline">\(L\)</span></td>
<td>Sequence length</td>
</tr>
<tr class="odd">
<td><span class="math inline">\(N\)</span></td>
<td>State dimension</td>
</tr>
<tr class="even">
<td><span class="math inline">\(D\)</span></td>
<td>Model dimension</td>
</tr>
</tbody>
</table>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Mamba Transformers: Revolutionizing Sequence Modeling with Selective State Space Models]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-transformer/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-transformer/</guid>
      <pubDate>Sat, 23 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="mamba-transformers-revolutionizing-sequence-modeling-with-selective-state-space-models" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/mamba/mamba-transformer/mamba.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Mamba represents a groundbreaking advancement in sequence modeling architecture, emerging as a compelling alternative to the dominant transformer paradigm. Introduced in late 2023 by Albert Gu and Tri Dao, Mamba addresses fundamental limitations of transformers while maintaining their modeling capabilities. This selective state space model (SSM) offers linear scaling with sequence length, making it particularly attractive for processing long sequences that would be computationally prohibitive for traditional attention-based models.</p>
</section>
<section id="background-the-need-for-better-sequence-models" class="level2">
<h2 class="anchored" data-anchor-id="background-the-need-for-better-sequence-models" id="background-the-need-for-better-sequence-models">Background: The Need for Better Sequence Models</h2>
<section id="limitations-of-transformers" class="level3">
<h3 class="anchored" data-anchor-id="limitations-of-transformers" id="limitations-of-transformers">Limitations of Transformers</h3>
<p>While transformers have achieved remarkable success across numerous domains, they face several critical challenges:</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Key Transformer Limitations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Quadratic Complexity</strong>: The self-attention mechanism scales quadratically with sequence length (O(n²))</li>
<li><strong>Fixed Context Windows</strong>: Most implementations are constrained by fixed context windows</li>
<li><strong>Computational Inefficiency</strong>: Parallel attention can be inefficient during inference</li>
</ul>
</div>
</div>
<p><strong>Quadratic Complexity</strong>: The self-attention mechanism scales quadratically with sequence length (O(n²)), making it computationally expensive and memory-intensive for long sequences. This limitation becomes particularly problematic when processing documents, long conversations, or high-resolution images treated as sequences.</p>
<p><strong>Fixed Context Windows</strong>: Most transformer implementations are constrained by fixed context windows, limiting their ability to maintain coherence over very long sequences. Even with techniques like sliding windows or sparse attention, the fundamental scalability issues remain.</p>
<p><strong>Computational Inefficiency</strong>: The parallel nature of attention, while beneficial for training, can be inefficient during inference, especially for autoregressive generation where each token requires attention to all previous tokens.</p>
</section>
<section id="enter-state-space-models" class="level3">
<h3 class="anchored" data-anchor-id="enter-state-space-models" id="enter-state-space-models">Enter State Space Models</h3>
<p>State space models offer an elegant mathematical framework for sequence modeling that naturally handles variable-length sequences with linear complexity. These models maintain a hidden state that evolves over time, capturing dependencies across the sequence without the quadratic scaling issues of attention.</p>
<p>The core idea behind SSMs is to model sequences through a continuous-time dynamical system:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># State Space Model equations</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># dx/dt = Ax + Bu</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="co"># y = Cx + Du</span></span></code></pre></div></div>
<p>Where:</p>
<ul>
<li><code>x</code> represents the hidden state</li>
<li><code>u</code> is the input sequence<br>
</li>
<li><code>y</code> is the output sequence</li>
<li><code>A</code>, <code>B</code>, <code>C</code>, <code>D</code> are learned parameter matrices</li>
</ul>
</section>
</section>
<section id="the-mamba-architecture" class="level2">
<h2 class="anchored" data-anchor-id="the-mamba-architecture" id="the-mamba-architecture">The Mamba Architecture</h2>
<section id="selective-state-space-models" class="level3">
<h3 class="anchored" data-anchor-id="selective-state-space-models" id="selective-state-space-models">Selective State Space Models</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Mamba’s Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>Mamba’s key innovation lies in making the state space model “selective” - the ability to selectively retain or forget information based on the input context.</p>
</div>
</div>
<p>Mamba’s key innovation lies in making the state space model “selective” - the ability to selectively retain or forget information based on the input context. This selectivity is achieved through input-dependent parameters, allowing the model to dynamically adjust its behavior based on the content it’s processing.</p>
</section>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<section id="selective-scan-algorithm" class="level4">
<h4 class="anchored" data-anchor-id="selective-scan-algorithm">Selective Scan Algorithm</h4>
<p>The heart of Mamba is the selective scan algorithm, which efficiently computes state transitions while maintaining the ability to selectively focus on relevant information. Unlike traditional SSMs with fixed parameters, Mamba’s parameters (particularly the <code>B</code> and <code>C</code> matrices) are functions of the input:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Input-dependent parameterization</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>B_t <span class="op">=</span> Linear_B(x_t)</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>C_t <span class="op">=</span> Linear_C(x_t)</span></code></pre></div></div>
<p>This input-dependent parameterization allows the model to gate information flow dynamically, similar to how LSTM gates control information retention and forgetting.</p>
</section>
<section id="hardware-efficient-implementation" class="level4">
<h4 class="anchored" data-anchor-id="hardware-efficient-implementation">Hardware-Efficient Implementation</h4>
<p>One of Mamba’s significant achievements is its hardware-efficient implementation. The authors developed specialized CUDA kernels that avoid materializing intermediate states in high-bandwidth memory (HBM). Instead, computations are performed in SRAM, dramatically reducing memory access overhead and enabling efficient processing of long sequences.</p>
</section>
<section id="the-mamba-block" class="level4">
<h4 class="anchored" data-anchor-id="the-mamba-block">The Mamba Block</h4>
<p>A single Mamba block consists of:</p>
<ul>
<li><strong>Input Projection</strong>: Linear transformation of input embeddings</li>
<li><strong>Selective SSM Layer</strong>: The core selective state space computation</li>
<li><strong>Output Projection</strong>: Final linear transformation</li>
<li><strong>Residual Connection</strong>: Skip connection for gradient flow</li>
<li><strong>Normalization</strong>: Layer normalization for training stability</li>
</ul>
<p>Multiple Mamba blocks are stacked to create deeper models, similar to transformer layers.</p>
</section>
</section>
<section id="mathematical-formulation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-formulation" id="mathematical-formulation">Mathematical Formulation</h3>
<p>The selective SSM in Mamba can be expressed as:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Selective SSM equations</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>h_t <span class="op">=</span> A <span class="op">*</span> h_{t<span class="op">-</span><span class="dv">1</span>} <span class="op">+</span> B_t <span class="op">*</span> x_t</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>y_t <span class="op">=</span> C_t <span class="op">*</span> h_t</span></code></pre></div></div>
<p>Where:</p>
<ul>
<li><code>h_t</code> is the hidden state at time step t</li>
<li><code>x_t</code> is the input at time step t</li>
<li><code>y_t</code> is the output at time step t</li>
<li><code>A</code> is a learned transition matrix (often initialized as a HiPPO matrix)</li>
<li><code>B_t</code> and <code>C_t</code> are input-dependent projection matrices</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>The selectivity comes from the fact that <code>B_t</code> and <code>C_t</code> vary with the input, allowing the model to adaptively control information flow.</p>
</div>
</div>
</section>
</section>
<section id="key-innovations-and-advantages" class="level2">
<h2 class="anchored" data-anchor-id="key-innovations-and-advantages" id="key-innovations-and-advantages">Key Innovations and Advantages</h2>
<section id="linear-scaling" class="level3">
<h3 class="anchored" data-anchor-id="linear-scaling" id="linear-scaling">Linear Scaling</h3>
<p>Mamba’s most significant advantage is its linear scaling with sequence length O(n), compared to transformers’ quadratic scaling O(n²). This makes it practical to process sequences with hundreds of thousands or even millions of tokens, opening up new possibilities for modeling very long contexts.</p>
</section>
<section id="efficient-memory-usage" class="level3">
<h3 class="anchored" data-anchor-id="efficient-memory-usage" id="efficient-memory-usage">Efficient Memory Usage</h3>
<p>The hardware-aware implementation ensures that memory usage scales linearly with sequence length, without the attention mechanism’s memory bottlenecks. This efficiency extends to both training and inference.</p>
</section>
<section id="strong-inductive-biases" class="level3">
<h3 class="anchored" data-anchor-id="strong-inductive-biases" id="strong-inductive-biases">Strong Inductive Biases</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Natural Sequence Modeling Advantages
</div>
</div>
<div class="callout-body-container callout-body">
<p>The state space formulation provides natural inductive biases:</p>
<ul>
<li><strong>Causality</strong>: Information flows from past to future naturally</li>
<li><strong>Translation Invariance</strong>: Handles sequences of varying lengths</li>
<li><strong>Stability</strong>: Mathematical foundation ensures stable training</li>
</ul>
</div>
</div>
</section>
<section id="fast-inference" class="level3">
<h3 class="anchored" data-anchor-id="fast-inference" id="fast-inference">Fast Inference</h3>
<p>During autoregressive generation, Mamba only needs to update its hidden state rather than recomputing attention over all previous tokens. This leads to significantly faster inference, especially for long sequences.</p>
</section>
</section>
<section id="performance-and-capabilities" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-capabilities" id="performance-and-capabilities">Performance and Capabilities</h2>
<section id="language-modeling" class="level3">
<h3 class="anchored" data-anchor-id="language-modeling" id="language-modeling">Language Modeling</h3>
<p>Mamba has demonstrated competitive performance on language modeling benchmarks while using significantly less computational resources. Key results include:</p>
<ul>
<li><strong>Perplexity</strong>: Competitive or superior perplexity scores compared to transformers of similar size</li>
<li><strong>Scaling</strong>: Maintains performance advantages as model size increases<br>
</li>
<li><strong>Efficiency</strong>: Dramatically reduced inference time for long sequences</li>
</ul>
</section>
<section id="long-context-understanding" class="level3">
<h3 class="anchored" data-anchor-id="long-context-understanding" id="long-context-understanding">Long Context Understanding</h3>
<p>Perhaps most impressively, Mamba excels at tasks requiring long-context understanding:</p>
<ul>
<li><strong>Document Processing</strong>: Can effectively process entire books or long documents</li>
<li><strong>Code Generation</strong>: Handles large codebases with complex dependencies</li>
<li><strong>Conversation Modeling</strong>: Maintains coherence over very long dialogues</li>
</ul>
</section>
<section id="domain-specific-applications" class="level3">
<h3 class="anchored" data-anchor-id="domain-specific-applications" id="domain-specific-applications">Domain-Specific Applications</h3>
<p>Mamba’s efficiency makes it particularly suitable for:</p>
<ul>
<li><strong>Genomic Sequence Analysis</strong>: Processing DNA sequences with millions of base pairs</li>
<li><strong>Time Series Forecasting</strong>: Handling long temporal sequences efficiently</li>
<li><strong>Audio Processing</strong>: Managing long audio sequences for speech and music applications</li>
</ul>
</section>
</section>
<section id="architectural-variations-and-extensions" class="level2">
<h2 class="anchored" data-anchor-id="architectural-variations-and-extensions" id="architectural-variations-and-extensions">Architectural Variations and Extensions</h2>
<section id="mamba-2" class="level3">
<h3 class="anchored" data-anchor-id="mamba-2" id="mamba-2">Mamba-2</h3>
<p>The follow-up work, Mamba-2, introduced additional improvements:</p>
<ul>
<li><strong>State Space Duality</strong>: Bridging connections between state space models and attention mechanisms</li>
<li><strong>Improved Training Dynamics</strong>: Better gradient flow and training stability</li>
<li><strong>Enhanced Hardware Efficiency</strong>: Further optimizations for modern GPU architectures</li>
</ul>
</section>
<section id="hybrid-architectures" class="level3">
<h3 class="anchored" data-anchor-id="hybrid-architectures" id="hybrid-architectures">Hybrid Architectures</h3>
<p>Researchers have explored combining Mamba with other architectures:</p>
<ul>
<li><strong>Mamba-Transformer Hybrids</strong>: Using Mamba for long-range dependencies and transformers for complex reasoning</li>
<li><strong>Multi-Scale Mamba</strong>: Different Mamba layers operating at different temporal scales</li>
<li><strong>Attention-Augmented Mamba</strong>: Adding selective attention layers for specific tasks</li>
</ul>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="training-strategies" class="level3">
<h3 class="anchored" data-anchor-id="training-strategies" id="training-strategies">Training Strategies</h3>
<p>Training Mamba models requires specific considerations:</p>
<ul>
<li><strong>Initialization</strong>: Proper initialization of the A matrix (often using HiPPO initialization)</li>
<li><strong>Learning Rate Scheduling</strong>: Different learning rates for different parameter groups</li>
<li><strong>Regularization</strong>: Specific regularization techniques for SSM parameters</li>
</ul>
</section>
<section id="hyperparameter-tuning" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h3>
<p>Key hyperparameters include:</p>
<ul>
<li><strong>State Dimension</strong>: The size of the hidden state</li>
<li><strong>Expansion Factor</strong>: How much to expand the intermediate representations</li>
<li><strong>Number of Layers</strong>: Depth of the Mamba stack</li>
<li><strong>Delta Parameter</strong>: Controls the discretization of the continuous system</li>
</ul>
</section>
<section id="hardware-requirements" class="level3">
<h3 class="anchored" data-anchor-id="hardware-requirements" id="hardware-requirements">Hardware Requirements</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Hardware Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<p>While more efficient than transformers for long sequences, Mamba still benefits from modern hardware for optimal performance.</p>
</div>
</div>
<p>While more efficient than transformers for long sequences, Mamba still benefits from:</p>
<ul>
<li><strong>High-Bandwidth Memory</strong>: For optimal performance</li>
<li><strong>Modern GPUs</strong>: CUDA kernels are optimized for recent architectures</li>
<li><strong>Sufficient VRAM</strong>: For storing model parameters and intermediate states</li>
</ul>
</section>
</section>
<section id="comparison-with-transformers" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-transformers" id="comparison-with-transformers">Comparison with Transformers</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<div id="tbl-complexity" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-complexity-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Computational complexity comparison between Transformers and Mamba
</figcaption>
<div aria-describedby="tbl-complexity-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Aspect</th>
<th>Transformers</th>
<th>Mamba</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Time Complexity</td>
<td>O(n²d)</td>
<td>O(nd)</td>
</tr>
<tr class="even">
<td>Memory Complexity</td>
<td>O(n²)</td>
<td>O(n)</td>
</tr>
<tr class="odd">
<td>Parallelization</td>
<td>High (training)</td>
<td>Moderate</td>
</tr>
<tr class="even">
<td>Inference Speed</td>
<td>Slow (long sequences)</td>
<td>Fast</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="task-performance" class="level3">
<h3 class="anchored" data-anchor-id="task-performance" id="task-performance">Task Performance</h3>
<ul>
<li><strong>Short Sequences</strong>: Transformers often maintain slight advantages</li>
<li><strong>Medium Sequences</strong>: Performance is generally comparable</li>
<li><strong>Long Sequences</strong>: Mamba consistently outperforms transformers</li>
<li><strong>Specialized Tasks</strong>: Task-dependent, with each architecture having strengths</li>
</ul>
</section>
<section id="practical-considerations" class="level3">
<h3 class="anchored" data-anchor-id="practical-considerations" id="practical-considerations">Practical Considerations</h3>
<ul>
<li><strong>Implementation Complexity</strong>: Mamba requires specialized kernels</li>
<li><strong>Ecosystem Maturity</strong>: Transformers have more extensive tooling and libraries</li>
<li><strong>Research Investment</strong>: Transformers have received more research attention</li>
<li><strong>Industry Adoption</strong>: Transformers currently dominate production systems</li>
</ul>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="natural-language-processing" class="level3">
<h3 class="anchored" data-anchor-id="natural-language-processing" id="natural-language-processing">Natural Language Processing</h3>
<ul>
<li><strong>Long Document Summarization</strong>: Processing entire books or research papers</li>
<li><strong>Multi-Turn Dialogue</strong>: Maintaining context over extended conversations</li>
<li><strong>Code Analysis</strong>: Understanding large codebases with complex dependencies</li>
<li><strong>Legal Document Analysis</strong>: Processing lengthy contracts and legal texts</li>
</ul>
</section>
<section id="scientific-computing" class="level3">
<h3 class="anchored" data-anchor-id="scientific-computing" id="scientific-computing">Scientific Computing</h3>
<ul>
<li><strong>Genomics</strong>: Analyzing long DNA sequences for pattern recognition</li>
<li><strong>Climate Modeling</strong>: Processing long time series of climate data</li>
<li><strong>Protein Folding</strong>: Understanding long protein sequences and their structures</li>
<li><strong>Astronomical Data</strong>: Analyzing long time series from celestial observations</li>
</ul>
</section>
<section id="creative-applications" class="level3">
<h3 class="anchored" data-anchor-id="creative-applications" id="creative-applications">Creative Applications</h3>
<ul>
<li><strong>Music Generation</strong>: Composing long musical pieces with coherent structure</li>
<li><strong>Story Generation</strong>: Creating novels or long-form narratives</li>
<li><strong>Video Analysis</strong>: Processing long video sequences for content understanding</li>
<li><strong>Game AI</strong>: Maintaining long-term strategy and memory in game environments</li>
</ul>
</section>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<section id="current-limitations" class="level3">
<h3 class="anchored" data-anchor-id="current-limitations" id="current-limitations">Current Limitations</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Known Limitations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Parallel Training</strong>: Less parallelizable than transformers during training</li>
<li><strong>Complex Reasoning</strong>: May struggle with complex multi-step reasoning tasks</li>
<li><strong>Established Benchmarks</strong>: Many benchmarks optimized for transformer architectures</li>
<li><strong>Implementation Complexity</strong>: Requires careful implementation for optimal performance</li>
</ul>
</div>
</div>
</section>
<section id="ongoing-research-challenges" class="level3">
<h3 class="anchored" data-anchor-id="ongoing-research-challenges" id="ongoing-research-challenges">Ongoing Research Challenges</h3>
<ul>
<li><strong>Theoretical Understanding</strong>: Deepening our understanding of why Mamba works so well</li>
<li><strong>Architectural Improvements</strong>: Developing better hybrid architectures</li>
<li><strong>Scaling Laws</strong>: Understanding how Mamba performance scales with model size</li>
<li><strong>Task-Specific Adaptations</strong>: Optimizing Mamba for specific domains and tasks</li>
</ul>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="research-opportunities" class="level3">
<h3 class="anchored" data-anchor-id="research-opportunities" id="research-opportunities">Research Opportunities</h3>
<ul>
<li><strong>Multimodal Extensions</strong>: Extending Mamba to vision, audio, and other modalities</li>
<li><strong>Architecture Search</strong>: Automatically discovering optimal Mamba configurations</li>
<li><strong>Theoretical Analysis</strong>: Better understanding the representational capabilities</li>
<li><strong>Efficiency Improvements</strong>: Further optimizations for specific hardware platforms</li>
</ul>
</section>
<section id="potential-breakthroughs" class="level3">
<h3 class="anchored" data-anchor-id="potential-breakthroughs" id="potential-breakthroughs">Potential Breakthroughs</h3>
<ul>
<li><strong>Universal Sequence Models</strong>: Models that can handle any type of sequence data</li>
<li><strong>Extreme Long Context</strong>: Processing sequences with billions of tokens</li>
<li><strong>Real-time Processing</strong>: Ultra-low latency inference for streaming applications</li>
<li><strong>Neuromorphic Implementation</strong>: Implementing Mamba on brain-inspired hardware</li>
</ul>
</section>
<section id="industry-implications" class="level3">
<h3 class="anchored" data-anchor-id="industry-implications" id="industry-implications">Industry Implications</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Transformative Potential
</div>
</div>
<div class="callout-body-container callout-body">
<p>Mamba’s efficiency gains could enable:</p>
<ul>
<li><strong>Cost Reduction</strong>: Dramatically lower computational costs</li>
<li><strong>New Applications</strong>: Previously impossible applications due to efficiency gains</li>
<li><strong>Democratization</strong>: Making long-context modeling accessible to smaller organizations</li>
<li><strong>Sustainability</strong>: Reducing environmental impact of large-scale modeling</li>
</ul>
</div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Mamba represents a paradigm shift in sequence modeling, offering a mathematically elegant and computationally efficient alternative to transformers. Its linear scaling properties, selective attention mechanism, and hardware-optimized implementation make it particularly compelling for applications involving long sequences.</p>
<p>While transformers continue to dominate many areas of machine learning, Mamba’s unique advantages position it as a crucial tool in the sequence modeling toolkit. The architecture’s efficiency gains are not merely incremental improvements but represent qualitative leaps that enable entirely new classes of applications.</p>
<p>As the field continues to evolve, we can expect to see increased adoption of Mamba-based models, particularly in domains where long-context understanding is crucial. The ongoing research into hybrid architectures, theoretical foundations, and domain-specific adaptations suggests that Mamba’s influence will only grow in the coming years.</p>
<p>The success of Mamba also highlights the importance of looking beyond attention mechanisms for sequence modeling solutions. By drawing inspiration from classical signal processing and control theory, the Mamba architecture demonstrates that innovative solutions often emerge from interdisciplinary approaches to longstanding problems.</p>
<p>For practitioners and researchers working with sequence data, Mamba offers a powerful new paradigm that combines theoretical elegance with practical efficiency. Whether used as a drop-in replacement for transformers or as part of hybrid architectures, Mamba represents a significant step forward in our quest to build more efficient and capable sequence models.</p>
</section>
<section id="references-and-further-reading" class="level2">
<h2 class="anchored" data-anchor-id="references-and-further-reading" id="references-and-further-reading">References and Further Reading</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key References
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Original Mamba Paper</strong>: “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (Gu &amp; Dao, 2023)</li>
<li><strong>State Space Models</strong>: “Efficiently Modeling Long Sequences with Structured State Spaces” (Gu et al., 2022)<br>
</li>
<li><strong>HiPPO Theory</strong>: “HiPPO: Recurrent Memory with Optimal Polynomial Projections” (Gu et al., 2020)</li>
<li><strong>Implementation Details</strong>: Official Mamba repository and CUDA kernels</li>
<li><strong>Comparative Studies</strong>: Various papers comparing Mamba with transformers across different tasks</li>
<li><strong>Hardware Optimization</strong>: Papers on efficient implementation of state space models</li>
</ul>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Quantization and Pruning]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/quantization/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/quantization/</guid>
      <pubDate>Fri, 22 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-quantization-and-pruning" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/quantization/quart.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Model compression techniques are essential for deploying deep learning models in resource-constrained environments. Two of the most effective approaches are quantization and pruning, which can significantly reduce model size, memory usage, and inference time while maintaining acceptable performance.</p>
<section id="why-model-compression-matters" class="level3">
<h3 class="anchored" data-anchor-id="why-model-compression-matters" id="why-model-compression-matters">Why Model Compression Matters</h3>
<p>Model compression addresses several critical challenges in deep learning deployment:</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Benefits of Model Compression
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Memory Efficiency</strong>: Reduced memory footprint enables deployment on mobile devices and edge hardware</li>
<li><strong>Inference Speed</strong>: Faster computations through reduced precision arithmetic and fewer operations</li>
<li><strong>Energy Consumption</strong>: Lower power requirements for battery-powered devices</li>
<li><strong>Cost Reduction</strong>: Decreased cloud computing costs and hardware requirements</li>
<li><strong>Accessibility</strong>: Enables AI deployment in environments with limited computational resources</li>
</ul>
</div>
</div>
</section>
</section>
<section id="quantization" class="level2">
<h2 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h2>
<p>Quantization reduces the precision of model weights and activations from floating-point representations (typically 32-bit) to lower-bit representations (8-bit, 4-bit, or even binary).</p>
<section id="fundamentals-of-quantization" class="level3">
<h3 class="anchored" data-anchor-id="fundamentals-of-quantization" id="fundamentals-of-quantization">Fundamentals of Quantization</h3>
<section id="uniform-quantization" class="level4">
<h4 class="anchored" data-anchor-id="uniform-quantization">Uniform Quantization</h4>
<p>The most common form maps continuous values to a finite set of discrete levels:</p>
<p><span id="eq-quantization"><span class="math display">\[Q(x) = \text{round}\left(\frac{x - \text{zero\_point}}{\text{scale}}\right) + \text{zero\_point} \tag{1}\]</span></span></p>
<p>Where:</p>
<ul>
<li><code>scale</code>: The step size between quantization levels</li>
<li><code>zero_point</code>: The value that maps to zero in the quantized representation</li>
</ul>
</section>
<section id="asymmetric-vs-symmetric-quantization" class="level4">
<h4 class="anchored" data-anchor-id="asymmetric-vs-symmetric-quantization">Asymmetric vs Symmetric Quantization</h4>
<p><strong>Symmetric Quantization</strong>: Zero point is at the center of the range</p>
<ul>
<li>Simpler implementation</li>
<li>Better for weights that are roughly centered around zero</li>
<li>Formula: <span class="math inline">\(Q(x) = \text{round}(x / \text{scale})\)</span></li>
</ul>
<p><strong>Asymmetric Quantization</strong>: Zero point can be anywhere in the range</p>
<ul>
<li>Better utilization of the quantization range</li>
<li>More suitable for activations (often non-negative)</li>
<li>Handles skewed distributions better</li>
</ul>
</section>
</section>
<section id="types-of-quantization" class="level3">
<h3 class="anchored" data-anchor-id="types-of-quantization" id="types-of-quantization">Types of Quantization</h3>
<section id="post-training-quantization-ptq" class="level4">
<h4 class="anchored" data-anchor-id="post-training-quantization-ptq">Post-Training Quantization (PTQ)</h4>
<p>Quantizes a pre-trained model without retraining:</p>
<p><strong>Static PTQ</strong>: Uses a calibration dataset to determine quantization parameters</p>
<ul>
<li>Faster deployment</li>
<li>No training data required</li>
<li>May have accuracy degradation for complex models</li>
</ul>
<p><strong>Dynamic PTQ</strong>: Determines quantization parameters at runtime</p>
<ul>
<li>Better accuracy than static PTQ</li>
<li>Slightly higher inference overhead</li>
<li>No calibration dataset needed</li>
</ul>
</section>
<section id="quantization-aware-training-qat" class="level4">
<h4 class="anchored" data-anchor-id="quantization-aware-training-qat">Quantization-Aware Training (QAT)</h4>
<p>Simulates quantization effects during training:</p>
<ul>
<li>Higher accuracy preservation</li>
<li>Requires retraining the model</li>
<li>Longer development time but better results</li>
</ul>
</section>
</section>
<section id="bit-width-considerations" class="level3">
<h3 class="anchored" data-anchor-id="bit-width-considerations" id="bit-width-considerations">Bit-width Considerations</h3>
<div id="tbl-bitwidth" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-bitwidth-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Quantization bit-width comparison
</figcaption>
<div aria-describedby="tbl-bitwidth-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Bit-width</th>
<th>Compression</th>
<th>Accuracy Trade-off</th>
<th>Use Case</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>8-bit (INT8)</td>
<td>2-4x</td>
<td>Minimal</td>
<td>Most common, well-supported</td>
</tr>
<tr class="even">
<td>4-bit</td>
<td>Up to 8x</td>
<td>Moderate</td>
<td>Inference-only scenarios</td>
</tr>
<tr class="odd">
<td>Binary/Ternary</td>
<td>Up to 32x</td>
<td>Significant</td>
<td>Extreme compression needs</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="mixed-precision-quantization" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-quantization" id="mixed-precision-quantization">Mixed-Precision Quantization</h3>
<p>Different layers use different precisions based on sensitivity analysis:</p>
<ul>
<li>Critical layers (e.g., first and last layers) kept at higher precision</li>
<li>Less sensitive layers quantized more aggressively</li>
<li>Automated search algorithms determine optimal bit allocation</li>
</ul>
</section>
</section>
<section id="pruning" class="level2">
<h2 class="anchored" data-anchor-id="pruning" id="pruning">Pruning</h2>
<p>Pruning removes redundant or less important connections, neurons, or entire layers from neural networks.</p>
<section id="types-of-pruning" class="level3">
<h3 class="anchored" data-anchor-id="types-of-pruning" id="types-of-pruning">Types of Pruning</h3>
<section id="magnitude-based-pruning" class="level4">
<h4 class="anchored" data-anchor-id="magnitude-based-pruning">Magnitude-Based Pruning</h4>
<p>Removes weights with the smallest absolute values:</p>
<ul>
<li>Simple to implement</li>
<li>Works well for many architectures</li>
<li>May not capture weight importance accurately</li>
</ul>
</section>
<section id="gradient-based-pruning" class="level4">
<h4 class="anchored" data-anchor-id="gradient-based-pruning">Gradient-Based Pruning</h4>
<p>Considers gradients to determine weight importance:</p>
<ul>
<li><strong>Fisher Information</strong>: Uses second-order derivatives</li>
<li><strong>SNIP</strong> (Single-shot Network Pruning): Prunes before training</li>
<li><strong>GraSP</strong>: Gradient Signal Preservation</li>
</ul>
</section>
<section id="lottery-ticket-hypothesis" class="level4">
<h4 class="anchored" data-anchor-id="lottery-ticket-hypothesis">Lottery Ticket Hypothesis</h4>
<p>Identifies sparse subnetworks that can be trained from scratch:</p>
<ul>
<li>Iterative magnitude pruning</li>
<li>Rewinding to early training checkpoints</li>
<li>Maintains original network performance</li>
</ul>
</section>
</section>
<section id="pruning-granularities" class="level3">
<h3 class="anchored" data-anchor-id="pruning-granularities" id="pruning-granularities">Pruning Granularities</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Unstructured Pruning</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Structured Pruning</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Semi-Structured Pruning</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>Removes individual weights regardless of their position:</p>
<ul>
<li>Higher compression ratios possible</li>
<li>Irregular sparsity patterns</li>
<li>May not lead to actual speedup without specialized hardware</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Removes entire structures (channels, filters, layers):</p>
<ul>
<li><strong>Channel Pruning</strong>: Removes entire feature map channels</li>
<li><strong>Filter Pruning</strong>: Removes convolutional filters</li>
<li><strong>Block Pruning</strong>: Removes structured weight blocks</li>
</ul>
<p>Benefits: - Guaranteed speedup on standard hardware - Maintains regular computation patterns - Easier to implement in existing frameworks</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Balances compression and hardware efficiency:</p>
<ul>
<li>N:M sparsity (e.g., 2:4 sparsity removes 2 out of every 4 weights)</li>
<li>Supported by modern hardware (NVIDIA Ampere architecture)</li>
<li>Good compression with hardware acceleration</li>
</ul>
</div>
</div>
</div>
</section>
<section id="pruning-schedules" class="level3">
<h3 class="anchored" data-anchor-id="pruning-schedules" id="pruning-schedules">Pruning Schedules</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A[Start Training] --&gt; B{Pruning Strategy}
    B --&gt;|One-Shot| C[Remove All Weights at Once]
    B --&gt;|Gradual| D[Remove Weights Incrementally]
    B --&gt;|Iterative| E[Cycle: Prune-Train-Recover]
    C --&gt; F[Simple Implementation]
    D --&gt; G[Better Accuracy Preservation]
    E --&gt; H[Highest Accuracy Retention]
    F --&gt; I[May Cause Accuracy Drop]
    G --&gt; J[Network Adapts Gradually]
    H --&gt; K[Computationally Expensive]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="knowledge-distillation-with-compression" class="level3">
<h3 class="anchored" data-anchor-id="knowledge-distillation-with-compression" id="knowledge-distillation-with-compression">Knowledge Distillation with Compression</h3>
<p>Combines compression with knowledge transfer:</p>
<ul>
<li>Teacher-student framework during compression</li>
<li>Maintains performance while reducing model size</li>
<li>Particularly effective for quantization</li>
</ul>
</section>
<section id="neural-architecture-search-nas-for-compression" class="level3">
<h3 class="anchored" data-anchor-id="neural-architecture-search-nas-for-compression" id="neural-architecture-search-nas-for-compression">Neural Architecture Search (NAS) for Compression</h3>
<p>Automated design of compressed architectures:</p>
<ul>
<li>Hardware-aware NAS considers deployment constraints</li>
<li>Co-optimization of architecture and compression</li>
<li>Differentiable NAS for quantization</li>
</ul>
</section>
<section id="lottery-ticket-hypothesis-variants" class="level3">
<h3 class="anchored" data-anchor-id="lottery-ticket-hypothesis-variants" id="lottery-ticket-hypothesis-variants">Lottery Ticket Hypothesis Variants</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Variants
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>SNIP (Single-shot Network Pruning)</strong>:</p>
<ul>
<li>Prunes networks before training</li>
<li>Uses gradient information for importance scoring</li>
<li>Faster than iterative approaches</li>
</ul>
<p><strong>GraSP (Gradient Signal Preservation)</strong>:</p>
<ul>
<li>Maintains gradient flow through the network</li>
<li>Better performance on deep networks</li>
<li>Considers layer-wise interactions</li>
</ul>
</div>
</div>
</section>
</section>
<section id="implementation-examples" class="level2">
<h2 class="anchored" data-anchor-id="implementation-examples" id="implementation-examples">Implementation Examples</h2>
<section id="pytorch-quantization-example" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-quantization-example" id="pytorch-quantization-example">PyTorch Quantization Example</h3>
<div id="pytorch-quantization" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization <span class="im">as</span> quant</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Define a simple model</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(nn.Module):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">3</span>)</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc <span class="op">=</span> nn.Linear(<span class="dv">64</span>, <span class="dv">10</span>)</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.relu(<span class="va">self</span>.conv1(x))</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.relu(<span class="va">self</span>.conv2(x))</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc(x)</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Post-training quantization</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel()</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare model for quantization</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>model.qconfig <span class="op">=</span> quant.get_default_qconfig(<span class="st">'fbgemm'</span>)</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>quant.prepare(model, inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Calibrate with sample data</span></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a><span class="co"># calibrate_model(model, calibration_data)</span></span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to quantized model</span></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>quantized_model <span class="op">=</span> quant.convert(model, inplace<span class="op">=</span><span class="va">False</span>)</span></code></pre></div></div>
</div>
</section>
<section id="pruning-example" class="level3">
<h3 class="anchored" data-anchor-id="pruning-example" id="pruning-example">Pruning Example</h3>
<div id="pytorch-pruning" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.prune <span class="im">as</span> prune</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply magnitude-based unstructured pruning</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel()</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>parameters_to_prune <span class="op">=</span> [</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    (model.conv1, <span class="st">'weight'</span>),</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    (model.conv2, <span class="st">'weight'</span>),</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    (model.fc, <span class="st">'weight'</span>),</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Prune 30% of weights globally</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>prune.global_unstructured(</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>    parameters_to_prune,</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    pruning_method<span class="op">=</span>prune.L1Unstructured,</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    amount<span class="op">=</span><span class="fl">0.3</span>,</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Make pruning permanent</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> module, param <span class="kw">in</span> parameters_to_prune:</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>    prune.remove(module, param)</span></code></pre></div></div>
</div>
</section>
<section id="structured-pruning-implementation" class="level3">
<h3 class="anchored" data-anchor-id="structured-pruning-implementation" id="structured-pruning-implementation">Structured Pruning Implementation</h3>
<div id="structured-pruning" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.prune <span class="im">as</span> prune</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> channel_pruning(model, layer_name, amount):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Prune channels based on L1 norm of filters"""</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    layer <span class="op">=</span> <span class="bu">getattr</span>(model, layer_name)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate channel importance (L1 norm)</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    importance <span class="op">=</span> torch.norm(layer.weight.data, p<span class="op">=</span><span class="dv">1</span>, dim<span class="op">=</span>[<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>])</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Determine channels to prune</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    num_channels <span class="op">=</span> <span class="bu">len</span>(importance)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    num_prune <span class="op">=</span> <span class="bu">int</span>(amount <span class="op">*</span> num_channels)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> num_prune <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        _, indices <span class="op">=</span> torch.topk(importance, num_prune, largest<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create pruning mask</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        prune.structured(layer, name<span class="op">=</span><span class="st">'weight'</span>, amount<span class="op">=</span>amount, </span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>                        dim<span class="op">=</span><span class="dv">0</span>, importance_scores<span class="op">=</span>importance)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>channel_pruning(model, <span class="st">'conv1'</span>, <span class="fl">0.5</span>)  <span class="co"># Prune 50% of channels</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="quantization-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="quantization-best-practices" id="quantization-best-practices">Quantization Best Practices</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Quantization Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Start with 8-bit quantization</strong>: Best balance of compression and accuracy</li>
<li><strong>Use calibration data</strong>: Representative of actual deployment data</li>
<li><strong>Layer sensitivity analysis</strong>: Identify which layers are most sensitive to quantization</li>
<li><strong>Gradual quantization</strong>: Start with less aggressive quantization and increase gradually</li>
<li><strong>Batch normalization folding</strong>: Combine BN parameters with preceding layer weights</li>
</ol>
</div>
</div>
</section>
<section id="pruning-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="pruning-best-practices" id="pruning-best-practices">Pruning Best Practices</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Pruning Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Sensitivity analysis</strong>: Determine which layers/channels are most important</li>
<li><strong>Gradual pruning</strong>: Remove weights incrementally during training</li>
<li><strong>Fine-tuning</strong>: Always fine-tune after pruning to recover accuracy</li>
<li><strong>Layer-wise pruning ratios</strong>: Different layers may benefit from different pruning ratios</li>
<li><strong>Structured over unstructured</strong>: Choose structured pruning for guaranteed speedup</li>
</ol>
</div>
</div>
</section>
<section id="combined-approaches" class="level3">
<h3 class="anchored" data-anchor-id="combined-approaches" id="combined-approaches">Combined Approaches</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Important Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Order matters</strong>: Generally prune first, then quantize</li>
<li><strong>Joint optimization</strong>: Consider both techniques simultaneously during training</li>
<li><strong>Hardware considerations</strong>: Align compression strategy with deployment hardware</li>
<li><strong>Validation throughout</strong>: Monitor accuracy at each compression stage</li>
</ol>
</div>
</div>
</section>
</section>
<section id="tools-and-frameworks" class="level2">
<h2 class="anchored" data-anchor-id="tools-and-frameworks" id="tools-and-frameworks">Tools and Frameworks</h2>
<div id="tbl-tools" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-tools-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: Model compression tools comparison
</figcaption>
<div aria-describedby="tbl-tools-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 21%">
<col style="width: 25%">
<col style="width: 17%">
<col style="width: 35%">
</colgroup>
<thead>
<tr class="header">
<th>Framework</th>
<th>Quantization</th>
<th>Pruning</th>
<th>Special Features</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>PyTorch</strong></td>
<td>torch.quantization</td>
<td>torch.nn.utils.prune</td>
<td>TorchScript optimization</td>
</tr>
<tr class="even">
<td><strong>TensorFlow</strong></td>
<td>Model Optimization Toolkit</td>
<td>Built-in pruning</td>
<td>TFLite for mobile</td>
</tr>
<tr class="odd">
<td><strong>NVIDIA TensorRT</strong></td>
<td>Automatic mixed precision</td>
<td>Layer fusion</td>
<td>High-performance inference</td>
</tr>
<tr class="even">
<td><strong>Intel Neural Compressor</strong></td>
<td>Cross-framework support</td>
<td>Auto-tuning</td>
<td>Hardware-specific optimizations</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<section id="specialized-tools" class="level3">
<h3 class="anchored" data-anchor-id="specialized-tools" id="specialized-tools">Specialized Tools</h3>
<p><strong>NVIDIA TensorRT</strong>:</p>
<ul>
<li>High-performance inference optimization</li>
<li>Automatic mixed precision</li>
<li>Layer fusion and kernel optimization</li>
</ul>
<p><strong>Intel Neural Compressor</strong>:</p>
<ul>
<li>Cross-framework quantization</li>
<li>Automatic accuracy-driven tuning</li>
<li>Hardware-specific optimizations</li>
</ul>
<p><strong>Apache TVM</strong>:</p>
<ul>
<li>Deep learning compiler stack</li>
<li>Auto-tuning for different hardware</li>
<li>Graph-level optimizations</li>
</ul>
<p><strong>ONNX Runtime</strong>:</p>
<ul>
<li>Cross-platform inference optimization</li>
<li>Dynamic quantization</li>
<li>Graph optimizations</li>
</ul>
</section>
</section>
<section id="sec-future" class="level2">
<h2 class="anchored" data-anchor-id="sec-future" id="sec-future">Future Directions</h2>
<section id="emerging-quantization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="emerging-quantization-techniques" id="emerging-quantization-techniques">Emerging Quantization Techniques</h3>
<ul>
<li><strong>Mixed-bit Networks</strong>: Different precisions for different operations</li>
<li><strong>Learned Quantization</strong>: Neural networks learn quantization parameters</li>
<li><strong>Hardware-Software Co-design</strong>: Quantization schemes designed for specific hardware</li>
</ul>
</section>
<section id="advanced-pruning-methods" class="level3">
<h3 class="anchored" data-anchor-id="advanced-pruning-methods" id="advanced-pruning-methods">Advanced Pruning Methods</h3>
<ul>
<li><strong>Differentiable Pruning</strong>: End-to-end learning of sparse structures</li>
<li><strong>Dynamic Sparsity</strong>: Runtime adaptation of sparsity patterns</li>
<li><strong>Cross-layer Dependencies</strong>: Pruning decisions considering global network structure</li>
</ul>
</section>
<section id="integration-with-other-techniques" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-other-techniques" id="integration-with-other-techniques">Integration with Other Techniques</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph TD
    A[Model Compression] --&gt; B[Neural Architecture Search]
    A --&gt; C[Federated Learning]
    A --&gt; D[Continual Learning]
    B --&gt; E[Joint Architecture &amp; Compression Optimization]
    C --&gt; F[Compression for Distributed Training]
    D --&gt; G[Maintaining Compression Benefits]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="hardware-considerations" class="level3">
<h3 class="anchored" data-anchor-id="hardware-considerations" id="hardware-considerations">Hardware Considerations</h3>
<ul>
<li><strong>Specialized Accelerators</strong>: ASICs designed for sparse and low-precision computation</li>
<li><strong>In-memory Computing</strong>: Compression for neuromorphic and analog computing</li>
<li><strong>Edge AI Chips</strong>: Dedicated hardware for compressed model inference</li>
</ul>
</section>
</section>
<section id="sec-conclusion" class="level2">
<h2 class="anchored" data-anchor-id="sec-conclusion" id="sec-conclusion">Conclusion</h2>
<p>Quantization and pruning are essential techniques for practical deep learning deployment. Success requires understanding the trade-offs between compression ratio, accuracy preservation, and hardware compatibility. The field continues to evolve with new methods that push the boundaries of what’s possible with compressed neural networks.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Start with well-established techniques (8-bit quantization, magnitude pruning)</li>
<li>Always validate on representative data and deployment hardware</li>
<li>Consider the entire deployment pipeline, not just model accuracy</li>
<li>Combine multiple compression techniques for maximum benefit</li>
<li>Stay informed about hardware-specific optimizations and emerging methods</li>
</ul>
</div>
</div>
<p>The future of neural network compression lies in automated, hardware-aware optimization that considers the full spectrum of deployment constraints while maintaining the intelligence and capabilities that make deep learning so powerful.</p>
<hr>
</section>
<section id="appendix-additional-resources" class="level2">
<h2 class="anchored" data-anchor-id="appendix-additional-resources" id="appendix-additional-resources">Appendix: Additional Resources</h2>
<section id="code-repositories" class="level3">
<h3 class="anchored" data-anchor-id="code-repositories" id="code-repositories">Code Repositories</h3>
<ul>
<li><a href="https://github.com/pytorch/pytorch">PyTorch Model Optimization</a></li>
<li><a href="https://github.com/tensorflow/model-optimization">TensorFlow Model Optimization</a></li>
<li><a href="https://github.com/intel/neural-compressor">Neural Compressor</a></li>
</ul>
</section>
<section id="research-papers" class="level3">
<h3 class="anchored" data-anchor-id="research-papers" id="research-papers">Research Papers</h3>
<ul>
<li>Lottery Ticket Hypothesis <span class="citation" data-cites="frankle2019lottery">[@frankle2019lottery]</span></li>
<li>Quantization and Training of Neural Networks <span class="citation" data-cites="jacob2018quantization">[@jacob2018quantization]</span></li>
<li>Structured Pruning Methods <span class="citation" data-cites="liu2017learning">[@liu2017learning]</span></li>
</ul>
</section>
<section id="datasets-for-evaluation" class="level3">
<h3 class="anchored" data-anchor-id="datasets-for-evaluation" id="datasets-for-evaluation">Datasets for Evaluation</h3>
<ul>
<li>ImageNet for computer vision models</li>
<li>GLUE benchmark for NLP models</li>
<li>Common Voice for speech models</li>
</ul>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to DINOv3: Self-Supervised Vision Transformers]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/dino-v3/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/dino-v3/</guid>
      <pubDate>Fri, 22 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-dinov3-self-supervised-vision-transformers" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/dino-v3/dino.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>DINOv3 represents a breakthrough in computer vision, offering the first truly universal vision backbone that achieves state-of-the-art performance across diverse visual tasks without requiring fine-tuning. Developed by Meta AI, DINOv3 scales self-supervised learning to unprecedented levels, training on 1.7 billion images with up to 7 billion parameters.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>DINOv3’s ability to produce high-quality, transferable features that work across different domains and tasks straight out of the box represents a significant advancement in foundation models for computer vision.</p>
</div>
</div>
</section>
<section id="what-is-dinov3" class="level2">
<h2 class="anchored" data-anchor-id="what-is-dinov3" id="what-is-dinov3">What is DINOv3?</h2>
<p>DINOv3 is a self-supervised learning method for computer vision that uses Vision Transformers (ViTs) to learn robust visual representations without labeled data. The key innovation lies in its ability to produce high-quality, transferable features that work across different domains and tasks straight out of the box.</p>
<section id="core-principles" class="level3">
<h3 class="anchored" data-anchor-id="core-principles" id="core-principles">Core Principles</h3>
<p><strong>Self-Supervised Learning</strong>: DINOv3 learns by comparing different views of the same image, using a teacher-student framework where the model learns to predict consistent representations across augmented versions of images.</p>
<p><strong>Universal Features</strong>: Unlike traditional models trained for specific tasks, DINOv3 produces general-purpose visual features that transfer well to various downstream applications.</p>
<p><strong>Scalability</strong>: The architecture is designed to scale effectively with both dataset size and model parameters, enabling training on massive datasets.</p>
</section>
</section>
<section id="evolution-from-dino-to-dinov3" class="level2">
<h2 class="anchored" data-anchor-id="evolution-from-dino-to-dinov3" id="evolution-from-dino-to-dinov3">Evolution from DINO to DINOv3</h2>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">timeline
    title Evolution of DINO Models
    
    2021 : DINO v1
         : Self-distillation with ViTs
         : Emergent segmentation properties
         : Limited scale
    
    2023 : DINO v2
         : Improved training methodology
         : Better data curation
         : Enhanced downstream performance
    
    2024 : DINO v3
         : Massive scale (1.7B images)
         : Universal backbone
         : 7B parameter models
         : State-of-the-art frozen performance
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<section id="dino-2021" class="level3">
<h3 class="anchored" data-anchor-id="dino-2021" id="dino-2021">DINO (2021)</h3>
<ul>
<li>Introduced self-distillation with Vision Transformers</li>
<li>Demonstrated emergent segmentation properties<br>
</li>
<li>Limited to smaller scales and datasets</li>
</ul>
</section>
<section id="dinov2-2023" class="level3">
<h3 class="anchored" data-anchor-id="dinov2-2023" id="dinov2-2023">DINOv2 (2023)</h3>
<ul>
<li>Improved training methodology</li>
<li>Better data curation techniques</li>
<li>Enhanced performance on downstream tasks</li>
</ul>
</section>
<section id="dinov3-2024" class="level3">
<h3 class="anchored" data-anchor-id="dinov3-2024" id="dinov3-2024">DINOv3 (2024)</h3>
<ul>
<li>Massive scale: 1.7 billion images, 7 billion parameters</li>
<li>First frozen backbone to outperform specialized models</li>
<li>Universal performance across domains (natural, aerial, medical images)</li>
<li>High-resolution feature extraction capabilities</li>
</ul>
</section>
</section>
<section id="key-features-and-capabilities" class="level2">
<h2 class="anchored" data-anchor-id="key-features-and-capabilities" id="key-features-and-capabilities">Key Features and Capabilities</h2>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Universal Vision Backbone</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">High-Resolution Features</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Frozen Model Performance</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-4" role="tab" aria-controls="tabset-1-4" aria-selected="false" href="">Emergent Properties</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li>Single model works across multiple domains without fine-tuning</li>
<li>Consistent performance on natural images, satellite imagery, and specialized domains</li>
<li>Eliminates need for domain-specific model training</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<ul>
<li>Produces detailed, semantically meaningful feature maps</li>
<li>Enables fine-grained understanding of image content</li>
<li>Supports dense prediction tasks effectively</li>
</ul>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<ul>
<li>Achieves state-of-the-art results without parameter updates</li>
<li>Reduces computational overhead for deployment</li>
<li>Simplifies integration into existing pipelines</li>
</ul>
</div>
<div id="tabset-1-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-4-tab">
<ul>
<li>Automatic semantic segmentation capabilities</li>
<li>Object localization without explicit training</li>
<li>Scene understanding and spatial reasoning</li>
</ul>
</div>
</div>
</div>
</section>
<section id="technical-architecture" class="level2">
<h2 class="anchored" data-anchor-id="technical-architecture" id="technical-architecture">Technical Architecture</h2>
<section id="vision-transformer-backbone" class="level3">
<h3 class="anchored" data-anchor-id="vision-transformer-backbone" id="vision-transformer-backbone">Vision Transformer Backbone</h3>
<p>DINOv3 builds upon the Vision Transformer architecture with several key modifications:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A[Input Image] --&gt; B[Patch Embedding]
    B --&gt; C[Positional Encoding]
    C --&gt; D[Transformer Blocks]
    D --&gt; E[Feature Extraction]
    E --&gt; F[Output Features]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="self-distillation-framework" class="level3">
<h3 class="anchored" data-anchor-id="self-distillation-framework" id="self-distillation-framework">Self-Distillation Framework</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Teacher-Student Learning
</div>
</div>
<div class="callout-body-container callout-body">
<p>The self-distillation framework consists of two networks: a teacher network (exponential moving average of student weights) and a student network (main learning network).</p>
</div>
</div>
<p><strong>Teacher Network</strong>:</p>
<ul>
<li>Exponential moving average of student weights</li>
<li>Produces stable target representations</li>
<li>Uses centering and sharpening operations</li>
</ul>
<p><strong>Student Network</strong>:</p>
<ul>
<li>Main learning network</li>
<li>Processes augmented image views</li>
<li>Minimizes distance to teacher representations</li>
</ul>
</section>
<section id="key-components" class="level3">
<h3 class="anchored" data-anchor-id="key-components" id="key-components">Key Components</h3>
<ol type="1">
<li><strong>Patch Embedding</strong>: Divides images into patches and projects them to embedding space</li>
<li><strong>Multi-Head Attention</strong>: Captures relationships between image patches</li>
<li><strong>Feed-Forward Networks</strong>: Processes attention outputs</li>
<li><strong>Layer Normalization</strong>: Stabilizes training</li>
<li><strong>CLS Token</strong>: Global image representation</li>
</ol>
</section>
</section>
<section id="training-methodology" class="level2">
<h2 class="anchored" data-anchor-id="training-methodology" id="training-methodology">Training Methodology</h2>
<section id="dataset-curation" class="level3">
<h3 class="anchored" data-anchor-id="dataset-curation" id="dataset-curation">Dataset Curation</h3>
<ul>
<li><strong>Scale</strong>: 1.7 billion images from diverse sources</li>
<li><strong>Quality Control</strong>: Automated filtering and deduplication</li>
<li><strong>Diversity</strong>: Natural images, web content, satellite imagery</li>
<li><strong>Resolution</strong>: High-resolution training for detailed features</li>
</ul>
</section>
<section id="training-process" class="level3">
<h3 class="anchored" data-anchor-id="training-process" id="training-process">Training Process</h3>
<ol type="1">
<li><strong>Data Augmentation</strong>: Multiple views of each image through crops, color jittering, and geometric transforms</li>
<li><strong>Teacher-Student Learning</strong>: Student network learns to match teacher predictions</li>
<li><strong>Multi-Crop Strategy</strong>: Uses global and local crops for comprehensive understanding</li>
<li><strong>Loss Function</strong>: Cross-entropy between student and teacher outputs</li>
<li><strong>Optimization</strong>: AdamW optimizer with cosine learning rate schedule</li>
</ol>
</section>
<section id="training-infrastructure" class="level3">
<h3 class="anchored" data-anchor-id="training-infrastructure" id="training-infrastructure">Training Infrastructure</h3>
<ul>
<li>Distributed training across multiple GPUs</li>
<li>Gradient accumulation for effective large batch training</li>
<li>Mixed precision for memory efficiency</li>
<li>Checkpoint saving and resumption capabilities</li>
</ul>
</section>
</section>
<section id="model-variants-and-specifications" class="level2">
<h2 class="anchored" data-anchor-id="model-variants-and-specifications" id="model-variants-and-specifications">Model Variants and Specifications</h2>
<section id="available-models" class="level3">
<h3 class="anchored" data-anchor-id="available-models" id="available-models">Available Models</h3>
<div id="tbl-model-variants" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-model-variants-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Model variants and their specifications
</figcaption>
<div aria-describedby="tbl-model-variants-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 12%">
<col style="width: 18%">
<col style="width: 20%">
<col style="width: 31%">
<col style="width: 17%">
</colgroup>
<thead>
<tr class="header">
<th>Model</th>
<th>Parameters</th>
<th>Patch Size</th>
<th>Input Resolution</th>
<th>Use Case</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>DINOv3-ViT-S/16</td>
<td>22M</td>
<td>16×16</td>
<td>224×224+</td>
<td>Lightweight applications</td>
</tr>
<tr class="even">
<td>DINOv3-ViT-B/16</td>
<td>86M</td>
<td>16×16</td>
<td>224×224+</td>
<td>Balanced performance</td>
</tr>
<tr class="odd">
<td>DINOv3-ViT-L/16</td>
<td>307M</td>
<td>16×16</td>
<td>224×224+</td>
<td>High performance</td>
</tr>
<tr class="even">
<td>DINOv3-ViT-g/16</td>
<td>1.1B</td>
<td>16×16</td>
<td>224×224+</td>
<td>Maximum capability</td>
</tr>
<tr class="odd">
<td>DINOv3-ViT-G/16</td>
<td>7B</td>
<td>16×16</td>
<td>518×518+</td>
<td>Research and high-end applications</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="model-selection-guidelines" class="level3">
<h3 class="anchored" data-anchor-id="model-selection-guidelines" id="model-selection-guidelines">Model Selection Guidelines</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Choosing the Right Model
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Small (S)</strong>: Mobile and edge applications, real-time inference</li>
<li><strong>Base (B)</strong>: General purpose, good balance of speed and accuracy</li>
<li><strong>Large (L)</strong>: High-accuracy applications, research</li>
<li><strong>Giant (g/G)</strong>: Maximum performance, resource-rich environments</li>
</ul>
</div>
</div>
</section>
</section>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Python 3.8+</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># PyTorch 1.12+</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="co"># CUDA (for GPU acceleration)</span></span></code></pre></div></div>
</section>
<section id="installation-options" class="level3">
<h3 class="anchored" data-anchor-id="installation-options" id="installation-options">Installation Options</h3>
<section id="option-1-using-hugging-face-transformers" class="level4">
<h4 class="anchored" data-anchor-id="option-1-using-hugging-face-transformers">Option 1: Using Hugging Face Transformers</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install transformers torch torchvision</span></code></pre></div></div>
</section>
<section id="option-2-from-source" class="level4">
<h4 class="anchored" data-anchor-id="option-2-from-source">Option 2: From Source</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/facebookresearch/dinov3.git</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> dinov3</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span></code></pre></div></div>
</section>
<section id="option-3-using-pre-built-containers" class="level4">
<h4 class="anchored" data-anchor-id="option-3-using-pre-built-containers">Option 3: Using Pre-built Containers</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> pull pytorch/pytorch:latest</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Add DINOv3 installation commands</span></span></code></pre></div></div>
</section>
</section>
<section id="environment-setup" class="level3">
<h3 class="anchored" data-anchor-id="environment-setup" id="environment-setup">Environment Setup</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create conda environment</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> create <span class="at">-n</span> dinov3 python=3.9</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> activate dinov3</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install dependencies</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision torchaudio</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install transformers pillow numpy matplotlib</span></code></pre></div></div>
</section>
</section>
<section id="usage-examples" class="level2">
<h2 class="anchored" data-anchor-id="usage-examples" id="usage-examples">Usage Examples</h2>
<section id="basic-feature-extraction" class="level3">
<h3 class="anchored" data-anchor-id="basic-feature-extraction" id="basic-feature-extraction">Basic Feature Extraction</h3>
<div id="basic-feature-extraction" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> DINOv3Model, DINOv3ImageProcessor</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model and processor</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> DINOv3ImageProcessor.from_pretrained(</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'facebook/dinov3-vits16-pretrain-lvd1689m'</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DINOv3Model.from_pretrained(</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'facebook/dinov3-vits16-pretrain-lvd1689m'</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Load and process image</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">'path/to/your/image.jpg'</span>)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>inputs <span class="op">=</span> processor(image, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract features</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(<span class="op">**</span>inputs)</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    features <span class="op">=</span> outputs.last_hidden_state</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    cls_token <span class="op">=</span> features[:, <span class="dv">0</span>]  <span class="co"># Global representation</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    patch_features <span class="op">=</span> features[:, <span class="dv">1</span>:]  <span class="co"># Patch-level features</span></span></code></pre></div></div>
</div>
</section>
<section id="batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing" id="batch-processing">Batch Processing</h3>
<div id="batch-processing" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom dataset class</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ImageDataset(torch.utils.data.Dataset):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_dir, transform<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_dir <span class="op">=</span> image_dir</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_files <span class="op">=</span> [</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            f <span class="cf">for</span> f <span class="kw">in</span> os.listdir(image_dir) </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> f.endswith((<span class="st">'.jpg'</span>, <span class="st">'.png'</span>))</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transform</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.image_files)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        image_path <span class="op">=</span> os.path.join(<span class="va">self</span>.image_dir, <span class="va">self</span>.image_files[idx])</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).convert(<span class="st">'RGB'</span>)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.transform:</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> <span class="va">self</span>.transform(image)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup data loading</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    transforms.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>dataset <span class="op">=</span> ImageDataset(<span class="st">'path/to/images'</span>, transform<span class="op">=</span>transform)</span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>dataloader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Process batches</span></span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>all_features <span class="op">=</span> []</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(pixel_values<span class="op">=</span>batch)</span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">0</span>]  <span class="co"># CLS tokens</span></span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>        all_features.append(features)</span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>all_features <span class="op">=</span> torch.cat(all_features, dim<span class="op">=</span><span class="dv">0</span>)</span></code></pre></div></div>
</div>
</section>
<section id="fine-tuning-for-classification" class="level3">
<h3 class="anchored" data-anchor-id="fine-tuning-for-classification" id="fine-tuning-for-classification">Fine-tuning for Classification</h3>
<div id="classification-finetuning" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> DINOv3Model</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv3Classifier(nn.Module):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">1000</span>, </span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>                 pretrained_model_name<span class="op">=</span><span class="st">'facebook/dinov3-vits16-pretrain-lvd1689m'</span>):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> DINOv3Model.from_pretrained(pretrained_model_name)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.backbone.config.hidden_size, </span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            num_classes</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, pixel_values):</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.backbone(pixel_values<span class="op">=</span>pixel_values)</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        cls_token <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">0</span>]</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.classifier(cls_token)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DINOv3Classifier(num_classes<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop would go here</span></span></code></pre></div></div>
</div>
</section>
<section id="semantic-segmentation-setup" class="level3">
<h3 class="anchored" data-anchor-id="semantic-segmentation-setup" id="semantic-segmentation-setup">Semantic Segmentation Setup</h3>
<div id="segmentation-setup" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> DINOv3Model</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv3Segmentation(nn.Module):</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes, </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>                 pretrained_model_name<span class="op">=</span><span class="st">'facebook/dinov3-vits16-pretrain-lvd1689m'</span>):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> DINOv3Model.from_pretrained(pretrained_model_name)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.decode_head <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="va">self</span>.backbone.config.hidden_size, <span class="dv">256</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">256</span>),</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">256</span>, num_classes, <span class="dv">1</span>)</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, pixel_values):</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        B, C, H, W <span class="op">=</span> pixel_values.shape</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.backbone(pixel_values<span class="op">=</span>pixel_values)</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        patch_features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">1</span>:]  <span class="co"># Remove CLS token</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Reshape to spatial dimensions</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        patch_size <span class="op">=</span> <span class="dv">16</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        h_patches, w_patches <span class="op">=</span> H <span class="op">//</span> patch_size, W <span class="op">//</span> patch_size</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> patch_features.reshape(B, h_patches, w_patches, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> features.permute(<span class="dv">0</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">2</span>)  <span class="co"># B, C, H, W</span></span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Upsample and classify</span></span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> nn.functional.interpolate(</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>            features, size<span class="op">=</span>(H, W), mode<span class="op">=</span><span class="st">'bilinear'</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.decode_head(features)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="computer-vision-tasks" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-tasks" id="computer-vision-tasks">Computer Vision Tasks</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Object Detection</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Semantic Segmentation</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-3" role="tab" aria-controls="tabset-2-3" aria-selected="false" href="">Instance Segmentation</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<ul>
<li>Use DINOv3 features with detection heads (DETR, FasterRCNN)</li>
<li>Excellent performance without fine-tuning</li>
<li>Works across diverse object categories</li>
</ul>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<ul>
<li>Dense pixel-level predictions</li>
<li>High-quality boundary detection</li>
<li>Effective for medical imaging, aerial imagery</li>
</ul>
</div>
<div id="tabset-2-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-3-tab">
<ul>
<li>Combines detection and segmentation</li>
<li>Useful for counting and analysis applications</li>
<li>Good generalization to new domains</li>
</ul>
</div>
</div>
</div>
</section>
<section id="content-understanding" class="level3">
<h3 class="anchored" data-anchor-id="content-understanding" id="content-understanding">Content Understanding</h3>
<p><strong>Image Retrieval</strong></p>
<ul>
<li>Use CLS token as global image descriptor</li>
<li>Efficient similarity search in large databases</li>
<li>Cross-domain retrieval capabilities</li>
</ul>
<p><strong>Content Moderation</strong></p>
<ul>
<li>Detect inappropriate or harmful content</li>
<li>Classify image types and categories</li>
<li>Identify policy violations</li>
</ul>
<p><strong>Quality Assessment</strong></p>
<ul>
<li>Assess image quality and aesthetics</li>
<li>Detect blurriness, artifacts, or corruption</li>
<li>Content filtering and ranking</li>
</ul>
</section>
<section id="scientific-applications" class="level3">
<h3 class="anchored" data-anchor-id="scientific-applications" id="scientific-applications">Scientific Applications</h3>
<p><strong>Medical Imaging</strong></p>
<ul>
<li>Pathology analysis</li>
<li>Radiology image understanding</li>
<li>Drug discovery applications</li>
</ul>
<p><strong>Satellite Imagery</strong></p>
<ul>
<li>Land use classification</li>
<li>Environmental monitoring</li>
<li>Urban planning and development</li>
</ul>
<p><strong>Biological Research</strong></p>
<ul>
<li>Cell counting and classification</li>
<li>Microscopy image analysis</li>
<li>Species identification</li>
</ul>
</section>
<section id="creative-and-media-applications" class="level3">
<h3 class="anchored" data-anchor-id="creative-and-media-applications" id="creative-and-media-applications">Creative and Media Applications</h3>
<p><strong>Art and Design</strong></p>
<ul>
<li>Style transfer and generation</li>
<li>Content-aware editing</li>
<li>Creative asset organization</li>
</ul>
<p><strong>Video Analysis</strong></p>
<ul>
<li>Frame-level understanding</li>
<li>Action recognition</li>
<li>Video summarization</li>
</ul>
</section>
</section>
<section id="performance-and-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-benchmarks" id="performance-and-benchmarks">Performance and Benchmarks</h2>
<section id="imagenet-classification" class="level3">
<h3 class="anchored" data-anchor-id="imagenet-classification" id="imagenet-classification">ImageNet Classification</h3>
<ul>
<li><strong>Linear Probing</strong>: 84.5% top-1 accuracy (ViT-G)</li>
<li><strong>k-NN Classification</strong>: 82.1% top-1 accuracy</li>
<li><strong>Few-shot Learning</strong>: Superior performance with limited data</li>
</ul>
</section>
<section id="dense-prediction-tasks" class="level3">
<h3 class="anchored" data-anchor-id="dense-prediction-tasks" id="dense-prediction-tasks">Dense Prediction Tasks</h3>
<ul>
<li><strong>ADE20K Segmentation</strong>: 58.8 mIoU</li>
<li><strong>COCO Detection</strong>: 59.3 AP (Mask R-CNN)</li>
<li><strong>Video Segmentation</strong>: State-of-the-art on DAVIS</li>
</ul>
</section>
<section id="cross-domain-performance" class="level3">
<h3 class="anchored" data-anchor-id="cross-domain-performance" id="cross-domain-performance">Cross-Domain Performance</h3>
<ul>
<li><strong>Natural Images</strong>: Excellent baseline performance</li>
<li><strong>Aerial Imagery</strong>: 15-20% improvement over supervised baselines</li>
<li><strong>Medical Images</strong>: Strong transfer learning capabilities</li>
</ul>
</section>
<section id="computational-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h3>
<ul>
<li><strong>Inference Speed</strong>: Competitive with supervised models</li>
<li><strong>Memory Usage</strong>: Efficient attention mechanisms</li>
<li><strong>Scalability</strong>: Linear scaling with input resolution</li>
</ul>
</section>
</section>
<section id="advantages-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="advantages-and-limitations" id="advantages-and-limitations">Advantages and Limitations</h2>
<section id="advantages" class="level3">
<h3 class="anchored" data-anchor-id="advantages" id="advantages">Advantages</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Strengths
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Universal Applicability</strong></p>
<ul>
<li>Single model for multiple tasks</li>
<li>No fine-tuning required for many applications</li>
<li>Consistent performance across domains</li>
</ul>
<p><strong>High-Quality Features</strong></p>
<ul>
<li>Rich semantic representations</li>
<li>Fine-grained spatial information</li>
<li>Emergent properties like segmentation</li>
</ul>
<p><strong>Scalability</strong></p>
<ul>
<li>Effective use of large datasets</li>
<li>Scales well with model size</li>
<li>Efficient training methodology</li>
</ul>
<p><strong>Research Impact</strong></p>
<ul>
<li>Pushes boundaries of self-supervised learning</li>
<li>Demonstrates viability of foundation models in vision</li>
<li>Enables new research directions</li>
</ul>
</div>
</div>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Current Constraints
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Computational Requirements</strong></p>
<ul>
<li>Large models require significant resources</li>
<li>High memory usage during training</li>
<li>GPU-intensive inference for large variants</li>
</ul>
<p><strong>Data Dependency</strong></p>
<ul>
<li>Performance depends on training data quality</li>
<li>May have biases from training dataset</li>
<li>Limited performance on very specialized domains</li>
</ul>
<p><strong>Interpretability</strong></p>
<ul>
<li>Complex attention mechanisms</li>
<li>Difficult to understand learned representations</li>
<li>Black-box nature of transformers</li>
</ul>
<p><strong>Task-Specific Limitations</strong></p>
<ul>
<li>May not match specialized models for specific tasks</li>
<li>Requires additional components for some applications</li>
<li>Not optimized for real-time mobile applications</li>
</ul>
</div>
</div>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="technical-improvements" class="level3">
<h3 class="anchored" data-anchor-id="technical-improvements" id="technical-improvements">Technical Improvements</h3>
<p><strong>Architecture Enhancements</strong></p>
<ul>
<li>More efficient attention mechanisms</li>
<li>Better handling of high-resolution images</li>
<li>Improved spatial reasoning capabilities</li>
</ul>
<p><strong>Training Methodology</strong></p>
<ul>
<li>Better data curation strategies</li>
<li>More efficient self-supervised objectives</li>
<li>Multi-modal learning integration</li>
</ul>
<p><strong>Scalability</strong></p>
<ul>
<li>Even larger models and datasets</li>
<li>Better distributed training techniques</li>
<li>More efficient inference methods</li>
</ul>
</section>
<section id="application-areas" class="level3">
<h3 class="anchored" data-anchor-id="application-areas" id="application-areas">Application Areas</h3>
<p><strong>Multimodal Learning</strong></p>
<ul>
<li>Integration with language models</li>
<li>Vision-language understanding</li>
<li>Cross-modal retrieval and generation</li>
</ul>
<p><strong>Real-time Applications</strong></p>
<ul>
<li>Mobile and edge deployment</li>
<li>Real-time video processing</li>
<li>Interactive applications</li>
</ul>
<p><strong>Specialized Domains</strong></p>
<ul>
<li>Domain-specific fine-tuning strategies</li>
<li>Better handling of specialized imagery</li>
<li>Integration with domain knowledge</li>
</ul>
</section>
<section id="research-opportunities" class="level3">
<h3 class="anchored" data-anchor-id="research-opportunities" id="research-opportunities">Research Opportunities</h3>
<p><strong>Foundation Models</strong></p>
<ul>
<li>Vision-centric foundation models</li>
<li>Integration with other modalities</li>
<li>Unified multimodal architectures</li>
</ul>
<p><strong>Self-Supervised Learning</strong></p>
<ul>
<li>New pretext tasks and objectives</li>
<li>Better theoretical understanding</li>
<li>More efficient training methods</li>
</ul>
<p><strong>Transfer Learning</strong></p>
<ul>
<li>Better understanding of transferability</li>
<li>Improved few-shot learning</li>
<li>Domain adaptation techniques</li>
</ul>
</section>
</section>
<section id="sec-resources" class="level2">
<h2 class="anchored" data-anchor-id="sec-resources" id="sec-resources">Resources and References</h2>
<section id="official-resources" class="level3">
<h3 class="anchored" data-anchor-id="official-resources" id="official-resources">Official Resources</h3>
<ul>
<li><strong>GitHub Repository</strong>: <a href="https://github.com/facebookresearch/dinov3">facebookresearch/dinov3</a></li>
<li><strong>Hugging Face Models</strong>: <a href="https://huggingface.co/facebook">facebook/dinov3-*</a></li>
<li><strong>Meta AI Blog</strong>: Technical blog posts and announcements</li>
<li><strong>ArXiv Papers</strong>: Latest research publications</li>
</ul>
</section>
<section id="documentation-and-tutorials" class="level3">
<h3 class="anchored" data-anchor-id="documentation-and-tutorials" id="documentation-and-tutorials">Documentation and Tutorials</h3>
<ul>
<li><strong>Hugging Face Documentation</strong>: Comprehensive usage guides</li>
<li><strong>PyTorch Tutorials</strong>: Integration with PyTorch ecosystem</li>
<li><strong>Community Tutorials</strong>: Third-party guides and examples</li>
</ul>
</section>
<section id="related-work" class="level3">
<h3 class="anchored" data-anchor-id="related-work" id="related-work">Related Work</h3>
<ul>
<li><strong>DINO</strong>: Original self-distillation paper</li>
<li><strong>DINOv2</strong>: Intermediate improvements</li>
<li><strong>Vision Transformers</strong>: Foundation architecture</li>
<li><strong>Self-Supervised Learning</strong>: Broader field context</li>
</ul>
</section>
<section id="community-and-support" class="level3">
<h3 class="anchored" data-anchor-id="community-and-support" id="community-and-support">Community and Support</h3>
<ul>
<li><strong>GitHub Issues</strong>: Bug reports and feature requests</li>
<li><strong>Research Community</strong>: Academic discussions and collaborations</li>
<li><strong>Industry Applications</strong>: Real-world deployment examples</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>DINOv3 represents a significant milestone in computer vision, demonstrating that self-supervised learning can produce universal visual features that rival or exceed specialized supervised models. Its ability to work across diverse domains without fine-tuning opens up new possibilities for practical applications and research directions.</p>
<p>The model’s success lies in its careful scaling of both data and model size, combined with effective self-supervised training techniques. As the field continues to evolve, DINOv3 provides a strong foundation for future developments in foundation models for computer vision.</p>
<p>Whether you’re a researcher exploring new frontiers in self-supervised learning or a practitioner looking to deploy state-of-the-art vision capabilities, DINOv3 offers a powerful and flexible solution that can adapt to a wide range of visual understanding tasks.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Looking Forward
</div>
</div>
<div class="callout-body-container callout-body">
<p>The success of DINOv3 paves the way for even more powerful and universal vision models, potentially leading to truly general-purpose computer vision systems that can understand and analyze visual content across any domain.</p>
</div>
</div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Reinforcement Learning]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/reinforced-learning/rl-basics/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/reinforced-learning/rl-basics/</guid>
      <pubDate>Fri, 22 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-reinforcement-learning" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/reinforced-learning/rl-basics/rl.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Reinforcement Learning (RL) is a paradigm of machine learning where an agent learns to make decisions by interacting with an environment to maximize cumulative rewards. Unlike supervised learning, where the correct answers are provided, or unsupervised learning, where patterns are discovered in data, reinforcement learning involves learning through trial and error based on feedback from the environment.</p>
<p>The inspiration for RL comes from behavioral psychology and how animals learn through rewards and punishments. This approach has proven remarkably effective for complex decision-making problems where the optimal strategy isn’t immediately apparent.</p>
</section>
<section id="core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts" id="core-concepts">Core Concepts</h2>
<section id="agent-and-environment" class="level3">
<h3 class="anchored" data-anchor-id="agent-and-environment" id="agent-and-environment">Agent and Environment</h3>
<p>The fundamental setup of RL involves two main components:</p>
<p><strong>Agent</strong>: The learner or decision-maker that takes actions in the environment. The agent’s goal is to learn a policy that maximizes expected cumulative reward.</p>
<p><strong>Environment</strong>: Everything the agent interacts with. It receives actions from the agent and returns observations (states) and rewards.</p>
</section>
<section id="key-elements" class="level3">
<h3 class="anchored" data-anchor-id="key-elements" id="key-elements">Key Elements</h3>
<p><strong>State (S)</strong>: A representation of the current situation in the environment. States can be fully observable (agent sees complete state) or partially observable (agent has limited information).</p>
<p><strong>Action (A)</strong>: Choices available to the agent at any given state. Actions can be discrete (finite set of options) or continuous (infinite possibilities within a range).</p>
<p><strong>Reward (R)</strong>: Numerical feedback from the environment indicating the immediate value of the agent’s action. Rewards can be sparse (only at terminal states) or dense (at every step).</p>
<p><strong>Policy (π)</strong>: The agent’s strategy for choosing actions given states. Can be deterministic (always same action for same state) or stochastic (probability distribution over actions).</p>
<p><strong>Value Function</strong>: Estimates the expected cumulative reward from a given state or state-action pair under a particular policy.</p>
</section>
<section id="the-rl-loop" class="level3">
<h3 class="anchored" data-anchor-id="the-rl-loop" id="the-rl-loop">The RL Loop</h3>
<ol type="1">
<li>Agent observes current state</li>
<li>Agent selects action based on current policy</li>
<li>Environment transitions to new state</li>
<li>Environment provides reward signal</li>
<li>Agent updates its knowledge/policy</li>
<li>Process repeats</li>
</ol>
</section>
<section id="exploration-vs-exploitation" class="level3">
<h3 class="anchored" data-anchor-id="exploration-vs-exploitation" id="exploration-vs-exploitation">Exploration vs Exploitation</h3>
<p>One of the central challenges in RL is balancing exploration (trying new actions to discover better strategies) with exploitation (using current knowledge to maximize immediate reward). This tradeoff is crucial because:</p>
<ul>
<li>Pure exploitation may miss better long-term strategies</li>
<li>Pure exploration wastes opportunities to use known good strategies</li>
<li>The optimal balance depends on the problem and learning phase</li>
</ul>
</section>
</section>
<section id="mathematical-foundations" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-foundations" id="mathematical-foundations">Mathematical Foundations</h2>
<section id="markov-decision-process-mdp" class="level3">
<h3 class="anchored" data-anchor-id="markov-decision-process-mdp" id="markov-decision-process-mdp">Markov Decision Process (MDP)</h3>
<p>Most RL problems are formalized as MDPs, defined by the tuple (S, A, P, R, γ):</p>
<ul>
<li>S: Set of states</li>
<li>A: Set of actions<br>
</li>
<li>P: State transition probabilities P(s’|s,a)</li>
<li>R: Reward function R(s,a,s’)</li>
<li>γ: Discount factor (0 ≤ γ ≤ 1)</li>
</ul>
<p>The Markov property states that the future depends only on the current state, not the history of how we arrived there.</p>
</section>
<section id="bellman-equations" class="level3">
<h3 class="anchored" data-anchor-id="bellman-equations" id="bellman-equations">Bellman Equations</h3>
<p>The Bellman equations provide the foundation for many RL algorithms:</p>
<p><strong>State Value Function</strong>: <span class="math display">\[
V^π(s) = \mathbb{E}[R_{t+1} + γV^π(S_{t+1}) | S_t = s]
\]</span></p>
<p><strong>Action Value Function (Q-function)</strong>: <span class="math display">\[
Q^π(s,a) = \mathbb{E}[R_{t+1} + γQ^π(S_{t+1}, A_{t+1}) | S_t = s, A_t = a]
\]</span></p>
<p><strong>Optimal Bellman Equations</strong>: <span class="math display">\[
V^*(s) = \max_a \sum_{s'} P(s'|s,a)[R(s,a,s') + γV^*(s')]
\]</span></p>
<p><span class="math display">\[
Q^*(s,a) = \sum_{s'} P(s'|s,a)[R(s,a,s') + γ \max_{a'} Q^*(s',a')]
\]</span></p>
</section>
<section id="convergence-and-optimality" class="level3">
<h3 class="anchored" data-anchor-id="convergence-and-optimality" id="convergence-and-optimality">Convergence and Optimality</h3>
<p>Under certain conditions (finite state/action spaces, proper discount factor), RL algorithms are guaranteed to converge to optimal policies. The policy improvement theorem provides theoretical backing for iterative policy improvement methods.</p>
</section>
</section>
<section id="key-algorithms" class="level2">
<h2 class="anchored" data-anchor-id="key-algorithms" id="key-algorithms">Key Algorithms</h2>
<section id="model-based-methods" class="level3">
<h3 class="anchored" data-anchor-id="model-based-methods" id="model-based-methods">Model-Based Methods</h3>
<p><strong>Dynamic Programming</strong></p>
<ul>
<li><strong>Policy Iteration</strong>: Alternates between policy evaluation and policy improvement</li>
<li><strong>Value Iteration</strong>: Directly computes optimal value function, then derives policy</li>
<li>Requires complete knowledge of environment dynamics</li>
<li>Guaranteed convergence but computationally expensive for large state spaces</li>
</ul>
</section>
<section id="model-free-methods" class="level3">
<h3 class="anchored" data-anchor-id="model-free-methods" id="model-free-methods">Model-Free Methods</h3>
<p><strong>Temporal Difference Learning</strong></p>
<ul>
<li><strong>Q-Learning</strong>: Off-policy method that learns optimal action values
<ul>
<li>Update rule: <span class="math inline">\(Q(s,a) \leftarrow Q(s,a) + α[r + γ \max_{a'} Q(s',a') - Q(s,a)]\)</span></li>
<li>Explores using ε-greedy or other exploration strategies</li>
<li>Proven to converge to optimal Q-function</li>
</ul></li>
<li><strong>SARSA (State-Action-Reward-State-Action)</strong>: On-policy method
<ul>
<li>Update rule: <span class="math inline">\(Q(s,a) \leftarrow Q(s,a) + α[r + γ Q(s',a') - Q(s,a)]\)</span></li>
<li>Uses actual next action taken by current policy</li>
<li>More conservative than Q-learning</li>
</ul></li>
</ul>
<p><strong>Policy Gradient Methods</strong></p>
<ul>
<li>Directly optimize policy parameters using gradient ascent</li>
<li><strong>REINFORCE</strong>: Basic policy gradient algorithm using Monte Carlo returns</li>
<li><strong>Actor-Critic</strong>: Combines value function estimation with policy optimization
<ul>
<li>Actor: Updates policy parameters</li>
<li>Critic: Estimates value function to reduce variance</li>
</ul></li>
<li>Better for continuous action spaces and stochastic policies</li>
</ul>
</section>
<section id="monte-carlo-methods" class="level3">
<h3 class="anchored" data-anchor-id="monte-carlo-methods" id="monte-carlo-methods">Monte Carlo Methods</h3>
<ul>
<li>Learn from complete episodes</li>
<li>No bootstrapping (unlike TD methods)</li>
<li>High variance but unbiased estimates</li>
<li>Suitable when episodes are short and environment is episodic</li>
</ul>
</section>
</section>
<section id="deep-reinforcement-learning" class="level2">
<h2 class="anchored" data-anchor-id="deep-reinforcement-learning" id="deep-reinforcement-learning">Deep Reinforcement Learning</h2>
<section id="deep-q-networks-dqn" class="level3">
<h3 class="anchored" data-anchor-id="deep-q-networks-dqn" id="deep-q-networks-dqn">Deep Q-Networks (DQN)</h3>
<p>Combines Q-learning with deep neural networks to handle high-dimensional state spaces:</p>
<p><strong>Key Innovations</strong>:</p>
<ul>
<li><strong>Experience Replay</strong>: Store and randomly sample past experiences to break correlation</li>
<li><strong>Target Network</strong>: Use separate network for computing targets to stabilize learning</li>
<li><strong>Function Approximation</strong>: Neural networks approximate Q-values for large state spaces</li>
</ul>
<p><strong>Improvements</strong>:</p>
<ul>
<li><strong>Double DQN</strong>: Addresses overestimation bias in Q-learning</li>
<li><strong>Dueling DQN</strong>: Separates state value and advantage estimation</li>
<li><strong>Prioritized Experience Replay</strong>: Sample important experiences more frequently</li>
<li><strong>Rainbow DQN</strong>: Combines multiple improvements for state-of-the-art performance</li>
</ul>
</section>
<section id="policy-gradient-methods" class="level3">
<h3 class="anchored" data-anchor-id="policy-gradient-methods" id="policy-gradient-methods">Policy Gradient Methods</h3>
<p><strong>Proximal Policy Optimization (PPO)</strong></p>
<ul>
<li>Clips policy updates to prevent destructive large changes</li>
<li>Simpler and more stable than other policy gradient methods</li>
<li>Widely used in practice due to reliability</li>
</ul>
<p><strong>Trust Region Policy Optimization (TRPO)</strong></p>
<ul>
<li>Constrains policy updates within trust region</li>
<li>Provides theoretical guarantees on policy improvement</li>
<li>More complex than PPO but stronger theoretical foundation</li>
</ul>
<p><strong>Actor-Critic Methods</strong></p>
<ul>
<li><strong>A3C (Asynchronous Actor-Critic)</strong>: Parallel training with multiple agents</li>
<li><strong>A2C (Advantage Actor-Critic)</strong>: Synchronous version of A3C</li>
<li><strong>SAC (Soft Actor-Critic)</strong>: Off-policy method with entropy regularization</li>
</ul>
</section>
<section id="deep-deterministic-policy-gradient-ddpg" class="level3">
<h3 class="anchored" data-anchor-id="deep-deterministic-policy-gradient-ddpg" id="deep-deterministic-policy-gradient-ddpg">Deep Deterministic Policy Gradient (DDPG)</h3>
<ul>
<li>Extends DQN to continuous action spaces</li>
<li>Uses actor-critic architecture with deterministic policies</li>
<li>Employs target networks and experience replay like DQN</li>
</ul>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="multi-agent-reinforcement-learning-marl" class="level3">
<h3 class="anchored" data-anchor-id="multi-agent-reinforcement-learning-marl" id="multi-agent-reinforcement-learning-marl">Multi-Agent Reinforcement Learning (MARL)</h3>
<p>When multiple agents interact in the same environment:</p>
<ul>
<li><strong>Cooperative</strong>: Agents share common goal</li>
<li><strong>Competitive</strong>: Zero-sum or adversarial setting<br>
</li>
<li><strong>Mixed-Motive</strong>: Combination of cooperation and competition</li>
</ul>
<p>Challenges include non-stationarity (other agents are learning too), credit assignment, and communication.</p>
</section>
<section id="hierarchical-reinforcement-learning" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-reinforcement-learning" id="hierarchical-reinforcement-learning">Hierarchical Reinforcement Learning</h3>
<p>Structures learning across multiple temporal scales:</p>
<ul>
<li><strong>Options Framework</strong>: Semi-Markov decision processes with temporal abstractions</li>
<li><strong>Feudal Networks</strong>: Hierarchical structure with managers and workers</li>
<li><strong>HAM (Hierarchy of Abstract Machines)</strong>: Formal framework for hierarchical policies</li>
</ul>
<p>Benefits include faster learning, better exploration, and transferable skills.</p>
</section>
<section id="transfer-learning-and-meta-learning" class="level3">
<h3 class="anchored" data-anchor-id="transfer-learning-and-meta-learning" id="transfer-learning-and-meta-learning">Transfer Learning and Meta-Learning</h3>
<ul>
<li><strong>Transfer Learning</strong>: Apply knowledge from one task to related tasks</li>
<li><strong>Meta-Learning</strong>: Learn how to learn quickly on new tasks</li>
<li><strong>Few-Shot Learning</strong>: Quickly adapt to new tasks with minimal data</li>
</ul>
</section>
<section id="partial-observability" class="level3">
<h3 class="anchored" data-anchor-id="partial-observability" id="partial-observability">Partial Observability</h3>
<p>When agents can’t observe complete state:</p>
<ul>
<li><strong>POMDPs (Partially Observable MDPs)</strong>: Formal framework with belief states</li>
<li><strong>Recurrent Networks</strong>: Use memory to maintain state estimates</li>
<li><strong>Attention Mechanisms</strong>: Focus on relevant parts of observation history</li>
</ul>
</section>
<section id="safety-and-robustness" class="level3">
<h3 class="anchored" data-anchor-id="safety-and-robustness" id="safety-and-robustness">Safety and Robustness</h3>
<p>Critical considerations for real-world deployment:</p>
<ul>
<li><strong>Safe Exploration</strong>: Avoid dangerous actions during learning</li>
<li><strong>Robust RL</strong>: Handle uncertainty and distribution shift</li>
<li><strong>Constrained RL</strong>: Satisfy safety constraints while optimizing rewards</li>
<li><strong>Interpretability</strong>: Understanding agent decision-making process</li>
</ul>
</section>
</section>
<section id="applications" class="level2">
<h2 class="anchored" data-anchor-id="applications" id="applications">Applications</h2>
<section id="game-playing" class="level3">
<h3 class="anchored" data-anchor-id="game-playing" id="game-playing">Game Playing</h3>
<ul>
<li><strong>Board Games</strong>: Chess (Deep Blue), Go (AlphaGo, AlphaZero)</li>
<li><strong>Video Games</strong>: Atari games (DQN), StarCraft II (AlphaStar), Dota 2 (OpenAI Five)</li>
<li><strong>Card Games</strong>: Poker (Libratus, Pluribus)</li>
</ul>
</section>
<section id="robotics" class="level3">
<h3 class="anchored" data-anchor-id="robotics" id="robotics">Robotics</h3>
<ul>
<li><strong>Manipulation</strong>: Grasping, assembly, dexterous manipulation</li>
<li><strong>Navigation</strong>: Path planning, obstacle avoidance, SLAM</li>
<li><strong>Locomotion</strong>: Walking, running, jumping for legged robots</li>
<li><strong>Human-Robot Interaction</strong>: Social robots, collaborative robots</li>
</ul>
</section>
<section id="autonomous-systems" class="level3">
<h3 class="anchored" data-anchor-id="autonomous-systems" id="autonomous-systems">Autonomous Systems</h3>
<ul>
<li><strong>Self-Driving Cars</strong>: Path planning, decision making in traffic</li>
<li><strong>Drones</strong>: Navigation, surveillance, delivery</li>
<li><strong>Traffic Management</strong>: Optimizing traffic flow, signal control</li>
</ul>
</section>
<section id="finance-and-trading" class="level3">
<h3 class="anchored" data-anchor-id="finance-and-trading" id="finance-and-trading">Finance and Trading</h3>
<ul>
<li><strong>Algorithmic Trading</strong>: Portfolio management, execution strategies</li>
<li><strong>Risk Management</strong>: Dynamic hedging, capital allocation</li>
<li><strong>Market Making</strong>: Optimal bid-ask spread management</li>
</ul>
</section>
<section id="healthcare" class="level3">
<h3 class="anchored" data-anchor-id="healthcare" id="healthcare">Healthcare</h3>
<ul>
<li><strong>Treatment Planning</strong>: Personalized therapy recommendations</li>
<li><strong>Drug Discovery</strong>: Molecular design, clinical trial optimization</li>
<li><strong>Medical Imaging</strong>: Automated diagnosis, treatment planning</li>
</ul>
</section>
<section id="natural-language-processing" class="level3">
<h3 class="anchored" data-anchor-id="natural-language-processing" id="natural-language-processing">Natural Language Processing</h3>
<ul>
<li><strong>Dialogue Systems</strong>: Conversational AI, customer service bots</li>
<li><strong>Machine Translation</strong>: Optimizing translation quality</li>
<li><strong>Text Generation</strong>: Content creation, summarization</li>
</ul>
</section>
<section id="resource-management" class="level3">
<h3 class="anchored" data-anchor-id="resource-management" id="resource-management">Resource Management</h3>
<ul>
<li><strong>Cloud Computing</strong>: Resource allocation, auto-scaling</li>
<li><strong>Energy Systems</strong>: Smart grid management, battery optimization<br>
</li>
<li><strong>Supply Chain</strong>: Inventory management, logistics optimization</li>
</ul>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="environment-design" class="level3">
<h3 class="anchored" data-anchor-id="environment-design" id="environment-design">Environment Design</h3>
<ul>
<li><strong>Reward Engineering</strong>: Design rewards that incentivize desired behavior</li>
<li><strong>State Representation</strong>: Choose appropriate features and observations</li>
<li><strong>Action Space</strong>: Balance expressiveness with computational complexity</li>
<li><strong>Simulation Fidelity</strong>: Trade-off between realism and computational speed</li>
</ul>
</section>
<section id="hyperparameter-tuning" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h3>
<p>Critical parameters affecting performance:</p>
<ul>
<li><strong>Learning Rate</strong>: Too high causes instability, too low slows convergence</li>
<li><strong>Exploration Rate</strong>: Balance exploration and exploitation</li>
<li><strong>Discount Factor</strong>: Determines importance of future rewards</li>
<li><strong>Network Architecture</strong>: Layer sizes, activation functions, regularization</li>
<li><strong>Batch Size</strong>: Affects stability and computational efficiency</li>
</ul>
</section>
<section id="evaluation-and-testing" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-and-testing" id="evaluation-and-testing">Evaluation and Testing</h3>
<ul>
<li><strong>Sample Efficiency</strong>: How much data needed to learn effective policy</li>
<li><strong>Final Performance</strong>: Quality of learned policy on test environments</li>
<li><strong>Robustness</strong>: Performance under distribution shift or adversarial conditions</li>
<li><strong>Safety</strong>: Avoiding dangerous or harmful actions</li>
</ul>
</section>
<section id="debugging-rl-systems" class="level3">
<h3 class="anchored" data-anchor-id="debugging-rl-systems" id="debugging-rl-systems">Debugging RL Systems</h3>
<p>Common issues and solutions:</p>
<ul>
<li><strong>Learning Instability</strong>: Use target networks, gradient clipping, proper initialization</li>
<li><strong>Poor Exploration</strong>: Adjust exploration strategies, use curiosity-driven methods</li>
<li><strong>Reward Hacking</strong>: Careful reward design, use auxiliary objectives</li>
<li><strong>Overfitting</strong>: Regularization, diverse training environments</li>
</ul>
</section>
<section id="computational-considerations" class="level3">
<h3 class="anchored" data-anchor-id="computational-considerations" id="computational-considerations">Computational Considerations</h3>
<ul>
<li><strong>Parallel Training</strong>: Distributed computing, asynchronous updates</li>
<li><strong>Memory Requirements</strong>: Experience replay buffers, model storage</li>
<li><strong>Training Time</strong>: Sample efficiency vs wall-clock time trade-offs</li>
<li><strong>Hardware</strong>: GPUs for neural networks, CPUs for environment simulation</li>
</ul>
</section>
</section>
<section id="resources-and-tools" class="level2">
<h2 class="anchored" data-anchor-id="resources-and-tools" id="resources-and-tools">Resources and Tools</h2>
<section id="frameworks-and-libraries" class="level3">
<h3 class="anchored" data-anchor-id="frameworks-and-libraries" id="frameworks-and-libraries">Frameworks and Libraries</h3>
<ul>
<li><strong>Stable-Baselines3</strong>: High-quality implementations of RL algorithms</li>
<li><strong>Ray RLlib</strong>: Scalable reinforcement learning library</li>
<li><strong>OpenAI Gym</strong>: Standard environment interface for RL research</li>
<li><strong>PyBullet</strong>: Physics simulation for robotics applications</li>
<li><strong>Unity ML-Agents</strong>: RL framework for Unity game engine</li>
<li><strong>TensorFlow Agents</strong>: RL library built on TensorFlow</li>
<li><strong>Dopamine</strong>: Research framework for fast prototyping</li>
</ul>
</section>
<section id="simulation-environments" class="level3">
<h3 class="anchored" data-anchor-id="simulation-environments" id="simulation-environments">Simulation Environments</h3>
<ul>
<li><strong>Atari</strong>: Classic video games for testing RL algorithms</li>
<li><strong>MuJoCo</strong>: Physics simulation for continuous control</li>
<li><strong>CarRacing</strong>: Autonomous driving simulation</li>
<li><strong>Roboschool</strong>: Open-source physics simulation</li>
<li><strong>StarCraft II Learning Environment</strong>: Real-time strategy game</li>
<li><strong>Procgen</strong>: Procedurally generated environments for generalization</li>
</ul>
</section>
<section id="books-and-courses" class="level3">
<h3 class="anchored" data-anchor-id="books-and-courses" id="books-and-courses">Books and Courses</h3>
<ul>
<li>“Reinforcement Learning: An Introduction” by Sutton &amp; Barto</li>
<li>“Deep Reinforcement Learning” by Aske Plaat</li>
<li>CS294 Deep Reinforcement Learning (UC Berkeley)</li>
<li>DeepMind &amp; UCL Reinforcement Learning Course</li>
<li>OpenAI Spinning Up in Deep RL</li>
</ul>
</section>
<section id="research-venues" class="level3">
<h3 class="anchored" data-anchor-id="research-venues" id="research-venues">Research Venues</h3>
<ul>
<li><strong>Conferences</strong>: ICML, NeurIPS, ICLR, AAAI, IJCAI</li>
<li><strong>Journals</strong>: JMLR, Machine Learning, Artificial Intelligence</li>
<li><strong>Workshops</strong>: Deep RL Workshop, Multi-Agent RL Workshop</li>
</ul>
</section>
<section id="best-practices" class="level3">
<h3 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h3>
<ol type="1">
<li><strong>Start Simple</strong>: Begin with basic algorithms before moving to complex methods</li>
<li><strong>Understand the Environment</strong>: Analyze state/action spaces and reward structure</li>
<li><strong>Baseline Comparison</strong>: Compare against random and heuristic policies</li>
<li><strong>Ablation Studies</strong>: Test individual components to understand their contribution</li>
<li><strong>Reproducibility</strong>: Use seeds, version control, and detailed logging</li>
<li><strong>Incremental Development</strong>: Add complexity gradually while maintaining functionality</li>
<li><strong>Monitor Training</strong>: Track learning curves, exploration metrics, and environment statistics</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Reinforcement learning represents a powerful paradigm for solving complex sequential decision-making problems. While it presents unique challenges in terms of sample efficiency, exploration, and stability, the field continues to advance rapidly with new algorithms, applications, and theoretical insights. Success in RL requires careful consideration of problem formulation, algorithm selection, implementation details, and thorough evaluation practices.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Vision-Language Models: Bridging Visual and Textual Understanding]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-explained/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-explained/</guid>
      <pubDate>Sat, 02 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="vision-language-models-bridging-visual-and-textual-understanding" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-explained/vl.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Vision-Language Models (VLMs) represent one of the most exciting frontiers in artificial intelligence, combining computer vision and natural language processing to create systems that can understand and reason about both images and text simultaneously. These multimodal models are revolutionizing how machines interpret the world around us.</p>
</section>
<section id="what-are-vision-language-models" class="level2">
<h2 class="anchored" data-anchor-id="what-are-vision-language-models" id="what-are-vision-language-models">What Are Vision-Language Models?</h2>
<p>Vision-Language Models are neural networks designed to process and understand both visual and textual information. Unlike traditional models that handle only one modality, VLMs can:</p>
<ul>
<li>Describe images in natural language</li>
<li>Answer questions about visual content</li>
<li>Generate images from text descriptions</li>
<li>Perform visual reasoning tasks</li>
<li>Extract and understand text within images</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>The key innovation lies in their ability to create shared representations that bridge the semantic gap between visual and linguistic information.</p>
</div>
</div>
</section>
<section id="architecture-deep-dive" class="level2">
<h2 class="anchored" data-anchor-id="architecture-deep-dive" id="architecture-deep-dive">Architecture Deep Dive</h2>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<p>Most modern VLMs follow a encoder-decoder architecture with several key components:</p>
<div id="vlm-architecture" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionLanguageModel:</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vision_encoder <span class="op">=</span> VisionTransformer()</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text_encoder <span class="op">=</span> TextTransformer()</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cross_attention <span class="op">=</span> CrossAttentionLayer()</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.decoder <span class="op">=</span> LanguageDecoder()</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, image, text):</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract visual features</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        visual_features <span class="op">=</span> <span class="va">self</span>.vision_encoder(image)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract textual features</span></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> <span class="va">self</span>.text_encoder(text)</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cross-modal attention</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        fused_features <span class="op">=</span> <span class="va">self</span>.cross_attention(</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>            visual_features, text_features</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate output</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.decoder(fused_features)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
</section>
<section id="vision-encoder" class="level3">
<h3 class="anchored" data-anchor-id="vision-encoder" id="vision-encoder">Vision Encoder</h3>
<p>The vision component typically uses:</p>
<ul>
<li><strong>Vision Transformers (ViTs)</strong>: Split images into patches and process them as sequences</li>
<li><strong>Convolutional Neural Networks</strong>: Extract hierarchical visual features</li>
<li><strong>Region-based methods</strong>: Focus on specific image regions</li>
</ul>
<div id="patch-embedding" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> patch_embedding(image, patch_size<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert image to patch embeddings"""</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    patches <span class="op">=</span> image.unfold(<span class="dv">2</span>, patch_size, patch_size)</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    patches <span class="op">=</span> patches.unfold(<span class="dv">3</span>, patch_size, patch_size)</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Flatten patches and create embeddings</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    patch_embeddings <span class="op">=</span> patches.reshape(<span class="op">-</span><span class="dv">1</span>, patch_size <span class="op">*</span> patch_size <span class="op">*</span> <span class="dv">3</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> patch_embeddings</span></code></pre></div></div>
</div>
</section>
<section id="text-encoder" class="level3">
<h3 class="anchored" data-anchor-id="text-encoder" id="text-encoder">Text Encoder</h3>
<p>Text processing leverages transformer architectures:</p>
<ul>
<li><strong>BERT-style encoders</strong>: For understanding input text</li>
<li><strong>GPT-style decoders</strong>: For generating responses</li>
<li><strong>Tokenization</strong>: Converting text to numerical representations</li>
</ul>
</section>
<section id="cross-modal-fusion" class="level3">
<h3 class="anchored" data-anchor-id="cross-modal-fusion" id="cross-modal-fusion">Cross-Modal Fusion</h3>
<p>The critical challenge is combining visual and textual information:</p>
<div id="cross-attention" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CrossAttention(nn.Module):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dim):</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.attention <span class="op">=</span> nn.MultiheadAttention(dim, num_heads<span class="op">=</span><span class="dv">8</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, visual_features, text_features):</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use text as query, vision as key and value</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        attended_features, _ <span class="op">=</span> <span class="va">self</span>.attention(</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>            query<span class="op">=</span>text_features,</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>            key<span class="op">=</span>visual_features,</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            value<span class="op">=</span>visual_features</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> attended_features</span></code></pre></div></div>
</div>
</section>
</section>
<section id="training-strategies" class="level2">
<h2 class="anchored" data-anchor-id="training-strategies" id="training-strategies">Training Strategies</h2>
<section id="contrastive-learning" class="level3">
<h3 class="anchored" data-anchor-id="contrastive-learning" id="contrastive-learning">Contrastive Learning</h3>
<p>Many VLMs use contrastive learning to align visual and textual representations:</p>
<div id="contrastive-loss" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> contrastive_loss(image_features, text_features, temperature<span class="op">=</span><span class="fl">0.07</span>):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""CLIP-style contrastive loss"""</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Normalize features</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    image_features <span class="op">=</span> F.normalize(image_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    text_features <span class="op">=</span> F.normalize(text_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compute similarity matrix</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    similarity <span class="op">=</span> torch.matmul(image_features, text_features.T) <span class="op">/</span> temperature</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create labels (diagonal should be positive pairs)</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    labels <span class="op">=</span> torch.arange(<span class="bu">len</span>(image_features))</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compute loss</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    loss_i2t <span class="op">=</span> F.cross_entropy(similarity, labels)</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    loss_t2i <span class="op">=</span> F.cross_entropy(similarity.T, labels)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (loss_i2t <span class="op">+</span> loss_t2i) <span class="op">/</span> <span class="dv">2</span></span></code></pre></div></div>
</div>
</section>
<section id="multi-task-learning" class="level3">
<h3 class="anchored" data-anchor-id="multi-task-learning" id="multi-task-learning">Multi-Task Learning</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Training Objectives
</div>
</div>
<div class="callout-body-container callout-body">
<p>VLMs often train on multiple objectives simultaneously:</p>
<ul>
<li>Image-text matching</li>
<li>Masked language modeling</li>
<li>Image captioning</li>
<li>Visual question answering</li>
</ul>
</div>
</div>
</section>
<section id="data-requirements" class="level3">
<h3 class="anchored" data-anchor-id="data-requirements" id="data-requirements">Data Requirements</h3>
<p>Training requires massive paired datasets:</p>
<div id="vlm-dataset" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> Dataset</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VLMDataset(Dataset):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_paths, captions):</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_paths <span class="op">=</span> image_paths</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.captions <span class="op">=</span> captions</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>            transforms.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>                               std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="va">self</span>.image_paths[idx])</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> <span class="va">self</span>.transform(image)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        caption <span class="op">=</span> <span class="va">self</span>.captions[idx]</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="st">'image'</span>: image,</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>            <span class="st">'caption'</span>: caption,</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>            <span class="st">'image_id'</span>: idx</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.image_paths)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="popular-vlm-architectures" class="level2">
<h2 class="anchored" data-anchor-id="popular-vlm-architectures" id="popular-vlm-architectures">Popular VLM Architectures</h2>
<section id="clip-contrastive-language-image-pre-training" class="level3">
<h3 class="anchored" data-anchor-id="clip-contrastive-language-image-pre-training" id="clip-contrastive-language-image-pre-training">CLIP (Contrastive Language-Image Pre-training)</h3>
<p>CLIP learns visual concepts from natural language supervision:</p>
<div id="clip-model" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CLIP(nn.Module):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vision_model, text_model):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vision_model <span class="op">=</span> vision_model</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text_model <span class="op">=</span> text_model</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logit_scale <span class="op">=</span> nn.Parameter(torch.ones([]) <span class="op">*</span> np.log(<span class="dv">1</span><span class="op">/</span><span class="fl">0.07</span>))</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, image, text):</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> <span class="va">self</span>.vision_model(image)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> <span class="va">self</span>.text_model(text)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize features</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> image_features <span class="op">/</span> image_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> text_features <span class="op">/</span> text_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute similarities</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        logit_scale <span class="op">=</span> <span class="va">self</span>.logit_scale.exp()</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        logits_per_image <span class="op">=</span> logit_scale <span class="op">*</span> image_features <span class="op">@</span> text_features.t()</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits_per_image</span></code></pre></div></div>
</div>
</section>
<section id="blip-bootstrapping-language-image-pre-training" class="level3">
<h3 class="anchored" data-anchor-id="blip-bootstrapping-language-image-pre-training" id="blip-bootstrapping-language-image-pre-training">BLIP (Bootstrapping Language-Image Pre-training)</h3>
<p>BLIP uses a unified architecture for multiple vision-language tasks:</p>
<ul>
<li>Encoder for understanding</li>
<li>Encoder-decoder for generation</li>
<li>Decoder for language modeling</li>
</ul>
</section>
<section id="flamingo" class="level3">
<h3 class="anchored" data-anchor-id="flamingo" id="flamingo">Flamingo</h3>
<p>Flamingo excels at few-shot learning by conditioning on visual examples:</p>
<div id="flamingo-layer" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FeedForward(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dim, hidden_dim<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        hidden_dim <span class="op">=</span> hidden_dim <span class="kw">or</span> <span class="dv">4</span> <span class="op">*</span> dim</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.net <span class="op">=</span> nn.Sequential(</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>            nn.Linear(dim, hidden_dim),</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>            nn.GELU(),</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_dim, dim)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.net(x)</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FlamingoLayer(nn.Module):</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dim):</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cross_attention <span class="op">=</span> CrossAttention(dim)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.feed_forward <span class="op">=</span> FeedForward(dim)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, text_features, visual_features):</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cross-attention between text and vision</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        attended <span class="op">=</span> <span class="va">self</span>.cross_attention(text_features, visual_features)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add residual connection</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> text_features <span class="op">+</span> attended</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feed forward</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.feed_forward(text_features)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
</section>
</section>
<section id="implementation-example" class="level2">
<h2 class="anchored" data-anchor-id="implementation-example" id="implementation-example">Implementation Example</h2>
<p>Here’s a simplified VLM implementation for image captioning:</p>
<div id="simple-vlm" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> GPT2LMHeadModel, GPT2Tokenizer</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.models <span class="im">import</span> resnet50</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleVLM(nn.Module):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vocab_size<span class="op">=</span><span class="dv">50257</span>, hidden_dim<span class="op">=</span><span class="dv">768</span>):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Vision encoder</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vision_encoder <span class="op">=</span> resnet50(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vision_encoder.fc <span class="op">=</span> nn.Linear(<span class="dv">2048</span>, hidden_dim)</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Language model</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.language_model <span class="op">=</span> GPT2LMHeadModel.from_pretrained(<span class="st">'gpt2'</span>)</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Projection layer</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.visual_projection <span class="op">=</span> nn.Linear(hidden_dim, hidden_dim)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, images, input_ids, attention_mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract visual features</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        visual_features <span class="op">=</span> <span class="va">self</span>.vision_encoder(images)</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        visual_features <span class="op">=</span> <span class="va">self</span>.visual_projection(visual_features)</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add visual features as prefix to text</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> visual_features.size(<span class="dv">0</span>)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        visual_tokens <span class="op">=</span> visual_features.unsqueeze(<span class="dv">1</span>)  <span class="co"># [B, 1, H]</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get text embeddings</span></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        text_embeddings <span class="op">=</span> <span class="va">self</span>.language_model.transformer.wte(input_ids)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate visual and text embeddings</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>        combined_embeddings <span class="op">=</span> torch.cat([visual_tokens, text_embeddings], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate text</span></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.language_model(</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>            inputs_embeds<span class="op">=</span>combined_embeddings,</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>            attention_mask<span class="op">=</span>attention_mask</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> outputs</span></code></pre></div></div>
</div>
<section id="training-loop" class="level3">
<h3 class="anchored" data-anchor-id="training-loop" id="training-loop">Training Loop</h3>
<div id="training-loop" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_vlm(model, dataloader, optimizer, device):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Training loop for VLM"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        images <span class="op">=</span> batch[<span class="st">'images'</span>].to(device)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        captions <span class="op">=</span> batch[<span class="st">'captions'</span>].to(device)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(images, captions[:, :<span class="op">-</span><span class="dv">1</span>])</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute loss</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> nn.CrossEntropyLoss()(</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>            outputs.logits.reshape(<span class="op">-</span><span class="dv">1</span>, outputs.logits.size(<span class="op">-</span><span class="dv">1</span>)),</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>            captions[:, <span class="dv">1</span>:].reshape(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(dataloader)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="evaluation-metrics" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-metrics" id="evaluation-metrics">Evaluation Metrics</h2>
<p>VLMs are evaluated using various metrics depending on the task:</p>
<section id="image-captioning-metrics" class="level3">
<h3 class="anchored" data-anchor-id="image-captioning-metrics" id="image-captioning-metrics">Image Captioning Metrics</h3>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Metric</th>
<th>Description</th>
<th>Range</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>BLEU</strong></td>
<td>N-gram overlap with reference captions</td>
<td>0-1</td>
</tr>
<tr class="even">
<td><strong>ROUGE</strong></td>
<td>Recall-oriented similarity</td>
<td>0-1</td>
</tr>
<tr class="odd">
<td><strong>CIDEr</strong></td>
<td>Consensus-based metric for image description</td>
<td>0-10</td>
</tr>
<tr class="even">
<td><strong>SPICE</strong></td>
<td>Semantic similarity metric</td>
<td>0-1</td>
</tr>
</tbody>
</table>
<div id="bleu-score" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compute_bleu_score(predictions, references):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Compute BLEU score for image captioning"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="im">from</span> nltk.translate.bleu_score <span class="im">import</span> corpus_bleu</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Tokenize predictions and references</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    pred_tokens <span class="op">=</span> [pred.split() <span class="cf">for</span> pred <span class="kw">in</span> predictions]</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    ref_tokens <span class="op">=</span> [[ref.split() <span class="cf">for</span> ref <span class="kw">in</span> refs] <span class="cf">for</span> refs <span class="kw">in</span> references]</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compute BLEU score</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    bleu_score <span class="op">=</span> corpus_bleu(ref_tokens, pred_tokens)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> bleu_score</span></code></pre></div></div>
</div>
</section>
<section id="visual-question-answering" class="level3">
<h3 class="anchored" data-anchor-id="visual-question-answering" id="visual-question-answering">Visual Question Answering</h3>
<ul>
<li><strong>Accuracy</strong>: Exact match with ground truth answers</li>
<li><strong>F1 Score</strong>: Harmonic mean of precision and recall</li>
</ul>
</section>
<section id="image-text-retrieval" class="level3">
<h3 class="anchored" data-anchor-id="image-text-retrieval" id="image-text-retrieval">Image-Text Retrieval</h3>
<ul>
<li><strong>Recall@K</strong>: Fraction of queries where correct answer is in top-K results</li>
<li><strong>Mean Reciprocal Rank</strong>: Average of reciprocal ranks of correct answers</li>
</ul>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="content-generation" class="level3">
<h3 class="anchored" data-anchor-id="content-generation" id="content-generation">Content Generation</h3>
<div id="caption-generation" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_caption(model, image, tokenizer, max_length<span class="op">=</span><span class="dv">50</span>):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Generate caption for an image"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process image</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        image_tensor <span class="op">=</span> preprocess_image(image)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate caption</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        generated_ids <span class="op">=</span> model.generate(</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>            image_tensor,</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>            max_length<span class="op">=</span>max_length,</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>            num_beams<span class="op">=</span><span class="dv">5</span>,</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>            temperature<span class="op">=</span><span class="fl">0.8</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decode caption</span></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        caption <span class="op">=</span> tokenizer.decode(generated_ids[<span class="dv">0</span>], skip_special_tokens<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> caption</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_image(image):</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Preprocess image for model input"""</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        transforms.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>                           std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> transform(image).unsqueeze(<span class="dv">0</span>)</span></code></pre></div></div>
</div>
</section>
<section id="document-understanding" class="level3">
<h3 class="anchored" data-anchor-id="document-understanding" id="document-understanding">Document Understanding</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Applications
</div>
</div>
<div class="callout-body-container callout-body">
<p>VLMs excel at processing documents with both text and visual elements:</p>
<ul>
<li>Form understanding</li>
<li>Chart and graph interpretation</li>
<li>Layout analysis</li>
<li>OCR with context</li>
</ul>
</div>
</div>
</section>
<section id="other-applications" class="level3">
<h3 class="anchored" data-anchor-id="other-applications" id="other-applications">Other Applications</h3>
<ul>
<li><strong>Accessibility</strong>: Image description for visually impaired users</li>
<li><strong>E-commerce</strong>: Product description generation and visual search</li>
<li><strong>Navigation</strong>: Scene understanding and object recognition</li>
</ul>
</section>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<section id="computational-requirements" class="level3">
<h3 class="anchored" data-anchor-id="computational-requirements" id="computational-requirements">Computational Requirements</h3>
<p>VLMs require significant computational resources:</p>
<div id="memory-estimation" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> estimate_memory_usage(batch_size, image_size, model_params):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Estimate GPU memory usage for VLM"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    image_memory <span class="op">=</span> batch_size <span class="op">*</span> <span class="dv">3</span> <span class="op">*</span> image_size <span class="op">*</span> image_size <span class="op">*</span> <span class="dv">4</span>  <span class="co"># bytes</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    model_memory <span class="op">=</span> model_params <span class="op">*</span> <span class="dv">4</span>  <span class="co"># 4 bytes per parameter</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    activation_memory <span class="op">=</span> batch_size <span class="op">*</span> model_params <span class="op">*</span> <span class="fl">0.3</span>  <span class="co"># rough estimate</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    total_gb <span class="op">=</span> (image_memory <span class="op">+</span> model_memory <span class="op">+</span> activation_memory) <span class="op">/</span> (<span class="dv">1024</span><span class="op">**</span><span class="dv">3</span>)</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_gb</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>memory_gb <span class="op">=</span> estimate_memory_usage(</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    batch_size<span class="op">=</span><span class="dv">32</span>, </span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    image_size<span class="op">=</span><span class="dv">224</span>, </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    model_params<span class="op">=</span><span class="dv">175_000_000</span>  <span class="co"># 175M parameters</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Estimated memory usage: </span><span class="sc">{</span>memory_gb<span class="sc">:.2f}</span><span class="ss"> GB"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="bias-and-fairness" class="level3">
<h3 class="anchored" data-anchor-id="bias-and-fairness" id="bias-and-fairness">Bias and Fairness</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Bias Concerns
</div>
</div>
<div class="callout-body-container callout-body">
<p>VLMs can perpetuate biases present in training data:</p>
<ul>
<li>Gender and racial stereotypes</li>
<li>Cultural biases in image interpretation</li>
<li>Socioeconomic biases in scene understanding</li>
</ul>
</div>
</div>
</section>
<section id="hallucination-detection" class="level3">
<h3 class="anchored" data-anchor-id="hallucination-detection" id="hallucination-detection">Hallucination Detection</h3>
<p>Models may generate plausible but incorrect descriptions:</p>
<div id="hallucination-detection" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> detect_hallucination(caption, image_objects):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simple hallucination detection"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    mentioned_objects <span class="op">=</span> extract_objects_from_caption(caption)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    hallucinated_objects <span class="op">=</span> []</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> obj <span class="kw">in</span> mentioned_objects:</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> obj <span class="kw">not</span> <span class="kw">in</span> image_objects:</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>            hallucinated_objects.append(obj)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> hallucinated_objects</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> extract_objects_from_caption(caption):</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Extract mentioned objects from caption"""</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Simplified implementation - in practice, use NLP techniques</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> re</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    nouns <span class="op">=</span> re.findall(<span class="vs">r'</span><span class="dv">\b</span><span class="pp">[a-z]</span><span class="op">+</span><span class="dv">\b</span><span class="vs">'</span>, caption.lower())</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> nouns</span></code></pre></div></div>
</div>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="advanced-capabilities" class="level3">
<h3 class="anchored" data-anchor-id="advanced-capabilities" id="advanced-capabilities">Advanced Capabilities</h3>
<p>Future VLMs are moving toward more sophisticated reasoning:</p>
<ul>
<li><strong>Temporal understanding</strong> in videos</li>
<li><strong>Spatial reasoning</strong> in 3D scenes</li>
<li><strong>Causal reasoning</strong> from visual evidence</li>
</ul>
</section>
<section id="efficiency-improvements" class="level3">
<h3 class="anchored" data-anchor-id="efficiency-improvements" id="efficiency-improvements">Efficiency Improvements</h3>
<p>Research focuses on making VLMs more efficient:</p>
<ul>
<li>Model compression and pruning</li>
<li>Knowledge distillation</li>
<li>Efficient attention mechanisms</li>
</ul>
</section>
<section id="interactive-systems" class="level3">
<h3 class="anchored" data-anchor-id="interactive-systems" id="interactive-systems">Interactive Systems</h3>
<p>Future VLMs will support more interactive applications:</p>
<ul>
<li>Conversational visual AI</li>
<li>Real-time visual assistance</li>
<li>Collaborative human-AI systems</li>
</ul>
</section>
</section>
<section id="best-practices-for-implementation" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-for-implementation" id="best-practices-for-implementation">Best Practices for Implementation</h2>
<section id="data-preparation" class="level3">
<h3 class="anchored" data-anchor-id="data-preparation" id="data-preparation">Data Preparation</h3>
<div id="data-preparation" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> prepare_vlm_dataset(image_dir, caption_file):</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Prepare dataset for VLM training"""</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    dataset <span class="op">=</span> []</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(caption_file, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> line <span class="kw">in</span> f:</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>            data <span class="op">=</span> json.loads(line)</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>            image_path <span class="op">=</span> os.path.join(image_dir, data[<span class="st">'image'</span>])</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Quality checks</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> os.path.exists(image_path) <span class="kw">and</span> <span class="bu">len</span>(data[<span class="st">'caption'</span>]) <span class="op">&gt;</span> <span class="dv">10</span>:</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>                dataset.append({</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'image_path'</span>: image_path,</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'caption'</span>: data[<span class="st">'caption'</span>],</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'metadata'</span>: data.get(<span class="st">'metadata'</span>, {})</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dataset</span></code></pre></div></div>
</div>
</section>
<section id="model-optimization-tips" class="level3">
<h3 class="anchored" data-anchor-id="model-optimization-tips" id="model-optimization-tips">Model Optimization Tips</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Optimization Strategies
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Use mixed precision training</li>
<li>Implement gradient checkpointing</li>
<li>Apply learning rate scheduling</li>
<li>Monitor for overfitting</li>
</ul>
</div>
</div>
</section>
<section id="deployment-considerations" class="level3">
<h3 class="anchored" data-anchor-id="deployment-considerations" id="deployment-considerations">Deployment Considerations</h3>
<ul>
<li><strong>Model quantization</strong> for edge deployment</li>
<li><strong>Caching strategies</strong> for repeated queries</li>
<li><strong>Load balancing</strong> for high-traffic applications</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Vision-Language Models represent a paradigm shift toward more human-like AI systems that can understand and reason about the visual world through natural language. As these models continue to evolve, they promise to unlock new possibilities in human-computer interaction, accessibility, content creation, and automated understanding of our increasingly visual digital world.</p>
<p>The field continues to advance rapidly, with ongoing research addressing current limitations while pushing the boundaries of what’s possible when machines can truly see and understand the world around them. For developers and researchers, VLMs offer exciting opportunities to build applications that bridge the gap between human perception and machine understanding.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Fine-tuning Vision-Language Models: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-finetuning/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-finetuning/</guid>
      <pubDate>Sat, 02 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="fine-tuning-vision-language-models-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-finetuning/vlfine.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Vision-Language Models (VLMs) represent a significant advancement in artificial intelligence, combining computer vision and natural language processing to understand and generate content that bridges visual and textual modalities. Fine-tuning these models for specific tasks and domains has become crucial for achieving optimal performance in real-world applications.</p>
<p>This comprehensive guide explores the intricacies of fine-tuning VLMs, from theoretical foundations to practical implementation strategies. Whether you’re adapting models like CLIP, BLIP, or more recent architectures like GPT-4V or LLaVA, this article provides the knowledge needed to successfully customize these powerful models for your specific use cases.</p>
</section>
<section id="understanding-vision-language-models" class="level2">
<h2 class="anchored" data-anchor-id="understanding-vision-language-models" id="understanding-vision-language-models">Understanding Vision-Language Models</h2>
<section id="architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h3>
<p>Vision-Language Models typically consist of three main components:</p>
<p><strong>Vision Encoder</strong>: Processes visual input (images, videos) and extracts meaningful features. Common architectures include:</p>
<ul>
<li>Vision Transformers (ViTs)</li>
<li>Convolutional Neural Networks (CNNs)</li>
<li>Hybrid architectures combining both approaches</li>
</ul>
<p><strong>Language Encoder/Decoder</strong>: Handles textual input and output generation. This component often leverages:</p>
<ul>
<li>Transformer-based architectures</li>
<li>Pre-trained language models (BERT, GPT variants)</li>
<li>Specialized language models designed for multimodal tasks</li>
</ul>
<p><strong>Cross-Modal Fusion</strong>: Integrates information from both modalities through:</p>
<ul>
<li>Attention mechanisms</li>
<li>Cross-modal transformers</li>
<li>Contrastive learning approaches</li>
<li>Multimodal fusion layers</li>
</ul>
</section>
<section id="popular-vlm-architectures" class="level3">
<h3 class="anchored" data-anchor-id="popular-vlm-architectures" id="popular-vlm-architectures">Popular VLM Architectures</h3>
<section id="clip-contrastive-language-image-pre-training" class="level4">
<h4 class="anchored" data-anchor-id="clip-contrastive-language-image-pre-training">CLIP (Contrastive Language-Image Pre-training)</h4>
<p>CLIP learns visual concepts from natural language supervision by training on image-text pairs using contrastive learning. It consists of separate image and text encoders that map inputs to a shared embedding space.</p>
</section>
<section id="blip-bootstrapping-language-image-pre-training" class="level4">
<h4 class="anchored" data-anchor-id="blip-bootstrapping-language-image-pre-training">BLIP (Bootstrapping Language-Image Pre-training)</h4>
<p>BLIP introduces a multimodal mixture of encoder-decoder architecture that can handle various vision-language tasks through unified pre-training objectives.</p>
</section>
<section id="llava-large-language-and-vision-assistant" class="level4">
<h4 class="anchored" data-anchor-id="llava-large-language-and-vision-assistant">LLaVA (Large Language and Vision Assistant)</h4>
<p>LLaVA connects a vision encoder with a large language model, enabling instruction-following capabilities for multimodal tasks.</p>
</section>
<section id="gpt-4v-and-similar-models" class="level4">
<h4 class="anchored" data-anchor-id="gpt-4v-and-similar-models">GPT-4V and Similar Models</h4>
<p>Recent large-scale models that integrate vision capabilities directly into large language models, offering sophisticated reasoning across modalities.</p>
</section>
</section>
</section>
<section id="types-of-fine-tuning" class="level2">
<h2 class="anchored" data-anchor-id="types-of-fine-tuning" id="types-of-fine-tuning">Types of Fine-tuning</h2>
<section id="full-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="full-fine-tuning" id="full-fine-tuning">Full Fine-tuning</h3>
<p>Complete parameter updates across the entire model architecture. This approach offers maximum flexibility but requires substantial computational resources and carefully curated datasets.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>Maximum adaptation potential</li>
<li>Can learn complex task-specific patterns</li>
<li>Suitable for significantly different domains</li>
</ul>
<p><strong>Disadvantages</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Risk of catastrophic forgetting</li>
<li>Requires large datasets</li>
</ul>
</section>
<section id="parameter-efficient-fine-tuning-peft" class="level3">
<h3 class="anchored" data-anchor-id="parameter-efficient-fine-tuning-peft" id="parameter-efficient-fine-tuning-peft">Parameter-Efficient Fine-tuning (PEFT)</h3>
<section id="low-rank-adaptation-lora" class="level4">
<h4 class="anchored" data-anchor-id="low-rank-adaptation-lora">Low-Rank Adaptation (LoRA)</h4>
<p>LoRA introduces trainable low-rank matrices to approximate weight updates, significantly reducing the number of trainable parameters while maintaining performance.</p>
<p><strong>Implementation</strong>: Instead of updating weight matrix W, LoRA learns decomposition W + BA, where B and A are much smaller matrices.</p>
</section>
<section id="adapters" class="level4">
<h4 class="anchored" data-anchor-id="adapters">Adapters</h4>
<p>Small neural network modules inserted between transformer layers, allowing task-specific adaptation while keeping the original model frozen.</p>
</section>
<section id="prompt-tuning" class="level4">
<h4 class="anchored" data-anchor-id="prompt-tuning">Prompt Tuning</h4>
<p>Learning continuous prompt embeddings that guide the model’s behavior without modifying the underlying parameters.</p>
</section>
<section id="prefix-tuning" class="level4">
<h4 class="anchored" data-anchor-id="prefix-tuning">Prefix Tuning</h4>
<p>Similar to prompt tuning but focuses on learning continuous task-specific vectors prepended to the input sequence.</p>
</section>
</section>
<section id="layer-wise-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="layer-wise-fine-tuning" id="layer-wise-fine-tuning">Layer-wise Fine-tuning</h3>
<p>Selective unfreezing and training of specific model layers, often starting from the top layers and gradually including lower layers.</p>
</section>
<section id="task-specific-head-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="task-specific-head-fine-tuning" id="task-specific-head-fine-tuning">Task-specific Head Fine-tuning</h3>
<p>Adding and training new classification or regression heads while keeping the backbone frozen, suitable for discriminative tasks.</p>
</section>
</section>
<section id="data-preparation" class="level2">
<h2 class="anchored" data-anchor-id="data-preparation" id="data-preparation">Data Preparation</h2>
<section id="dataset-requirements" class="level3">
<h3 class="anchored" data-anchor-id="dataset-requirements" id="dataset-requirements">Dataset Requirements</h3>
<p><strong>Quality over Quantity</strong>: High-quality, well-annotated data is more valuable than large volumes of noisy data. Each image-text pair should be:</p>
<ul>
<li>Semantically aligned</li>
<li>Descriptively accurate</li>
<li>Relevant to the target task</li>
</ul>
<p><strong>Data Diversity</strong>: Ensure representation across:</p>
<ul>
<li>Visual concepts and scenes</li>
<li>Linguistic patterns and styles</li>
<li>Cultural and demographic diversity</li>
<li>Various lighting conditions and viewpoints</li>
</ul>
</section>
<section id="data-formats-and-standards" class="level3">
<h3 class="anchored" data-anchor-id="data-formats-and-standards" id="data-formats-and-standards">Data Formats and Standards</h3>
<section id="image-text-pairs" class="level4">
<h4 class="anchored" data-anchor-id="image-text-pairs">Image-Text Pairs</h4>
<div id="63da37d9" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example data structure for image-text pairs</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>example_data <span class="op">=</span> {</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"image_path"</span>: <span class="st">"path/to/image.jpg"</span>,</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"caption"</span>: <span class="st">"A detailed description of the image"</span>,</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"metadata"</span>: {</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"source"</span>: <span class="st">"dataset_name"</span>,</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"quality_score"</span>: <span class="fl">0.95</span>,</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"language"</span>: <span class="st">"en"</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(json.dumps(example_data, indent<span class="op">=</span><span class="dv">2</span>))</span></code></pre></div></div>
</div>
</section>
<section id="instruction-following-format" class="level4">
<h4 class="anchored" data-anchor-id="instruction-following-format">Instruction-Following Format</h4>
<div id="b2957bc2" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example instruction-following format</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>instruction_data <span class="op">=</span> {</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"image"</span>: <span class="st">"path/to/image.jpg"</span>,</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"conversations"</span>: [</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">"from"</span>: <span class="st">"human"</span>,</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>            <span class="st">"value"</span>: <span class="st">"What objects are visible in this image?"</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        },</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        {</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">"from"</span>: <span class="st">"gpt"</span>,</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">"value"</span>: <span class="st">"I can see a red bicycle, a wooden bench, and several trees in the background."</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(json.dumps(instruction_data, indent<span class="op">=</span><span class="dv">2</span>))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="data-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="data-preprocessing" id="data-preprocessing">Data Preprocessing</h3>
<p><strong>Image Preprocessing</strong>:</p>
<ul>
<li>Normalization using pre-training statistics</li>
<li>Consistent resizing and aspect ratio handling</li>
<li>Data augmentation strategies (rotation, cropping, color jittering)</li>
<li>Format standardization (RGB, resolution)</li>
</ul>
<p><strong>Text Preprocessing</strong>:</p>
<ul>
<li>Tokenization using model-specific tokenizers</li>
<li>Length normalization and truncation</li>
<li>Special token handling</li>
<li>Encoding consistency</li>
</ul>
</section>
<section id="data-augmentation-strategies" class="level3">
<h3 class="anchored" data-anchor-id="data-augmentation-strategies" id="data-augmentation-strategies">Data Augmentation Strategies</h3>
<p><strong>Visual Augmentations</strong>:</p>
<ul>
<li>Geometric transformations (rotation, scaling, flipping)</li>
<li>Color space modifications</li>
<li>Noise injection</li>
<li>Cutout and mixup techniques</li>
</ul>
<p><strong>Textual Augmentations</strong>:</p>
<ul>
<li>Paraphrasing using language models</li>
<li>Synonym replacement</li>
<li>Back-translation</li>
<li>Template-based generation</li>
</ul>
<p><strong>Cross-modal Augmentations</strong>:</p>
<ul>
<li>Hard negative mining</li>
<li>Curriculum learning approaches</li>
<li>Multi-view consistency training</li>
</ul>
</section>
</section>
<section id="fine-tuning-strategies" class="level2">
<h2 class="anchored" data-anchor-id="fine-tuning-strategies" id="fine-tuning-strategies">Fine-tuning Strategies</h2>
<section id="curriculum-learning" class="level3">
<h3 class="anchored" data-anchor-id="curriculum-learning" id="curriculum-learning">Curriculum Learning</h3>
<p>Gradually increasing task complexity during training, starting with simpler examples and progressing to more challenging ones.</p>
<p><strong>Implementation Strategies</strong>:</p>
<ul>
<li>Easy-to-hard example ordering</li>
<li>Confidence-based sample selection</li>
<li>Multi-stage training protocols</li>
</ul>
</section>
<section id="multi-task-learning" class="level3">
<h3 class="anchored" data-anchor-id="multi-task-learning" id="multi-task-learning">Multi-task Learning</h3>
<p>Training on multiple related tasks simultaneously to improve generalization and transfer learning capabilities.</p>
<p><strong>Task Selection Criteria</strong>:</p>
<ul>
<li>Complementary skill requirements</li>
<li>Shared visual or linguistic patterns</li>
<li>Balanced computational requirements</li>
</ul>
</section>
<section id="domain-adaptation-techniques" class="level3">
<h3 class="anchored" data-anchor-id="domain-adaptation-techniques" id="domain-adaptation-techniques">Domain Adaptation Techniques</h3>
<section id="adversarial-training" class="level4">
<h4 class="anchored" data-anchor-id="adversarial-training">Adversarial Training</h4>
<p>Using domain discriminators to learn domain-invariant features while maintaining task performance.</p>
</section>
<section id="gradual-domain-shift" class="level4">
<h4 class="anchored" data-anchor-id="gradual-domain-shift">Gradual Domain Shift</h4>
<p>Progressively adapting from source to target domain through intermediate domains or synthetic data.</p>
</section>
<section id="self-supervised-pre-training" class="level4">
<h4 class="anchored" data-anchor-id="self-supervised-pre-training">Self-supervised Pre-training</h4>
<p>Leveraging unlabeled data from the target domain through self-supervised objectives before fine-tuning.</p>
</section>
</section>
<section id="regularization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="regularization-techniques" id="regularization-techniques">Regularization Techniques</h3>
<p><strong>Weight Decay and Dropout</strong>: Standard regularization methods to prevent overfitting.</p>
<p><strong>Knowledge Distillation</strong>: Using a larger teacher model to guide the training of a smaller student model.</p>
<p><strong>Elastic Weight Consolidation (EWC)</strong>: Preventing catastrophic forgetting by constraining important parameters based on Fisher information.</p>
</section>
</section>
<section id="technical-implementation" class="level2">
<h2 class="anchored" data-anchor-id="technical-implementation" id="technical-implementation">Technical Implementation</h2>
<section id="environment-setup" class="level3">
<h3 class="anchored" data-anchor-id="environment-setup" id="environment-setup">Environment Setup</h3>
<div id="3c5502a1" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Required libraries</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> transformers</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoProcessor, AutoModel</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, Dataset</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytorch_lightning <span class="im">as</span> pl</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span></code></pre></div></div>
</div>
</section>
<section id="model-loading-and-configuration" class="level3">
<h3 class="anchored" data-anchor-id="model-loading-and-configuration" id="model-loading-and-configuration">Model Loading and Configuration</h3>
<div id="52ba3bd8" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VLMFineTuner(pl.LightningModule):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_name, learning_rate<span class="op">=</span><span class="fl">1e-4</span>, freeze_vision<span class="op">=</span><span class="va">False</span>):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> AutoModel.from_pretrained(model_name)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processor <span class="op">=</span> AutoProcessor.from_pretrained(model_name)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.learning_rate <span class="op">=</span> learning_rate</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Freeze vision encoder if specified</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> freeze_vision:</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.model.vision_model.parameters():</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>                param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.optim.AdamW(</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>            <span class="bu">filter</span>(<span class="kw">lambda</span> p: p.requires_grad, <span class="va">self</span>.parameters()),</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>            lr<span class="op">=</span><span class="va">self</span>.learning_rate,</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>            weight_decay<span class="op">=</span><span class="fl">0.01</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span></code></pre></div></div>
</div>
</section>
<section id="custom-dataset-implementation" class="level3">
<h3 class="anchored" data-anchor-id="custom-dataset-implementation" id="custom-dataset-implementation">Custom Dataset Implementation</h3>
<div id="78fc4ed6" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionLanguageDataset(Dataset):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_path, processor, max_length<span class="op">=</span><span class="dv">512</span>):</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(data_path, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.data <span class="op">=</span> json.load(f)</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processor <span class="op">=</span> processor</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_length <span class="op">=</span> max_length</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.data)</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        item <span class="op">=</span> <span class="va">self</span>.data[idx]</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(item[<span class="st">'image_path'</span>]).convert(<span class="st">'RGB'</span>)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        text <span class="op">=</span> item[<span class="st">'caption'</span>]</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process inputs</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        inputs <span class="op">=</span> <span class="va">self</span>.processor(</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>            images<span class="op">=</span>image,</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>            text<span class="op">=</span>text,</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>            return_tensors<span class="op">=</span><span class="st">"pt"</span>,</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            padding<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            truncation<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>            max_length<span class="op">=</span><span class="va">self</span>.max_length</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>            <span class="st">'pixel_values'</span>: inputs[<span class="st">'pixel_values'</span>].squeeze(),</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>            <span class="st">'input_ids'</span>: inputs[<span class="st">'input_ids'</span>].squeeze(),</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>            <span class="st">'attention_mask'</span>: inputs[<span class="st">'attention_mask'</span>].squeeze(),</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>            <span class="st">'labels'</span>: inputs[<span class="st">'input_ids'</span>].squeeze()</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</div>
</section>
<section id="lora-implementation" class="level3">
<h3 class="anchored" data-anchor-id="lora-implementation" id="lora-implementation">LoRA Implementation</h3>
<div id="302bb9c7" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRALayer(nn.Module):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_features, out_features, rank<span class="op">=</span><span class="dv">16</span>, alpha<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rank <span class="op">=</span> rank</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> alpha</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_A <span class="op">=</span> nn.Parameter(torch.randn(rank, in_features))</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_B <span class="op">=</span> nn.Parameter(torch.zeros(out_features, rank))</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scaling <span class="op">=</span> <span class="va">self</span>.alpha <span class="op">/</span> <span class="va">self</span>.rank</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, original_forward):</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> original_forward(x)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        lora_result <span class="op">=</span> (x <span class="op">@</span> <span class="va">self</span>.lora_A.T <span class="op">@</span> <span class="va">self</span>.lora_B.T) <span class="op">*</span> <span class="va">self</span>.scaling</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result <span class="op">+</span> lora_result</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> apply_lora_to_model(model, rank<span class="op">=</span><span class="dv">16</span>, alpha<span class="op">=</span><span class="dv">16</span>, target_modules<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply LoRA to specified modules in the model"""</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> target_modules <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        target_modules <span class="op">=</span> [<span class="st">'query'</span>, <span class="st">'key'</span>, <span class="st">'value'</span>, <span class="st">'dense'</span>]</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">any</span>(target <span class="kw">in</span> name <span class="cf">for</span> target <span class="kw">in</span> target_modules):</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, nn.Linear):</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>                lora_layer <span class="op">=</span> LoRALayer(</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>                    module.in_features, </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>                    module.out_features, </span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>                    rank, </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>                    alpha</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Replace the module with LoRA-enhanced version</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>                parent <span class="op">=</span> model</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> attr <span class="kw">in</span> name.split(<span class="st">'.'</span>)[:<span class="op">-</span><span class="dv">1</span>]:</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>                    parent <span class="op">=</span> <span class="bu">getattr</span>(parent, attr)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>                <span class="bu">setattr</span>(parent, name.split(<span class="st">'.'</span>)[<span class="op">-</span><span class="dv">1</span>], lora_layer)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span></code></pre></div></div>
</div>
</section>
<section id="training-loop" class="level3">
<h3 class="anchored" data-anchor-id="training-loop" id="training-loop">Training Loop</h3>
<div id="8cbd91ee" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_model(model, train_loader, val_loader, num_epochs<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Train the VLM with comprehensive monitoring and checkpointing"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup callbacks</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    callbacks <span class="op">=</span> [</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        pl.callbacks.ModelCheckpoint(</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>            monitor<span class="op">=</span><span class="st">'val_loss'</span>,</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>            mode<span class="op">=</span><span class="st">'min'</span>,</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>            save_top_k<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            filename<span class="op">=</span><span class="st">'</span><span class="sc">{epoch}</span><span class="st">-</span><span class="sc">{val_loss:.2f}</span><span class="st">'</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        pl.callbacks.EarlyStopping(</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            monitor<span class="op">=</span><span class="st">'val_loss'</span>,</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            patience<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            mode<span class="op">=</span><span class="st">'min'</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        pl.callbacks.LearningRateMonitor(logging_interval<span class="op">=</span><span class="st">'step'</span>)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup trainer</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        max_epochs<span class="op">=</span>num_epochs,</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        accelerator<span class="op">=</span><span class="st">'gpu'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>,</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        precision<span class="op">=</span><span class="dv">16</span>,  <span class="co"># Mixed precision training</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        gradient_clip_val<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        accumulate_grad_batches<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        val_check_interval<span class="op">=</span><span class="fl">0.5</span>,</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        callbacks<span class="op">=</span>callbacks,</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        logger<span class="op">=</span>pl.loggers.TensorBoardLogger(<span class="st">'logs/'</span>)</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train the model</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>    trainer.fit(model, train_loader, val_loader)</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> trainer</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize model</span></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> VLMFineTuner(</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>        model_name<span class="op">=</span><span class="st">"Salesforce/blip2-opt-2.7b"</span>,</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>        learning_rate<span class="op">=</span><span class="fl">1e-4</span>,</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>        freeze_vision<span class="op">=</span><span class="va">True</span></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create datasets</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> VisionLanguageDataset(</span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>        <span class="st">'train_data.json'</span>, </span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a>        model.processor</span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>    val_dataset <span class="op">=</span> VisionLanguageDataset(</span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a>        <span class="st">'val_data.json'</span>, </span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>        model.processor</span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-55"><a href="#cb7-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-56"><a href="#cb7-56" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create data loaders</span></span>
<span id="cb7-57"><a href="#cb7-57" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(</span>
<span id="cb7-58"><a href="#cb7-58" aria-hidden="true" tabindex="-1"></a>        train_dataset, </span>
<span id="cb7-59"><a href="#cb7-59" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span><span class="dv">8</span>, </span>
<span id="cb7-60"><a href="#cb7-60" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">True</span>, </span>
<span id="cb7-61"><a href="#cb7-61" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span></span>
<span id="cb7-62"><a href="#cb7-62" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-63"><a href="#cb7-63" aria-hidden="true" tabindex="-1"></a>    val_loader <span class="op">=</span> DataLoader(</span>
<span id="cb7-64"><a href="#cb7-64" aria-hidden="true" tabindex="-1"></a>        val_dataset, </span>
<span id="cb7-65"><a href="#cb7-65" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span><span class="dv">8</span>, </span>
<span id="cb7-66"><a href="#cb7-66" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">False</span>, </span>
<span id="cb7-67"><a href="#cb7-67" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span></span>
<span id="cb7-68"><a href="#cb7-68" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-69"><a href="#cb7-69" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-70"><a href="#cb7-70" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train model</span></span>
<span id="cb7-71"><a href="#cb7-71" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> train_model(model, train_loader, val_loader, num_epochs<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb7-72"><a href="#cb7-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-73"><a href="#cb7-73" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb7-74"><a href="#cb7-74" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="evaluation-and-metrics" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-and-metrics" id="evaluation-and-metrics">Evaluation and Metrics</h2>
<section id="task-specific-metrics" class="level3">
<h3 class="anchored" data-anchor-id="task-specific-metrics" id="task-specific-metrics">Task-specific Metrics</h3>
<div id="7f6b54d5" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchmetrics.text <span class="im">import</span> BLEUScore, ROUGEScore</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchmetrics.retrieval <span class="im">import</span> RetrievalRecall</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VLMEvaluator:</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bleu <span class="op">=</span> BLEUScore()</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rouge <span class="op">=</span> ROUGEScore()</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.recall_at_k <span class="op">=</span> RetrievalRecall(k<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_captioning(<span class="va">self</span>, predictions, references):</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate image captioning performance"""</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        metrics <span class="op">=</span> {}</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># BLEU scores</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'bleu_1'</span>] <span class="op">=</span> <span class="va">self</span>.bleu(predictions, references, n_gram<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'bleu_4'</span>] <span class="op">=</span> <span class="va">self</span>.bleu(predictions, references, n_gram<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># ROUGE-L</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'rouge_l'</span>] <span class="op">=</span> <span class="va">self</span>.rouge(predictions, references)</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> metrics</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_retrieval(<span class="va">self</span>, query_embeddings, candidate_embeddings, relevance_labels):</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate image-text retrieval performance"""</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate similarity scores</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        similarity_scores <span class="op">=</span> torch.mm(query_embeddings, candidate_embeddings.T)</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate recall@k</span></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        recall <span class="op">=</span> <span class="va">self</span>.recall_at_k(similarity_scores, relevance_labels)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">'recall_at_5'</span>: recall}</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_vqa(<span class="va">self</span>, predictions, ground_truth):</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate Visual Question Answering performance"""</span></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simple accuracy for classification-style VQA</span></span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="bu">sum</span>(p.strip().lower() <span class="op">==</span> gt.strip().lower() </span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>                     <span class="cf">for</span> p, gt <span class="kw">in</span> <span class="bu">zip</span>(predictions, ground_truth))</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        accuracy <span class="op">=</span> correct <span class="op">/</span> <span class="bu">len</span>(predictions)</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">'accuracy'</span>: accuracy}</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a><span class="co"># Example evaluation pipeline</span></span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> run_evaluation(model, test_loader, evaluator):</span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>    all_predictions <span class="op">=</span> []</span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>    all_references <span class="op">=</span> []</span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> test_loader:</span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Generate predictions (implementation depends on task)</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model.generate(<span class="op">**</span>batch)</span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> model.processor.batch_decode(</span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a>                outputs, skip_special_tokens<span class="op">=</span><span class="va">True</span></span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a>            all_predictions.extend(predictions)</span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a>            all_references.extend(batch[<span class="st">'references'</span>])</span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluate performance</span></span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a>    metrics <span class="op">=</span> evaluator.evaluate_captioning(all_predictions, all_references)</span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> metrics</span></code></pre></div></div>
</div>
</section>
<section id="cross-modal-understanding-metrics" class="level3">
<h3 class="anchored" data-anchor-id="cross-modal-understanding-metrics" id="cross-modal-understanding-metrics">Cross-modal Understanding Metrics</h3>
<p><strong>Semantic Similarity</strong>: Measuring the alignment between visual and textual representations using cosine similarity or other distance metrics.</p>
<p><strong>Cross-modal Retrieval Performance</strong>: Evaluating how well the model can retrieve relevant text given an image and vice versa.</p>
<p><strong>Compositional Understanding</strong>: Testing the model’s ability to understand complex scenes with multiple objects and relationships.</p>
</section>
<section id="evaluation-protocols" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-protocols" id="evaluation-protocols">Evaluation Protocols</h3>
<p><strong>Zero-shot Evaluation</strong>: Testing on unseen categories or domains without additional training.</p>
<p><strong>Few-shot Learning</strong>: Evaluating adaptation capabilities with limited examples.</p>
<p><strong>Robustness Testing</strong>: Assessing performance under various conditions such as:</p>
<ul>
<li>Different lighting conditions</li>
<li>Occlusions and partial views</li>
<li>Adversarial examples</li>
<li>Out-of-distribution data</li>
</ul>
</section>
</section>
<section id="common-challenges-and-solutions" class="level2">
<h2 class="anchored" data-anchor-id="common-challenges-and-solutions" id="common-challenges-and-solutions">Common Challenges and Solutions</h2>
<section id="catastrophic-forgetting" class="level3">
<h3 class="anchored" data-anchor-id="catastrophic-forgetting" id="catastrophic-forgetting">Catastrophic Forgetting</h3>
<p><strong>Problem</strong>: Fine-tuning can cause models to forget previously learned knowledge.</p>
<p><strong>Solutions</strong>:</p>
<ul>
<li>Elastic Weight Consolidation (EWC)</li>
<li>Progressive neural networks</li>
<li>Memory replay techniques</li>
<li>Regularization-based approaches</li>
</ul>
<div id="b7bcaf0a" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EWCLoss(nn.Module):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Elastic Weight Consolidation loss for preventing catastrophic forgetting"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, dataset, importance<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.importance <span class="op">=</span> importance</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fisher_information <span class="op">=</span> <span class="va">self</span>._compute_fisher_information(dataset)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimal_params <span class="op">=</span> {name: param.clone() </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>                              <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters()}</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _compute_fisher_information(<span class="va">self</span>, dataset):</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute Fisher Information Matrix"""</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        fisher <span class="op">=</span> {}</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, param <span class="kw">in</span> <span class="va">self</span>.model.named_parameters():</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>            fisher[name] <span class="op">=</span> torch.zeros_like(param)</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> dataset:</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.zero_grad()</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> output.loss</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> name, param <span class="kw">in</span> <span class="va">self</span>.model.named_parameters():</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> param.grad <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>                    fisher[name] <span class="op">+=</span> param.grad.<span class="bu">pow</span>(<span class="dv">2</span>)</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize by dataset size</span></span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name <span class="kw">in</span> fisher:</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>            fisher[name] <span class="op">/=</span> <span class="bu">len</span>(dataset)</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> fisher</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, current_loss):</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Add EWC penalty to current loss"""</span></span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>        ewc_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, param <span class="kw">in</span> <span class="va">self</span>.model.named_parameters():</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> name <span class="kw">in</span> <span class="va">self</span>.fisher_information:</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>                ewc_loss <span class="op">+=</span> (<span class="va">self</span>.fisher_information[name] <span class="op">*</span> </span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>                           (param <span class="op">-</span> <span class="va">self</span>.optimal_params[name]).<span class="bu">pow</span>(<span class="dv">2</span>)).<span class="bu">sum</span>()</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> current_loss <span class="op">+</span> <span class="va">self</span>.importance <span class="op">*</span> ewc_loss</span></code></pre></div></div>
</div>
</section>
<section id="mode-collapse" class="level3">
<h3 class="anchored" data-anchor-id="mode-collapse" id="mode-collapse">Mode Collapse</h3>
<p><strong>Problem</strong>: The model becomes overly specialized and loses diversity in outputs.</p>
<p><strong>Solutions</strong>:</p>
<ul>
<li>Diverse training data</li>
<li>Regularization techniques</li>
<li>Multi-task training</li>
<li>Curriculum learning</li>
</ul>
</section>
<section id="data-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="data-efficiency" id="data-efficiency">Data Efficiency</h3>
<p><strong>Problem</strong>: Limited labeled data for specific domains or tasks.</p>
<p><strong>Solutions</strong>:</p>
<ul>
<li>Few-shot learning techniques</li>
<li>Data augmentation strategies</li>
<li>Self-supervised pre-training</li>
<li>Transfer learning from related tasks</li>
</ul>
</section>
<section id="computational-constraints" class="level3">
<h3 class="anchored" data-anchor-id="computational-constraints" id="computational-constraints">Computational Constraints</h3>
<p><strong>Problem</strong>: Limited computational resources for training large VLMs.</p>
<p><strong>Solutions</strong>:</p>
<ul>
<li>Parameter-efficient fine-tuning (LoRA, adapters)</li>
<li>Gradient checkpointing</li>
<li>Mixed precision training</li>
<li>Model pruning and quantization</li>
</ul>
</section>
<section id="evaluation-challenges" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-challenges" id="evaluation-challenges">Evaluation Challenges</h3>
<p><strong>Problem</strong>: Difficulty in comprehensively evaluating multimodal understanding.</p>
<p><strong>Solutions</strong>:</p>
<ul>
<li>Multi-faceted evaluation frameworks</li>
<li>Human evaluation protocols</li>
<li>Automated evaluation metrics</li>
<li>Benchmark development</li>
</ul>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="model-selection" class="level3">
<h3 class="anchored" data-anchor-id="model-selection" id="model-selection">Model Selection</h3>
<p>Choose the appropriate base model based on:</p>
<ul>
<li>Task requirements and complexity</li>
<li>Available computational resources</li>
<li>Target domain characteristics</li>
<li>Performance-efficiency trade-offs</li>
</ul>
</section>
<section id="hyperparameter-optimization" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-optimization" id="hyperparameter-optimization">Hyperparameter Optimization</h3>
<div id="0a351b50" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> optuna</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> optuna.integration <span class="im">import</span> PyTorchLightningPruningCallback</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> objective(trial):</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Optuna objective function for hyperparameter optimization"""</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Suggest hyperparameters</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    learning_rate <span class="op">=</span> trial.suggest_float(<span class="st">'learning_rate'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-3</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> trial.suggest_categorical(<span class="st">'batch_size'</span>, [<span class="dv">4</span>, <span class="dv">8</span>, <span class="dv">16</span>, <span class="dv">32</span>])</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    rank <span class="op">=</span> trial.suggest_int(<span class="st">'lora_rank'</span>, <span class="dv">8</span>, <span class="dv">64</span>, step<span class="op">=</span><span class="dv">8</span>)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    alpha <span class="op">=</span> trial.suggest_int(<span class="st">'lora_alpha'</span>, <span class="dv">8</span>, <span class="dv">64</span>, step<span class="op">=</span><span class="dv">8</span>)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model with suggested hyperparameters</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> VLMFineTuner(</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        model_name<span class="op">=</span><span class="st">"Salesforce/blip2-opt-2.7b"</span>,</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        learning_rate<span class="op">=</span>learning_rate</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply LoRA with suggested parameters</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> apply_lora_to_model(model, rank<span class="op">=</span>rank, alpha<span class="op">=</span>alpha)</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create data loaders with suggested batch size</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>batch_size, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    val_loader <span class="op">=</span> DataLoader(val_dataset, batch_size<span class="op">=</span>batch_size, shuffle<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup trainer with pruning callback</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>        max_epochs<span class="op">=</span><span class="dv">5</span>,</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        callbacks<span class="op">=</span>[PyTorchLightningPruningCallback(trial, monitor<span class="op">=</span><span class="st">"val_loss"</span>)],</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>        logger<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>        enable_checkpointing<span class="op">=</span><span class="va">False</span></span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train and return validation loss</span></span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>    trainer.fit(model, train_loader, val_loader)</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> trainer.callback_metrics[<span class="st">"val_loss"</span>].item()</span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a><span class="co"># Run optimization</span></span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(direction<span class="op">=</span><span class="st">'minimize'</span>)</span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>study.optimize(objective, n_trials<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-43"><a href="#cb10-43" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Best hyperparameters:"</span>, study.best_params)</span></code></pre></div></div>
</div>
</section>
<section id="data-management" class="level3">
<h3 class="anchored" data-anchor-id="data-management" id="data-management">Data Management</h3>
<p>Implement robust data pipelines:</p>
<ul>
<li>Version control for datasets</li>
<li>Data quality validation</li>
<li>Efficient data loading and preprocessing</li>
<li>Balanced sampling strategies</li>
</ul>
</section>
<section id="monitoring-and-debugging" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-and-debugging" id="monitoring-and-debugging">Monitoring and Debugging</h3>
<div id="b0117574" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> wandb</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.loggers <span class="im">import</span> WandbLogger</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdvancedVLMTrainer(pl.LightningModule):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.validation_outputs <span class="op">=</span> []</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log detailed metrics</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'learning_rate'</span>, <span class="va">self</span>.optimizers().param_groups[<span class="dv">0</span>][<span class="st">'lr'</span>])</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log gradient norms</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        total_norm <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> <span class="va">self</span>.parameters():</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> p.grad <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>                param_norm <span class="op">=</span> p.grad.data.norm(<span class="dv">2</span>)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>                total_norm <span class="op">+=</span> param_norm.item() <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        total_norm <span class="op">=</span> total_norm <span class="op">**</span> (<span class="fl">1.</span> <span class="op">/</span> <span class="dv">2</span>)</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'gradient_norm'</span>, total_norm)</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.validation_outputs.append({</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'loss'</span>: loss,</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>            <span class="st">'predictions'</span>: outputs.logits.argmax(dim<span class="op">=-</span><span class="dv">1</span>),</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>            <span class="st">'targets'</span>: batch[<span class="st">'labels'</span>]</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_validation_epoch_end(<span class="va">self</span>):</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute additional metrics</span></span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        all_preds <span class="op">=</span> torch.cat([x[<span class="st">'predictions'</span>] <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.validation_outputs])</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        all_targets <span class="op">=</span> torch.cat([x[<span class="st">'targets'</span>] <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.validation_outputs])</span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Example: compute accuracy</span></span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>        accuracy <span class="op">=</span> (all_preds <span class="op">==</span> all_targets).<span class="bu">float</span>().mean()</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_accuracy'</span>, accuracy)</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Clear validation outputs</span></span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.validation_outputs.clear()</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup advanced logging</span></span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>wandb_logger <span class="op">=</span> WandbLogger(project<span class="op">=</span><span class="st">"vlm-finetuning"</span>)</span>
<span id="cb11-55"><a href="#cb11-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-56"><a href="#cb11-56" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb11-57"><a href="#cb11-57" aria-hidden="true" tabindex="-1"></a>    logger<span class="op">=</span>wandb_logger,</span>
<span id="cb11-58"><a href="#cb11-58" aria-hidden="true" tabindex="-1"></a>    callbacks<span class="op">=</span>[</span>
<span id="cb11-59"><a href="#cb11-59" aria-hidden="true" tabindex="-1"></a>        pl.callbacks.ModelCheckpoint(monitor<span class="op">=</span><span class="st">'val_loss'</span>),</span>
<span id="cb11-60"><a href="#cb11-60" aria-hidden="true" tabindex="-1"></a>        pl.callbacks.LearningRateMonitor()</span>
<span id="cb11-61"><a href="#cb11-61" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb11-62"><a href="#cb11-62" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="reproducibility" class="level3">
<h3 class="anchored" data-anchor-id="reproducibility" id="reproducibility">Reproducibility</h3>
<p>Ensure experimental reproducibility:</p>
<div id="a6f76cff" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> set_seed(seed<span class="op">=</span><span class="dv">42</span>):</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Set seed for reproducibility"""</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    random.seed(seed)</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    np.random.seed(seed)</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    torch.manual_seed(seed)</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    torch.cuda.manual_seed(seed)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    torch.cuda.manual_seed_all(seed)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    torch.backends.cudnn.deterministic <span class="op">=</span> <span class="va">True</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    torch.backends.cudnn.benchmark <span class="op">=</span> <span class="va">False</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'PYTHONHASHSEED'</span>] <span class="op">=</span> <span class="bu">str</span>(seed)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Set seed at the beginning of experiments</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>set_seed(<span class="dv">42</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="case-studies" class="level2">
<h2 class="anchored" data-anchor-id="case-studies" id="case-studies">Case Studies</h2>
<section id="case-study-1-medical-image-analysis" class="level3">
<h3 class="anchored" data-anchor-id="case-study-1-medical-image-analysis" id="case-study-1-medical-image-analysis">Case Study 1: Medical Image Analysis</h3>
<p><strong>Objective</strong>: Fine-tune a VLM for radiology report generation.</p>
<p><strong>Approach</strong>:</p>
<ul>
<li>Base model: BLIP-2</li>
<li>Dataset: MIMIC-CXR with chest X-rays and reports</li>
<li>Fine-tuning strategy: LoRA with frozen vision encoder</li>
<li>Evaluation: BLEU, ROUGE, clinical accuracy metrics</li>
</ul>
<div id="72d9a8ed" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Medical domain-specific preprocessing</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MedicalImageProcessor:</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, processor):</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processor <span class="op">=</span> processor</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.medical_vocab <span class="op">=</span> <span class="va">self</span>._load_medical_vocabulary()</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _load_medical_vocabulary(<span class="va">self</span>):</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load medical terminology and abbreviations"""</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'CXR'</span>: <span class="st">'chest X-ray'</span>,</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'AP'</span>: <span class="st">'anteroposterior'</span>,</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">'PA'</span>: <span class="st">'posteroanterior'</span>,</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># ... more medical terms</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> preprocess_report(<span class="va">self</span>, report):</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Expand medical abbreviations and normalize text"""</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> abbrev, full_form <span class="kw">in</span> <span class="va">self</span>.medical_vocab.items():</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>            report <span class="op">=</span> report.replace(abbrev, full_form)</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> report</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Specialized evaluation for medical domain</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MedicalEvaluator:</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.clinical_keywords <span class="op">=</span> [</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>            <span class="st">'pneumonia'</span>, <span class="st">'pneumothorax'</span>, <span class="st">'pleural_effusion'</span>,</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>            <span class="st">'cardiomegaly'</span>, <span class="st">'atelectasis'</span></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_clinical_accuracy(<span class="va">self</span>, predictions, references):</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate clinical finding detection accuracy"""</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>        accuracy_scores <span class="op">=</span> {}</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> keyword <span class="kw">in</span> <span class="va">self</span>.clinical_keywords:</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>            pred_positive <span class="op">=</span> [keyword.lower() <span class="kw">in</span> pred.lower() <span class="cf">for</span> pred <span class="kw">in</span> predictions]</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>            ref_positive <span class="op">=</span> [keyword.lower() <span class="kw">in</span> ref.lower() <span class="cf">for</span> ref <span class="kw">in</span> references]</span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Calculate precision, recall, F1</span></span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>            tp <span class="op">=</span> <span class="bu">sum</span>(p <span class="kw">and</span> r <span class="cf">for</span> p, r <span class="kw">in</span> <span class="bu">zip</span>(pred_positive, ref_positive))</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>            fp <span class="op">=</span> <span class="bu">sum</span>(p <span class="kw">and</span> <span class="kw">not</span> r <span class="cf">for</span> p, r <span class="kw">in</span> <span class="bu">zip</span>(pred_positive, ref_positive))</span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>            fn <span class="op">=</span> <span class="bu">sum</span>(<span class="kw">not</span> p <span class="kw">and</span> r <span class="cf">for</span> p, r <span class="kw">in</span> <span class="bu">zip</span>(pred_positive, ref_positive))</span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>            precision <span class="op">=</span> tp <span class="op">/</span> (tp <span class="op">+</span> fp) <span class="cf">if</span> (tp <span class="op">+</span> fp) <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>            recall <span class="op">=</span> tp <span class="op">/</span> (tp <span class="op">+</span> fn) <span class="cf">if</span> (tp <span class="op">+</span> fn) <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>            f1 <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> precision <span class="op">*</span> recall <span class="op">/</span> (precision <span class="op">+</span> recall) <span class="cf">if</span> (precision <span class="op">+</span> recall) <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>            accuracy_scores[keyword] <span class="op">=</span> {</span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">'precision'</span>: precision,</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>                <span class="st">'recall'</span>: recall,</span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>                <span class="st">'f1'</span>: f1</span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> accuracy_scores</span></code></pre></div></div>
</div>
<p><strong>Results</strong>: Achieved 15% improvement in clinical accuracy while maintaining general language capabilities.</p>
<p><strong>Key Insights</strong>:</p>
<ul>
<li>Domain-specific vocabulary required careful handling</li>
<li>Multi-task training with classification improved performance</li>
<li>Regular validation with medical experts was crucial</li>
</ul>
</section>
<section id="case-study-2-e-commerce-product-description" class="level3">
<h3 class="anchored" data-anchor-id="case-study-2-e-commerce-product-description" id="case-study-2-e-commerce-product-description">Case Study 2: E-commerce Product Description</h3>
<p><strong>Objective</strong>: Develop automated product description generation from images.</p>
<p><strong>Approach</strong>:</p>
<ul>
<li>Base model: LLaVA</li>
<li>Dataset: Custom e-commerce image-description pairs</li>
<li>Fine-tuning strategy: Full fine-tuning with curriculum learning</li>
<li>Evaluation: Human preference scores, conversion metrics</li>
</ul>
<div id="0cd747af" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EcommerceDataProcessor:</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.category_templates <span class="op">=</span> {</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>            <span class="st">'clothing'</span>: <span class="st">"This </span><span class="sc">{color}</span><span class="st"> </span><span class="sc">{item_type}</span><span class="st"> features </span><span class="sc">{description}</span><span class="st">. Perfect for </span><span class="sc">{occasion}</span><span class="st">."</span>,</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>            <span class="st">'electronics'</span>: <span class="st">"The </span><span class="sc">{brand}</span><span class="st"> </span><span class="sc">{product_name}</span><span class="st"> offers </span><span class="sc">{features}</span><span class="st">. Ideal for </span><span class="sc">{use_case}</span><span class="st">."</span>,</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">'home'</span>: <span class="st">"This </span><span class="sc">{material}</span><span class="st"> </span><span class="sc">{item_type}</span><span class="st"> brings </span><span class="sc">{style}</span><span class="st"> to your </span><span class="sc">{room}</span><span class="st">."</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_template_augmentations(<span class="va">self</span>, product_data):</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate template-based augmentations for training data"""</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        category <span class="op">=</span> product_data[<span class="st">'category'</span>]</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        template <span class="op">=</span> <span class="va">self</span>.category_templates.get(category, <span class="st">""</span>)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> template:</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> template.<span class="bu">format</span>(<span class="op">**</span>product_data)</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> product_data[<span class="st">'original_description'</span>]</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a><span class="co"># A/B testing framework for real-world validation</span></span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ABTestingFramework:</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.test_groups <span class="op">=</span> {}</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics <span class="op">=</span> {}</span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> assign_user_to_group(<span class="va">self</span>, user_id, test_name):</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Assign user to control or treatment group"""</span></span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> hashlib</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>        hash_val <span class="op">=</span> <span class="bu">int</span>(hashlib.md5(<span class="ss">f"</span><span class="sc">{</span>user_id<span class="sc">}</span><span class="ss">_</span><span class="sc">{</span>test_name<span class="sc">}</span><span class="ss">"</span>.encode()).hexdigest(), <span class="dv">16</span>)</span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="st">"treatment"</span> <span class="cf">if</span> hash_val <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span> <span class="cf">else</span> <span class="st">"control"</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> log_conversion(<span class="va">self</span>, user_id, test_name, converted):</span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Log conversion event for analysis"""</span></span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> test_name <span class="kw">not</span> <span class="kw">in</span> <span class="va">self</span>.metrics:</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.metrics[test_name] <span class="op">=</span> {<span class="st">'control'</span>: [], <span class="st">'treatment'</span>: []}</span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>        group <span class="op">=</span> <span class="va">self</span>.assign_user_to_group(user_id, test_name)</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[test_name][group].append(converted)</span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_results(<span class="va">self</span>, test_name):</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Analyze A/B test results"""</span></span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>        control_conversions <span class="op">=</span> <span class="va">self</span>.metrics[test_name][<span class="st">'control'</span>]</span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>        treatment_conversions <span class="op">=</span> <span class="va">self</span>.metrics[test_name][<span class="st">'treatment'</span>]</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>        control_rate <span class="op">=</span> <span class="bu">sum</span>(control_conversions) <span class="op">/</span> <span class="bu">len</span>(control_conversions)</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>        treatment_rate <span class="op">=</span> <span class="bu">sum</span>(treatment_conversions) <span class="op">/</span> <span class="bu">len</span>(treatment_conversions)</span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">'control_rate'</span>: control_rate,</span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>            <span class="st">'treatment_rate'</span>: treatment_rate,</span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>            <span class="st">'lift'</span>: (treatment_rate <span class="op">-</span> control_rate) <span class="op">/</span> control_rate <span class="op">*</span> <span class="dv">100</span></span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</div>
<p><strong>Results</strong>: Generated descriptions led to 12% increase in click-through rates.</p>
<p><strong>Key Insights</strong>:</p>
<ul>
<li>Brand-specific terminology required specialized training</li>
<li>A/B testing was essential for real-world validation</li>
<li>Template-based augmentation improved consistency</li>
</ul>
</section>
<section id="case-study-3-educational-content-creation" class="level3">
<h3 class="anchored" data-anchor-id="case-study-3-educational-content-creation" id="case-study-3-educational-content-creation">Case Study 3: Educational Content Creation</h3>
<p><strong>Objective</strong>: Create an assistant for generating educational materials from visual content.</p>
<p><strong>Approach</strong>:</p>
<ul>
<li>Base model: GPT-4V (via API fine-tuning)</li>
<li>Dataset: Educational images with detailed explanations</li>
<li>Fine-tuning strategy: Instruction tuning with reinforcement learning</li>
<li>Evaluation: Educational effectiveness metrics, user engagement</li>
</ul>
<div id="62cba6b4" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EducationalContentGenerator:</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, difficulty_levels<span class="op">=</span>[<span class="st">'elementary'</span>, <span class="st">'middle'</span>, <span class="st">'high_school'</span>, <span class="st">'college'</span>]):</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.difficulty_levels <span class="op">=</span> difficulty_levels</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pedagogical_templates <span class="op">=</span> {</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>            <span class="st">'elementary'</span>: {</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>                <span class="st">'vocabulary'</span>: <span class="st">'simple'</span>,</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>                <span class="st">'sentence_length'</span>: <span class="st">'short'</span>,</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>                <span class="st">'examples'</span>: <span class="st">'concrete'</span>,</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>                <span class="st">'analogies'</span>: <span class="st">'familiar'</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'middle'</span>: {</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>                <span class="st">'vocabulary'</span>: <span class="st">'intermediate'</span>, </span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>                <span class="st">'sentence_length'</span>: <span class="st">'medium'</span>,</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>                <span class="st">'examples'</span>: <span class="st">'relatable'</span>,</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>                <span class="st">'analogies'</span>: <span class="st">'accessible'</span></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>            <span class="st">'high_school'</span>: {</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>                <span class="st">'vocabulary'</span>: <span class="st">'advanced'</span>,</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>                <span class="st">'sentence_length'</span>: <span class="st">'varied'</span>,</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>                <span class="st">'examples'</span>: <span class="st">'detailed'</span>,</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>                <span class="st">'analogies'</span>: <span class="st">'sophisticated'</span></span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>            <span class="st">'college'</span>: {</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>                <span class="st">'vocabulary'</span>: <span class="st">'technical'</span>,</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>                <span class="st">'sentence_length'</span>: <span class="st">'complex'</span>,</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>                <span class="st">'examples'</span>: <span class="st">'comprehensive'</span>,</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">'analogies'</span>: <span class="st">'abstract'</span></span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> adapt_content_difficulty(<span class="va">self</span>, content, target_level):</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Adapt educational content to target difficulty level"""</span></span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        template <span class="op">=</span> <span class="va">self</span>.pedagogical_templates[target_level]</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># This would integrate with the VLM to generate level-appropriate content</span></span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        adapted_prompt <span class="op">=</span> <span class="ss">f"""</span></span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a><span class="ss">        Explain this concept for </span><span class="sc">{</span>target_level<span class="sc">}</span><span class="ss"> students using:</span></span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a><span class="ss">        - </span><span class="sc">{</span>template[<span class="st">'vocabulary'</span>]<span class="sc">}</span><span class="ss"> vocabulary</span></span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a><span class="ss">        - </span><span class="sc">{</span>template[<span class="st">'sentence_length'</span>]<span class="sc">}</span><span class="ss"> sentences</span></span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a><span class="ss">        - </span><span class="sc">{</span>template[<span class="st">'examples'</span>]<span class="sc">}</span><span class="ss"> examples</span></span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a><span class="ss">        - </span><span class="sc">{</span>template[<span class="st">'analogies'</span>]<span class="sc">}</span><span class="ss"> analogies</span></span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a><span class="ss">        </span></span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a><span class="ss">        Original content: </span><span class="sc">{</span>content<span class="sc">}</span></span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a><span class="ss">        """</span></span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> adapted_prompt</span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a><span class="co"># Reinforcement Learning from Human Feedback (RLHF) implementation</span></span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EducationalRLHF:</span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, reward_model):</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.reward_model <span class="op">=</span> reward_model</span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ppo_trainer <span class="op">=</span> <span class="va">None</span>  <span class="co"># Would initialize PPO trainer</span></span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> collect_human_feedback(<span class="va">self</span>, generated_content, images):</span>
<span id="cb15-56"><a href="#cb15-56" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Collect feedback from educators on generated content"""</span></span>
<span id="cb15-57"><a href="#cb15-57" aria-hidden="true" tabindex="-1"></a>        feedback_criteria <span class="op">=</span> [</span>
<span id="cb15-58"><a href="#cb15-58" aria-hidden="true" tabindex="-1"></a>            <span class="st">'accuracy'</span>,</span>
<span id="cb15-59"><a href="#cb15-59" aria-hidden="true" tabindex="-1"></a>            <span class="st">'clarity'</span>, </span>
<span id="cb15-60"><a href="#cb15-60" aria-hidden="true" tabindex="-1"></a>            <span class="st">'age_appropriateness'</span>,</span>
<span id="cb15-61"><a href="#cb15-61" aria-hidden="true" tabindex="-1"></a>            <span class="st">'engagement'</span>,</span>
<span id="cb15-62"><a href="#cb15-62" aria-hidden="true" tabindex="-1"></a>            <span class="st">'pedagogical_value'</span></span>
<span id="cb15-63"><a href="#cb15-63" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb15-64"><a href="#cb15-64" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-65"><a href="#cb15-65" aria-hidden="true" tabindex="-1"></a>        <span class="co"># This would interface with human evaluators</span></span>
<span id="cb15-66"><a href="#cb15-66" aria-hidden="true" tabindex="-1"></a>        feedback <span class="op">=</span> {}</span>
<span id="cb15-67"><a href="#cb15-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> criterion <span class="kw">in</span> feedback_criteria:</span>
<span id="cb15-68"><a href="#cb15-68" aria-hidden="true" tabindex="-1"></a>            feedback[criterion] <span class="op">=</span> <span class="va">self</span>.get_human_rating(</span>
<span id="cb15-69"><a href="#cb15-69" aria-hidden="true" tabindex="-1"></a>                generated_content, images, criterion</span>
<span id="cb15-70"><a href="#cb15-70" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb15-71"><a href="#cb15-71" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-72"><a href="#cb15-72" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> feedback</span>
<span id="cb15-73"><a href="#cb15-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-74"><a href="#cb15-74" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_reward_model(<span class="va">self</span>, feedback_data):</span>
<span id="cb15-75"><a href="#cb15-75" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train reward model from human feedback"""</span></span>
<span id="cb15-76"><a href="#cb15-76" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would train a model to predict human preferences</span></span>
<span id="cb15-77"><a href="#cb15-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb15-78"><a href="#cb15-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-79"><a href="#cb15-79" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> optimize_with_ppo(<span class="va">self</span>, training_data):</span>
<span id="cb15-80"><a href="#cb15-80" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Optimize model using PPO with learned reward model"""</span></span>
<span id="cb15-81"><a href="#cb15-81" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would use PPO to optimize policy</span></span>
<span id="cb15-82"><a href="#cb15-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb15-83"><a href="#cb15-83" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-84"><a href="#cb15-84" aria-hidden="true" tabindex="-1"></a><span class="co"># Educational effectiveness evaluation</span></span>
<span id="cb15-85"><a href="#cb15-85" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EducationalEvaluator:</span>
<span id="cb15-86"><a href="#cb15-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb15-87"><a href="#cb15-87" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bloom_taxonomy_levels <span class="op">=</span> [</span>
<span id="cb15-88"><a href="#cb15-88" aria-hidden="true" tabindex="-1"></a>            <span class="st">'remember'</span>, <span class="st">'understand'</span>, <span class="st">'apply'</span>, </span>
<span id="cb15-89"><a href="#cb15-89" aria-hidden="true" tabindex="-1"></a>            <span class="st">'analyze'</span>, <span class="st">'evaluate'</span>, <span class="st">'create'</span></span>
<span id="cb15-90"><a href="#cb15-90" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb15-91"><a href="#cb15-91" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-92"><a href="#cb15-92" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> assess_learning_objectives(<span class="va">self</span>, content, learning_objectives):</span>
<span id="cb15-93"><a href="#cb15-93" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Assess how well content meets learning objectives"""</span></span>
<span id="cb15-94"><a href="#cb15-94" aria-hidden="true" tabindex="-1"></a>        coverage_scores <span class="op">=</span> {}</span>
<span id="cb15-95"><a href="#cb15-95" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-96"><a href="#cb15-96" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> objective <span class="kw">in</span> learning_objectives:</span>
<span id="cb15-97"><a href="#cb15-97" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Use NLP techniques to measure objective coverage</span></span>
<span id="cb15-98"><a href="#cb15-98" aria-hidden="true" tabindex="-1"></a>            coverage_scores[objective] <span class="op">=</span> <span class="va">self</span>.calculate_coverage_score(</span>
<span id="cb15-99"><a href="#cb15-99" aria-hidden="true" tabindex="-1"></a>                content, objective</span>
<span id="cb15-100"><a href="#cb15-100" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb15-101"><a href="#cb15-101" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-102"><a href="#cb15-102" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> coverage_scores</span>
<span id="cb15-103"><a href="#cb15-103" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-104"><a href="#cb15-104" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_cognitive_load(<span class="va">self</span>, content):</span>
<span id="cb15-105"><a href="#cb15-105" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate cognitive load of educational content"""</span></span>
<span id="cb15-106"><a href="#cb15-106" aria-hidden="true" tabindex="-1"></a>        metrics <span class="op">=</span> {</span>
<span id="cb15-107"><a href="#cb15-107" aria-hidden="true" tabindex="-1"></a>            <span class="st">'intrinsic_load'</span>: <span class="va">self</span>.measure_concept_complexity(content),</span>
<span id="cb15-108"><a href="#cb15-108" aria-hidden="true" tabindex="-1"></a>            <span class="st">'extraneous_load'</span>: <span class="va">self</span>.measure_irrelevant_information(content),</span>
<span id="cb15-109"><a href="#cb15-109" aria-hidden="true" tabindex="-1"></a>            <span class="st">'germane_load'</span>: <span class="va">self</span>.measure_schema_construction(content)</span>
<span id="cb15-110"><a href="#cb15-110" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb15-111"><a href="#cb15-111" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-112"><a href="#cb15-112" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> metrics</span>
<span id="cb15-113"><a href="#cb15-113" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-114"><a href="#cb15-114" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> measure_engagement_potential(<span class="va">self</span>, content, target_audience):</span>
<span id="cb15-115"><a href="#cb15-115" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Measure potential engagement level of content"""</span></span>
<span id="cb15-116"><a href="#cb15-116" aria-hidden="true" tabindex="-1"></a>        engagement_factors <span class="op">=</span> [</span>
<span id="cb15-117"><a href="#cb15-117" aria-hidden="true" tabindex="-1"></a>            <span class="st">'visual_appeal'</span>,</span>
<span id="cb15-118"><a href="#cb15-118" aria-hidden="true" tabindex="-1"></a>            <span class="st">'interactivity'</span>,</span>
<span id="cb15-119"><a href="#cb15-119" aria-hidden="true" tabindex="-1"></a>            <span class="st">'relevance'</span>,</span>
<span id="cb15-120"><a href="#cb15-120" aria-hidden="true" tabindex="-1"></a>            <span class="st">'challenge_level'</span>,</span>
<span id="cb15-121"><a href="#cb15-121" aria-hidden="true" tabindex="-1"></a>            <span class="st">'curiosity_gap'</span></span>
<span id="cb15-122"><a href="#cb15-122" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb15-123"><a href="#cb15-123" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-124"><a href="#cb15-124" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> {}</span>
<span id="cb15-125"><a href="#cb15-125" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> factor <span class="kw">in</span> engagement_factors:</span>
<span id="cb15-126"><a href="#cb15-126" aria-hidden="true" tabindex="-1"></a>            scores[factor] <span class="op">=</span> <span class="va">self</span>.score_engagement_factor(</span>
<span id="cb15-127"><a href="#cb15-127" aria-hidden="true" tabindex="-1"></a>                content, factor, target_audience</span>
<span id="cb15-128"><a href="#cb15-128" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb15-129"><a href="#cb15-129" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-130"><a href="#cb15-130" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> scores</span>
<span id="cb15-131"><a href="#cb15-131" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-132"><a href="#cb15-132" aria-hidden="true" tabindex="-1"></a><span class="co"># Comprehensive evaluation pipeline for educational VLM</span></span>
<span id="cb15-133"><a href="#cb15-133" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> run_educational_evaluation(model, test_dataset, evaluator):</span>
<span id="cb15-134"><a href="#cb15-134" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Run comprehensive evaluation for educational VLM"""</span></span>
<span id="cb15-135"><a href="#cb15-135" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-136"><a href="#cb15-136" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> {</span>
<span id="cb15-137"><a href="#cb15-137" aria-hidden="true" tabindex="-1"></a>        <span class="st">'content_quality'</span>: {},</span>
<span id="cb15-138"><a href="#cb15-138" aria-hidden="true" tabindex="-1"></a>        <span class="st">'learning_effectiveness'</span>: {},</span>
<span id="cb15-139"><a href="#cb15-139" aria-hidden="true" tabindex="-1"></a>        <span class="st">'engagement_metrics'</span>: {},</span>
<span id="cb15-140"><a href="#cb15-140" aria-hidden="true" tabindex="-1"></a>        <span class="st">'accessibility'</span>: {}</span>
<span id="cb15-141"><a href="#cb15-141" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb15-142"><a href="#cb15-142" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-143"><a href="#cb15-143" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> test_dataset:</span>
<span id="cb15-144"><a href="#cb15-144" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate educational content</span></span>
<span id="cb15-145"><a href="#cb15-145" aria-hidden="true" tabindex="-1"></a>        generated_content <span class="op">=</span> model.generate_educational_content(</span>
<span id="cb15-146"><a href="#cb15-146" aria-hidden="true" tabindex="-1"></a>            batch[<span class="st">'images'</span>], </span>
<span id="cb15-147"><a href="#cb15-147" aria-hidden="true" tabindex="-1"></a>            batch[<span class="st">'learning_objectives'</span>],</span>
<span id="cb15-148"><a href="#cb15-148" aria-hidden="true" tabindex="-1"></a>            batch[<span class="st">'target_level'</span>]</span>
<span id="cb15-149"><a href="#cb15-149" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-150"><a href="#cb15-150" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-151"><a href="#cb15-151" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate content quality</span></span>
<span id="cb15-152"><a href="#cb15-152" aria-hidden="true" tabindex="-1"></a>        quality_scores <span class="op">=</span> evaluator.assess_learning_objectives(</span>
<span id="cb15-153"><a href="#cb15-153" aria-hidden="true" tabindex="-1"></a>            generated_content, </span>
<span id="cb15-154"><a href="#cb15-154" aria-hidden="true" tabindex="-1"></a>            batch[<span class="st">'learning_objectives'</span>]</span>
<span id="cb15-155"><a href="#cb15-155" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-156"><a href="#cb15-156" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'content_quality'</span>].update(quality_scores)</span>
<span id="cb15-157"><a href="#cb15-157" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-158"><a href="#cb15-158" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate cognitive load</span></span>
<span id="cb15-159"><a href="#cb15-159" aria-hidden="true" tabindex="-1"></a>        cognitive_load <span class="op">=</span> evaluator.evaluate_cognitive_load(generated_content)</span>
<span id="cb15-160"><a href="#cb15-160" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'learning_effectiveness'</span>].update(cognitive_load)</span>
<span id="cb15-161"><a href="#cb15-161" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-162"><a href="#cb15-162" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate engagement potential</span></span>
<span id="cb15-163"><a href="#cb15-163" aria-hidden="true" tabindex="-1"></a>        engagement <span class="op">=</span> evaluator.measure_engagement_potential(</span>
<span id="cb15-164"><a href="#cb15-164" aria-hidden="true" tabindex="-1"></a>            generated_content, </span>
<span id="cb15-165"><a href="#cb15-165" aria-hidden="true" tabindex="-1"></a>            batch[<span class="st">'target_audience'</span>]</span>
<span id="cb15-166"><a href="#cb15-166" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-167"><a href="#cb15-167" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'engagement_metrics'</span>].update(engagement)</span>
<span id="cb15-168"><a href="#cb15-168" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-169"><a href="#cb15-169" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span></code></pre></div></div>
</div>
<p><strong>Results</strong>: Improved student comprehension scores by 18% in pilot studies.</p>
<p><strong>Key Insights</strong>:</p>
<ul>
<li>Pedagogical principles needed to be encoded in training</li>
<li>Multi-level difficulty adaptation was crucial</li>
<li>Continuous feedback incorporation improved outcomes</li>
</ul>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="emerging-architectures" class="level3">
<h3 class="anchored" data-anchor-id="emerging-architectures" id="emerging-architectures">Emerging Architectures</h3>
<p><strong>Unified Multimodal Models</strong>: Integration of vision, language, and potentially other modalities in single architectures.</p>
<div id="6b7133ed" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> UnifiedMultimodalModel(nn.Module):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Conceptual unified model architecture"""</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, modality_encoders, fusion_layer, decoder):</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.modality_encoders <span class="op">=</span> nn.ModuleDict(modality_encoders)</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fusion_layer <span class="op">=</span> fusion_layer</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.decoder <span class="op">=</span> decoder</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.modality_weights <span class="op">=</span> nn.Parameter(torch.ones(<span class="bu">len</span>(modality_encoders)))</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, inputs):</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Encode each modality</span></span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>        encoded_modalities <span class="op">=</span> {}</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> modality, data <span class="kw">in</span> inputs.items():</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> modality <span class="kw">in</span> <span class="va">self</span>.modality_encoders:</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>                encoded_modalities[modality] <span class="op">=</span> <span class="va">self</span>.modality_encoders[modality](data)</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Weighted fusion of modalities</span></span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>        weighted_features <span class="op">=</span> []</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, (modality, features) <span class="kw">in</span> <span class="bu">enumerate</span>(encoded_modalities.items()):</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>            weight <span class="op">=</span> torch.softmax(<span class="va">self</span>.modality_weights, dim<span class="op">=</span><span class="dv">0</span>)[i]</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>            weighted_features.append(weight <span class="op">*</span> features)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Fuse modalities</span></span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>        fused_representation <span class="op">=</span> <span class="va">self</span>.fusion_layer(torch.stack(weighted_features))</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate output</span></span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.decoder(fused_representation)</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
<p><strong>Efficient Architectures</strong>: Development of models optimized for mobile and edge deployment.</p>
<div id="c27c251a" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EfficientVLM(nn.Module):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Efficient VLM architecture for edge deployment"""</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vision_backbone<span class="op">=</span><span class="st">'mobilenet'</span>, language_backbone<span class="op">=</span><span class="st">'distilbert'</span>):</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Lightweight vision encoder</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> vision_backbone <span class="op">==</span> <span class="st">'mobilenet'</span>:</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.vision_encoder <span class="op">=</span> <span class="va">self</span>._create_mobilenet_encoder()</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> vision_backbone <span class="op">==</span> <span class="st">'efficientnet'</span>:</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.vision_encoder <span class="op">=</span> <span class="va">self</span>._create_efficientnet_encoder()</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Efficient language encoder</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> language_backbone <span class="op">==</span> <span class="st">'distilbert'</span>:</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.language_encoder <span class="op">=</span> <span class="va">self</span>._create_distilbert_encoder()</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> language_backbone <span class="op">==</span> <span class="st">'tinybert'</span>:</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.language_encoder <span class="op">=</span> <span class="va">self</span>._create_tinybert_encoder()</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Lightweight fusion mechanism</span></span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fusion <span class="op">=</span> nn.MultiheadAttention(embed_dim<span class="op">=</span><span class="dv">256</span>, num_heads<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Quantization-friendly layers</span></span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output_projection <span class="op">=</span> nn.Linear(<span class="dv">256</span>, vocab_size)</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, images, text):</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process with quantization in mind</span></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>        vision_features <span class="op">=</span> <span class="va">self</span>.vision_encoder(images)</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>        language_features <span class="op">=</span> <span class="va">self</span>.language_encoder(text)</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Efficient attention mechanism</span></span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>        fused_features, _ <span class="op">=</span> <span class="va">self</span>.fusion(</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>            vision_features, language_features, language_features</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.output_projection(fused_features)</span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> quantize_model(<span class="va">self</span>):</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Apply quantization for deployment"""</span></span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>        torch.quantization.quantize_dynamic(</span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>, {nn.Linear, nn.Conv2d}, dtype<span class="op">=</span>torch.qint8</span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>        )</span></code></pre></div></div>
</div>
<p><strong>Compositional Models</strong>: Better understanding and generation of complex visual scenes with multiple objects and relationships.</p>
</section>
<section id="advanced-training-techniques" class="level3">
<h3 class="anchored" data-anchor-id="advanced-training-techniques" id="advanced-training-techniques">Advanced Training Techniques</h3>
<p><strong>Self-supervised Learning</strong>: Leveraging unlabeled multimodal data for improved representations.</p>
<div id="2176026f" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SelfSupervisedVLM(pl.LightningModule):</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Self-supervised learning for VLMs"""</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> base_model</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.contrastive_temperature <span class="op">=</span> <span class="fl">0.07</span></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> masked_language_modeling_loss(<span class="va">self</span>, text_inputs, image_context):</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""MLM loss with visual context"""</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mask random tokens</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        masked_inputs, labels <span class="op">=</span> <span class="va">self</span>.mask_tokens(text_inputs)</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Predict masked tokens with visual context</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.base_model(</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>            images<span class="op">=</span>image_context,</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>            input_ids<span class="op">=</span>masked_inputs</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute MLM loss</span></span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        mlm_loss <span class="op">=</span> nn.CrossEntropyLoss()(</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>            outputs.logits.view(<span class="op">-</span><span class="dv">1</span>, outputs.logits.size(<span class="op">-</span><span class="dv">1</span>)),</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>            labels.view(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> mlm_loss</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> image_text_contrastive_loss(<span class="va">self</span>, images, texts):</span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Contrastive loss for image-text alignment"""</span></span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get embeddings</span></span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>        image_embeddings <span class="op">=</span> <span class="va">self</span>.base_model.get_image_features(images)</span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        text_embeddings <span class="op">=</span> <span class="va">self</span>.base_model.get_text_features(texts)</span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize embeddings</span></span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>        image_embeddings <span class="op">=</span> F.normalize(image_embeddings, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>        text_embeddings <span class="op">=</span> F.normalize(text_embeddings, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute similarity matrix</span></span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>        similarity_matrix <span class="op">=</span> torch.matmul(</span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>            image_embeddings, text_embeddings.transpose(<span class="dv">0</span>, <span class="dv">1</span>)</span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>        ) <span class="op">/</span> <span class="va">self</span>.contrastive_temperature</span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create labels (diagonal should be 1)</span></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> images.size(<span class="dv">0</span>)</span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>        labels <span class="op">=</span> torch.arange(batch_size).to(<span class="va">self</span>.device)</span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute contrastive loss</span></span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>        loss_i2t <span class="op">=</span> F.cross_entropy(similarity_matrix, labels)</span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>        loss_t2i <span class="op">=</span> F.cross_entropy(similarity_matrix.transpose(<span class="dv">0</span>, <span class="dv">1</span>), labels)</span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> (loss_i2t <span class="op">+</span> loss_t2i) <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Combined self-supervised training step"""</span></span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>        images <span class="op">=</span> batch[<span class="st">'images'</span>]</span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>        texts <span class="op">=</span> batch[<span class="st">'texts'</span>]</span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multiple self-supervised objectives</span></span>
<span id="cb18-59"><a href="#cb18-59" aria-hidden="true" tabindex="-1"></a>        mlm_loss <span class="op">=</span> <span class="va">self</span>.masked_language_modeling_loss(texts, images)</span>
<span id="cb18-60"><a href="#cb18-60" aria-hidden="true" tabindex="-1"></a>        contrastive_loss <span class="op">=</span> <span class="va">self</span>.image_text_contrastive_loss(images, texts)</span>
<span id="cb18-61"><a href="#cb18-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-62"><a href="#cb18-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combined loss</span></span>
<span id="cb18-63"><a href="#cb18-63" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> mlm_loss <span class="op">+</span> contrastive_loss</span>
<span id="cb18-64"><a href="#cb18-64" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-65"><a href="#cb18-65" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'mlm_loss'</span>, mlm_loss)</span>
<span id="cb18-66"><a href="#cb18-66" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'contrastive_loss'</span>, contrastive_loss) </span>
<span id="cb18-67"><a href="#cb18-67" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'total_loss'</span>, total_loss)</span>
<span id="cb18-68"><a href="#cb18-68" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-69"><a href="#cb18-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss</span></code></pre></div></div>
</div>
<p><strong>Meta-learning</strong>: Enabling rapid adaptation to new tasks with minimal data.</p>
<div id="1bb34ea4" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MAMLForVLM(nn.Module):</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Model-Agnostic Meta-Learning for VLMs"""</span></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model, meta_lr<span class="op">=</span><span class="fl">0.001</span>, inner_lr<span class="op">=</span><span class="fl">0.01</span>):</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> base_model</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.meta_lr <span class="op">=</span> meta_lr</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inner_lr <span class="op">=</span> inner_lr</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.meta_optimizer <span class="op">=</span> torch.optim.Adam(</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.base_model.parameters(), </span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>            lr<span class="op">=</span>meta_lr</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inner_loop_update(<span class="va">self</span>, support_batch):</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Perform inner loop adaptation"""</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Clone model for inner loop</span></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>        adapted_model <span class="op">=</span> <span class="va">self</span>.clone_model()</span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute loss on support set</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>        support_loss <span class="op">=</span> adapted_model(<span class="op">**</span>support_batch).loss</span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute gradients</span></span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>        grads <span class="op">=</span> torch.autograd.grad(</span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>            support_loss, </span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>            adapted_model.parameters(),</span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>            create_graph<span class="op">=</span><span class="va">True</span></span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update parameters</span></span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param, grad <span class="kw">in</span> <span class="bu">zip</span>(adapted_model.parameters(), grads):</span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>            param.data <span class="op">=</span> param.data <span class="op">-</span> <span class="va">self</span>.inner_lr <span class="op">*</span> grad</span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> adapted_model</span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> meta_update(<span class="va">self</span>, task_batch):</span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Perform meta-learning update"""</span></span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a>        meta_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> task <span class="kw">in</span> task_batch:</span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Inner loop adaptation</span></span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a>            adapted_model <span class="op">=</span> <span class="va">self</span>.inner_loop_update(task[<span class="st">'support'</span>])</span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute loss on query set</span></span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a>            query_loss <span class="op">=</span> adapted_model(<span class="op">**</span>task[<span class="st">'query'</span>]).loss</span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a>            meta_loss <span class="op">+=</span> query_loss</span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Meta gradient step</span></span>
<span id="cb19-48"><a href="#cb19-48" aria-hidden="true" tabindex="-1"></a>        meta_loss <span class="op">/=</span> <span class="bu">len</span>(task_batch)</span>
<span id="cb19-49"><a href="#cb19-49" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.meta_optimizer.zero_grad()</span>
<span id="cb19-50"><a href="#cb19-50" aria-hidden="true" tabindex="-1"></a>        meta_loss.backward()</span>
<span id="cb19-51"><a href="#cb19-51" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.meta_optimizer.step()</span>
<span id="cb19-52"><a href="#cb19-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-53"><a href="#cb19-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> meta_loss</span>
<span id="cb19-54"><a href="#cb19-54" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-55"><a href="#cb19-55" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> clone_model(<span class="va">self</span>):</span>
<span id="cb19-56"><a href="#cb19-56" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create a copy of the model for inner loop"""</span></span>
<span id="cb19-57"><a href="#cb19-57" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would create a functional copy</span></span>
<span id="cb19-58"><a href="#cb19-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</div>
<p><strong>Continual Learning</strong>: Developing methods for lifelong learning without forgetting.</p>
</section>
<section id="application-domains" class="level3">
<h3 class="anchored" data-anchor-id="application-domains" id="application-domains">Application Domains</h3>
<p><strong>Embodied AI</strong>: Integration with robotics for real-world interaction.</p>
<div id="1761ad86" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EmbodiedVLMAgent:</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""VLM agent for embodied AI applications"""</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vlm_model, action_decoder, environment_interface):</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vlm_model <span class="op">=</span> vlm_model</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.action_decoder <span class="op">=</span> action_decoder</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.environment <span class="op">=</span> environment_interface</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.memory <span class="op">=</span> []</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> perceive_and_act(<span class="va">self</span>, observation):</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Main perception-action loop"""</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process visual observation</span></span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>        visual_features <span class="op">=</span> <span class="va">self</span>.vlm_model.encode_image(observation[<span class="st">'image'</span>])</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process textual instruction</span></span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'instruction'</span> <span class="kw">in</span> observation:</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>            text_features <span class="op">=</span> <span class="va">self</span>.vlm_model.encode_text(observation[<span class="st">'instruction'</span>])</span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Fuse multimodal information</span></span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>            fused_features <span class="op">=</span> <span class="va">self</span>.vlm_model.fuse_modalities(</span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>                visual_features, text_features</span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>            fused_features <span class="op">=</span> visual_features</span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate action</span></span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>        action_logits <span class="op">=</span> <span class="va">self</span>.action_decoder(fused_features)</span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>        action <span class="op">=</span> torch.argmax(action_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Store in memory for future learning</span></span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.memory.append({</span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>            <span class="st">'observation'</span>: observation,</span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>            <span class="st">'action'</span>: action,</span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'features'</span>: fused_features</span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> action</span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> learn_from_interaction(<span class="va">self</span>):</span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Learn from stored interactions"""</span></span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(<span class="va">self</span>.memory) <span class="op">&lt;</span> <span class="dv">100</span>:  <span class="co"># Minimum batch size</span></span>
<span id="cb20-42"><a href="#cb20-42" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span></span>
<span id="cb20-43"><a href="#cb20-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-44"><a href="#cb20-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Sample batch from memory</span></span>
<span id="cb20-45"><a href="#cb20-45" aria-hidden="true" tabindex="-1"></a>        batch <span class="op">=</span> random.sample(<span class="va">self</span>.memory, <span class="dv">32</span>)</span>
<span id="cb20-46"><a href="#cb20-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-47"><a href="#cb20-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implement learning algorithm (e.g., reinforcement learning)</span></span>
<span id="cb20-48"><a href="#cb20-48" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.update_policy(batch)</span>
<span id="cb20-49"><a href="#cb20-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-50"><a href="#cb20-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update_policy(<span class="va">self</span>, batch):</span>
<span id="cb20-51"><a href="#cb20-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update policy based on interaction data"""</span></span>
<span id="cb20-52"><a href="#cb20-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would depend on specific RL algorithm</span></span>
<span id="cb20-53"><a href="#cb20-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</div>
<p><strong>Creative Applications</strong>: Advanced content generation for art, design, and entertainment.</p>
<p><strong>Scientific Discovery</strong>: Automated analysis and insight generation from scientific imagery.</p>
</section>
<section id="ethical-considerations" class="level3">
<h3 class="anchored" data-anchor-id="ethical-considerations" id="ethical-considerations">Ethical Considerations</h3>
<p><strong>Bias Mitigation</strong>: Developing techniques to reduce harmful biases in multimodal models.</p>
<div id="b290ffe8" class="cell" data-execution_count="21">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BiasAuditingFramework:</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Framework for auditing bias in VLMs"""</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, protected_attributes<span class="op">=</span>[<span class="st">'gender'</span>, <span class="st">'race'</span>, <span class="st">'age'</span>]):</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.protected_attributes <span class="op">=</span> protected_attributes</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bias_metrics <span class="op">=</span> {}</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> measure_representation_bias(<span class="va">self</span>, model, dataset):</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Measure bias in data representation"""</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        attribute_counts <span class="op">=</span> {attr: {} <span class="cf">for</span> attr <span class="kw">in</span> <span class="va">self</span>.protected_attributes}</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> dataset:</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Analyze demographic representation in images</span></span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>            detected_attributes <span class="op">=</span> <span class="va">self</span>.detect_demographic_attributes(</span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>                batch[<span class="st">'images'</span>]</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> attr <span class="kw">in</span> <span class="va">self</span>.protected_attributes:</span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> value <span class="kw">in</span> detected_attributes[attr]:</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> value <span class="kw">not</span> <span class="kw">in</span> attribute_counts[attr]:</span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a>                        attribute_counts[attr][value] <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>                    attribute_counts[attr][value] <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> attribute_counts</span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> measure_performance_bias(<span class="va">self</span>, model, test_sets_by_group):</span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Measure performance differences across demographic groups"""</span></span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>        performance_by_group <span class="op">=</span> {}</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> group, test_set <span class="kw">in</span> test_sets_by_group.items():</span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Evaluate model performance on each group</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>            metrics <span class="op">=</span> <span class="va">self</span>.evaluate_model_performance(model, test_set)</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>            performance_by_group[group] <span class="op">=</span> metrics</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate disparate impact</span></span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>        disparate_impact <span class="op">=</span> <span class="va">self</span>.calculate_disparate_impact(performance_by_group)</span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> performance_by_group, disparate_impact</span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> detect_demographic_attributes(<span class="va">self</span>, images):</span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Detect demographic attributes in images"""</span></span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># This would use specialized models for demographic analysis</span></span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation should be careful about privacy and consent</span></span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_bias_report(<span class="va">self</span>, model, datasets):</span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate comprehensive bias audit report"""</span></span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a>        report <span class="op">=</span> {</span>
<span id="cb21-49"><a href="#cb21-49" aria-hidden="true" tabindex="-1"></a>            <span class="st">'representation_bias'</span>: <span class="va">self</span>.measure_representation_bias(</span>
<span id="cb21-50"><a href="#cb21-50" aria-hidden="true" tabindex="-1"></a>                model, datasets[<span class="st">'train'</span>]</span>
<span id="cb21-51"><a href="#cb21-51" aria-hidden="true" tabindex="-1"></a>            ),</span>
<span id="cb21-52"><a href="#cb21-52" aria-hidden="true" tabindex="-1"></a>            <span class="st">'performance_bias'</span>: <span class="va">self</span>.measure_performance_bias(</span>
<span id="cb21-53"><a href="#cb21-53" aria-hidden="true" tabindex="-1"></a>                model, datasets[<span class="st">'test_by_group'</span>]</span>
<span id="cb21-54"><a href="#cb21-54" aria-hidden="true" tabindex="-1"></a>            ),</span>
<span id="cb21-55"><a href="#cb21-55" aria-hidden="true" tabindex="-1"></a>            <span class="st">'recommendations'</span>: <span class="va">self</span>.generate_mitigation_recommendations()</span>
<span id="cb21-56"><a href="#cb21-56" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb21-57"><a href="#cb21-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-58"><a href="#cb21-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> report</span>
<span id="cb21-59"><a href="#cb21-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-60"><a href="#cb21-60" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BiasMitigationTraining:</span>
<span id="cb21-61"><a href="#cb21-61" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Training framework with bias mitigation"""</span></span>
<span id="cb21-62"><a href="#cb21-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-63"><a href="#cb21-63" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, fairness_constraints):</span>
<span id="cb21-64"><a href="#cb21-64" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb21-65"><a href="#cb21-65" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fairness_constraints <span class="op">=</span> fairness_constraints</span>
<span id="cb21-66"><a href="#cb21-66" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-67"><a href="#cb21-67" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> adversarial_debiasing_loss(<span class="va">self</span>, outputs, protected_attributes):</span>
<span id="cb21-68"><a href="#cb21-68" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Adversarial loss for bias mitigation"""</span></span>
<span id="cb21-69"><a href="#cb21-69" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Train adversarial classifier to predict protected attributes</span></span>
<span id="cb21-70"><a href="#cb21-70" aria-hidden="true" tabindex="-1"></a>        adversarial_logits <span class="op">=</span> <span class="va">self</span>.adversarial_classifier(outputs.hidden_states)</span>
<span id="cb21-71"><a href="#cb21-71" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-72"><a href="#cb21-72" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Loss encourages representations that can't predict protected attributes</span></span>
<span id="cb21-73"><a href="#cb21-73" aria-hidden="true" tabindex="-1"></a>        adversarial_loss <span class="op">=</span> <span class="op">-</span>F.cross_entropy(</span>
<span id="cb21-74"><a href="#cb21-74" aria-hidden="true" tabindex="-1"></a>            adversarial_logits, </span>
<span id="cb21-75"><a href="#cb21-75" aria-hidden="true" tabindex="-1"></a>            protected_attributes</span>
<span id="cb21-76"><a href="#cb21-76" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb21-77"><a href="#cb21-77" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-78"><a href="#cb21-78" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> adversarial_loss</span>
<span id="cb21-79"><a href="#cb21-79" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-80"><a href="#cb21-80" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> fairness_regularization_loss(<span class="va">self</span>, predictions, groups):</span>
<span id="cb21-81"><a href="#cb21-81" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Regularization term for fairness"""</span></span>
<span id="cb21-82"><a href="#cb21-82" aria-hidden="true" tabindex="-1"></a>        group_losses <span class="op">=</span> {}</span>
<span id="cb21-83"><a href="#cb21-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-84"><a href="#cb21-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> group <span class="kw">in</span> torch.unique(groups):</span>
<span id="cb21-85"><a href="#cb21-85" aria-hidden="true" tabindex="-1"></a>            group_mask <span class="op">=</span> (groups <span class="op">==</span> group)</span>
<span id="cb21-86"><a href="#cb21-86" aria-hidden="true" tabindex="-1"></a>            group_predictions <span class="op">=</span> predictions[group_mask]</span>
<span id="cb21-87"><a href="#cb21-87" aria-hidden="true" tabindex="-1"></a>            group_losses[group.item()] <span class="op">=</span> F.mse_loss(</span>
<span id="cb21-88"><a href="#cb21-88" aria-hidden="true" tabindex="-1"></a>                group_predictions, </span>
<span id="cb21-89"><a href="#cb21-89" aria-hidden="true" tabindex="-1"></a>                torch.ones_like(group_predictions) <span class="op">*</span> <span class="fl">0.5</span></span>
<span id="cb21-90"><a href="#cb21-90" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb21-91"><a href="#cb21-91" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-92"><a href="#cb21-92" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Minimize difference in group losses</span></span>
<span id="cb21-93"><a href="#cb21-93" aria-hidden="true" tabindex="-1"></a>        loss_values <span class="op">=</span> <span class="bu">list</span>(group_losses.values())</span>
<span id="cb21-94"><a href="#cb21-94" aria-hidden="true" tabindex="-1"></a>        fairness_loss <span class="op">=</span> torch.var(torch.stack(loss_values))</span>
<span id="cb21-95"><a href="#cb21-95" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-96"><a href="#cb21-96" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> fairness_loss</span>
<span id="cb21-97"><a href="#cb21-97" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-98"><a href="#cb21-98" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step_with_fairness(<span class="va">self</span>, batch):</span>
<span id="cb21-99"><a href="#cb21-99" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Training step with fairness constraints"""</span></span>
<span id="cb21-100"><a href="#cb21-100" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Standard model forward pass</span></span>
<span id="cb21-101"><a href="#cb21-101" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb21-102"><a href="#cb21-102" aria-hidden="true" tabindex="-1"></a>        standard_loss <span class="op">=</span> outputs.loss</span>
<span id="cb21-103"><a href="#cb21-103" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-104"><a href="#cb21-104" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Fairness-aware losses</span></span>
<span id="cb21-105"><a href="#cb21-105" aria-hidden="true" tabindex="-1"></a>        adversarial_loss <span class="op">=</span> <span class="va">self</span>.adversarial_debiasing_loss(</span>
<span id="cb21-106"><a href="#cb21-106" aria-hidden="true" tabindex="-1"></a>            outputs, batch[<span class="st">'protected_attributes'</span>]</span>
<span id="cb21-107"><a href="#cb21-107" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb21-108"><a href="#cb21-108" aria-hidden="true" tabindex="-1"></a>        fairness_loss <span class="op">=</span> <span class="va">self</span>.fairness_regularization_loss(</span>
<span id="cb21-109"><a href="#cb21-109" aria-hidden="true" tabindex="-1"></a>            outputs.logits, batch[<span class="st">'groups'</span>]</span>
<span id="cb21-110"><a href="#cb21-110" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb21-111"><a href="#cb21-111" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-112"><a href="#cb21-112" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combined loss</span></span>
<span id="cb21-113"><a href="#cb21-113" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> (standard_loss <span class="op">+</span> </span>
<span id="cb21-114"><a href="#cb21-114" aria-hidden="true" tabindex="-1"></a>                     <span class="fl">0.1</span> <span class="op">*</span> adversarial_loss <span class="op">+</span> </span>
<span id="cb21-115"><a href="#cb21-115" aria-hidden="true" tabindex="-1"></a>                     <span class="fl">0.1</span> <span class="op">*</span> fairness_loss)</span>
<span id="cb21-116"><a href="#cb21-116" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-117"><a href="#cb21-117" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss</span></code></pre></div></div>
</div>
<p><strong>Fairness and Inclusivity</strong>: Ensuring equitable performance across different demographic groups.</p>
<p><strong>Privacy and Security</strong>: Protecting sensitive information in multimodal datasets and models.</p>
<div id="04697a52" class="cell" data-execution_count="22">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PrivacyPreservingVLM:</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Privacy-preserving techniques for VLMs"""</span></span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, privacy_budget<span class="op">=</span><span class="fl">1.0</span>):</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.privacy_budget <span class="op">=</span> privacy_budget</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.noise_multiplier <span class="op">=</span> <span class="va">self</span>.calculate_noise_multiplier()</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> differential_private_training(<span class="va">self</span>, dataloader):</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train with differential privacy guarantees"""</span></span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> opacus <span class="im">import</span> PrivacyEngine</span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>        privacy_engine <span class="op">=</span> PrivacyEngine()</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a>        model, optimizer, dataloader <span class="op">=</span> privacy_engine.make_private_with_epsilon(</span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>            module<span class="op">=</span><span class="va">self</span>.model,</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>            optimizer<span class="op">=</span>torch.optim.AdamW(<span class="va">self</span>.model.parameters()),</span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>            data_loader<span class="op">=</span>dataloader,</span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>            epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a>            target_epsilon<span class="op">=</span><span class="va">self</span>.privacy_budget,</span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a>            target_delta<span class="op">=</span><span class="fl">1e-5</span>,</span>
<span id="cb22-21"><a href="#cb22-21" aria-hidden="true" tabindex="-1"></a>            max_grad_norm<span class="op">=</span><span class="fl">1.0</span>,</span>
<span id="cb22-22"><a href="#cb22-22" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb22-23"><a href="#cb22-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-24"><a href="#cb22-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> model, optimizer, dataloader</span>
<span id="cb22-25"><a href="#cb22-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-26"><a href="#cb22-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> federated_learning_setup(<span class="va">self</span>, client_data):</span>
<span id="cb22-27"><a href="#cb22-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup for federated learning"""</span></span>
<span id="cb22-28"><a href="#cb22-28" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> flwr <span class="im">import</span> fl</span>
<span id="cb22-29"><a href="#cb22-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-30"><a href="#cb22-30" aria-hidden="true" tabindex="-1"></a>        <span class="kw">class</span> VLMClient(fl.client.NumPyClient):</span>
<span id="cb22-31"><a href="#cb22-31" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, trainloader, valloader):</span>
<span id="cb22-32"><a href="#cb22-32" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb22-33"><a href="#cb22-33" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.trainloader <span class="op">=</span> trainloader</span>
<span id="cb22-34"><a href="#cb22-34" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.valloader <span class="op">=</span> valloader</span>
<span id="cb22-35"><a href="#cb22-35" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb22-36"><a href="#cb22-36" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> get_parameters(<span class="va">self</span>, config):</span>
<span id="cb22-37"><a href="#cb22-37" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> [val.cpu().numpy() <span class="cf">for</span> _, val <span class="kw">in</span> <span class="va">self</span>.model.state_dict().items()]</span>
<span id="cb22-38"><a href="#cb22-38" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb22-39"><a href="#cb22-39" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> set_parameters(<span class="va">self</span>, parameters):</span>
<span id="cb22-40"><a href="#cb22-40" aria-hidden="true" tabindex="-1"></a>                params_dict <span class="op">=</span> <span class="bu">zip</span>(<span class="va">self</span>.model.state_dict().keys(), parameters)</span>
<span id="cb22-41"><a href="#cb22-41" aria-hidden="true" tabindex="-1"></a>                state_dict <span class="op">=</span> {k: torch.tensor(v) <span class="cf">for</span> k, v <span class="kw">in</span> params_dict}</span>
<span id="cb22-42"><a href="#cb22-42" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.model.load_state_dict(state_dict, strict<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb22-43"><a href="#cb22-43" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb22-44"><a href="#cb22-44" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> fit(<span class="va">self</span>, parameters, config):</span>
<span id="cb22-45"><a href="#cb22-45" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.set_parameters(parameters)</span>
<span id="cb22-46"><a href="#cb22-46" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Train model locally</span></span>
<span id="cb22-47"><a href="#cb22-47" aria-hidden="true" tabindex="-1"></a>                train_loss <span class="op">=</span> <span class="va">self</span>.train()</span>
<span id="cb22-48"><a href="#cb22-48" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="va">self</span>.get_parameters(config<span class="op">=</span>{}), <span class="bu">len</span>(<span class="va">self</span>.trainloader.dataset), {}</span>
<span id="cb22-49"><a href="#cb22-49" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb22-50"><a href="#cb22-50" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> evaluate(<span class="va">self</span>, parameters, config):</span>
<span id="cb22-51"><a href="#cb22-51" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.set_parameters(parameters)</span>
<span id="cb22-52"><a href="#cb22-52" aria-hidden="true" tabindex="-1"></a>                loss, accuracy <span class="op">=</span> <span class="va">self</span>.test()</span>
<span id="cb22-53"><a href="#cb22-53" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> loss, <span class="bu">len</span>(<span class="va">self</span>.valloader.dataset), {<span class="st">"accuracy"</span>: accuracy}</span>
<span id="cb22-54"><a href="#cb22-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb22-55"><a href="#cb22-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> VLMClient</span>
<span id="cb22-56"><a href="#cb22-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-57"><a href="#cb22-57" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> homomorphic_encryption_inference(<span class="va">self</span>, encrypted_input):</span>
<span id="cb22-58"><a href="#cb22-58" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Perform inference on encrypted data"""</span></span>
<span id="cb22-59"><a href="#cb22-59" aria-hidden="true" tabindex="-1"></a>        <span class="co"># This would require specialized libraries like SEAL or HELib</span></span>
<span id="cb22-60"><a href="#cb22-60" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would depend on specific homomorphic encryption scheme</span></span>
<span id="cb22-61"><a href="#cb22-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb22-62"><a href="#cb22-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-63"><a href="#cb22-63" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> secure_multiparty_computation(<span class="va">self</span>, distributed_inputs):</span>
<span id="cb22-64"><a href="#cb22-64" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute on distributed private inputs"""</span></span>
<span id="cb22-65"><a href="#cb22-65" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation would use SMPC frameworks</span></span>
<span id="cb22-66"><a href="#cb22-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Fine-tuning Vision-Language Models represents a powerful approach to creating specialized AI systems that can understand and generate content across visual and textual modalities. Success in this domain requires careful consideration of architectural choices, data preparation strategies, training methodologies, and evaluation protocols.</p>
<p>The field continues to evolve rapidly, with new techniques for parameter-efficient training, improved architectures, and novel applications emerging regularly. By following the principles and practices outlined in this guide, researchers and practitioners can effectively leverage the power of VLMs for their specific use cases while contributing to the advancement of multimodal AI.</p>
<p>As we move forward, the integration of vision and language understanding will become increasingly sophisticated, opening new possibilities for human-AI interaction and automated reasoning across diverse domains. The techniques and insights presented here provide a foundation for navigating this exciting and rapidly evolving landscape.</p>
<p>Key takeaways from this comprehensive guide include:</p>
<ol type="1">
<li><strong>Choose the right fine-tuning approach</strong> based on your computational resources and task requirements</li>
<li><strong>Invest in high-quality data preparation</strong> - it’s often more impactful than model architecture changes</li>
<li><strong>Use parameter-efficient methods</strong> like LoRA when full fine-tuning is not feasible</li>
<li><strong>Implement comprehensive evaluation frameworks</strong> that go beyond standard metrics</li>
<li><strong>Consider ethical implications</strong> and implement bias mitigation strategies</li>
<li><strong>Stay updated with emerging techniques</strong> in this rapidly evolving field</li>
</ol>
<p>The future of VLMs holds tremendous promise for advancing AI capabilities across numerous domains, from healthcare and education to creative applications and scientific discovery. By mastering the techniques presented in this guide, you’ll be well-equipped to contribute to this exciting frontier of artificial intelligence.</p>
</section>
<section id="references-and-further-reading" class="level2">
<h2 class="anchored" data-anchor-id="references-and-further-reading" id="references-and-further-reading">References and Further Reading</h2>
<p>For the most current research and developments in VLM fine-tuning, consider exploring:</p>
<ul>
<li>Recent papers on parameter-efficient fine-tuning methods</li>
<li>Benchmark datasets and evaluation frameworks</li>
<li>Open-source implementations and model repositories</li>
<li>Community forums and discussion groups</li>
<li>Academic conferences (NeurIPS, ICML, ICLR, CVPR, ACL)</li>
</ul>
<hr>
<p><em>This guide provides a comprehensive overview of VLM fine-tuning as of early 2025. Given the rapid pace of development in this field, readers are encouraged to stay updated with the latest research and best practices through academic publications and community resources.</em></p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[LoRA for Vision-Language Models: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/</guid>
      <pubDate>Sat, 02 Aug 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="lora-for-vision-language-models-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/lora.png" class="img-fluid"></p>
<section id="abstract" class="level2">
<h2 class="anchored" data-anchor-id="abstract" id="abstract">Abstract</h2>
<p>Low-Rank Adaptation (LoRA) has emerged as a revolutionary technique for efficient fine-tuning of large language models, and its application to Vision-Language Models (VLMs) represents a significant advancement in multimodal AI. This comprehensive guide provides theoretical foundations, practical implementation strategies, and production deployment techniques for LoRA in VLMs, covering everything from basic concepts to advanced optimization methods.</p>
</section>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Vision-Language Models like CLIP, BLIP, LLaVA, and GPT-4V contain billions of parameters, making full fine-tuning computationally expensive and memory-intensive. LoRA addresses these challenges by:</p>
<ul>
<li><strong>Reducing memory requirements</strong> by up to 90%</li>
<li><strong>Accelerating training</strong> by 2-3x</li>
<li><strong>Maintaining model performance</strong> with minimal parameter overhead</li>
<li><strong>Enabling modular adaptation</strong> for different tasks and domains</li>
</ul>
<section id="why-lora-for-vlms" class="level3">
<h3 class="anchored" data-anchor-id="why-lora-for-vlms" id="why-lora-for-vlms">Why LoRA for VLMs?</h3>
<div id="cell-fig-lora-benefits" class="cell" data-execution_count="1">
<div class="cell-output cell-output-display">
<div id="fig-lora-benefits" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-lora-benefits-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/fig-lora-benefits-output-1.png" width="950" height="566" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-lora-benefits-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: LoRA Benefits Comparison
</figcaption>
</figure>
</div>
</div>
</div>
</section>
</section>
<section id="understanding-lora" class="level2">
<h2 class="anchored" data-anchor-id="understanding-lora" id="understanding-lora">Understanding LoRA</h2>
<section id="core-principles" class="level3">
<h3 class="anchored" data-anchor-id="core-principles" id="core-principles">Core Principles</h3>
<p>LoRA is based on the hypothesis that weight updates during fine-tuning have a low intrinsic rank. Instead of updating all parameters, LoRA decomposes the weight update matrix into two smaller matrices:</p>
<p><span class="math display">\[\Delta W = BA\]</span></p>
<p>Where:</p>
<ul>
<li><span class="math inline">\(W\)</span> is the original weight matrix (<span class="math inline">\(d \times d\)</span>)</li>
<li><span class="math inline">\(B\)</span> is a learnable matrix (<span class="math inline">\(d \times r\)</span>)<br>
</li>
<li><span class="math inline">\(A\)</span> is a learnable matrix (<span class="math inline">\(r \times d\)</span>)</li>
<li><span class="math inline">\(r\)</span> is the rank (<span class="math inline">\(r \ll d\)</span>)</li>
</ul>
</section>
<section id="mathematical-foundation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-foundation" id="mathematical-foundation">Mathematical Foundation</h3>
<p>For a linear layer with weight matrix <span class="math inline">\(W_0\)</span>, the forward pass becomes:</p>
<p><span class="math display">\[h = W_0x + \Delta Wx = W_0x + BAx\]</span></p>
<p>The adapted weight matrix is: <span class="math display">\[W = W_0 + \alpha BA\]</span></p>
<p>Where <span class="math inline">\(\alpha\)</span> is a scaling factor that controls the magnitude of the adaptation.</p>
<div id="lora-implementation" class="cell" data-execution_count="2">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRALayer(nn.Module):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_features, out_features, rank<span class="op">=</span><span class="dv">16</span>, alpha<span class="op">=</span><span class="dv">16</span>, dropout<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rank <span class="op">=</span> rank</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> alpha</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scaling <span class="op">=</span> alpha <span class="op">/</span> rank</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># LoRA matrices</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_A <span class="op">=</span> nn.Linear(in_features, rank, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_B <span class="op">=</span> nn.Linear(rank, out_features, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize weights</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>        nn.init.kaiming_uniform_(<span class="va">self</span>.lora_A.weight, a<span class="op">=</span>math.sqrt(<span class="dv">5</span>))</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>        nn.init.zeros_(<span class="va">self</span>.lora_B.weight)</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.lora_A(x)</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.dropout(result)</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.lora_B(result)</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result <span class="op">*</span> <span class="va">self</span>.scaling</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRALinear(nn.Module):</span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, original_layer, rank<span class="op">=</span><span class="dv">16</span>, alpha<span class="op">=</span><span class="dv">16</span>, dropout<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.original_layer <span class="op">=</span> original_layer</span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora <span class="op">=</span> LoRALayer(</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>            original_layer.in_features,</span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>            original_layer.out_features,</span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>            rank, alpha, dropout</span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Freeze original weights</span></span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.original_layer.parameters():</span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.original_layer(x) <span class="op">+</span> <span class="va">self</span>.lora(x)</span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>original_linear <span class="op">=</span> nn.Linear(<span class="dv">768</span>, <span class="dv">768</span>)</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a>lora_linear <span class="op">=</span> LoRALinear(original_linear, rank<span class="op">=</span><span class="dv">16</span>, alpha<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Original parameters: </span><span class="sc">{</span><span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> original_linear.parameters())<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"LoRA parameters: </span><span class="sc">{</span><span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> lora_linear.lora.parameters())<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Parameter reduction: </span><span class="sc">{</span>(<span class="dv">1</span> <span class="op">-</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> lora_linear.lora.parameters()) <span class="op">/</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> original_linear.parameters())) <span class="op">*</span> <span class="dv">100</span><span class="sc">:.1f}</span><span class="ss">%"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Original parameters: 590592
LoRA parameters: 24576
Parameter reduction: 95.8%</code></pre>
</div>
</div>
</section>
<section id="key-advantages" class="level3">
<h3 class="anchored" data-anchor-id="key-advantages" id="key-advantages">Key Advantages</h3>
<ol type="1">
<li><strong>Parameter Efficiency</strong>: Only trains ~0.1-1% of original parameters</li>
<li><strong>Memory Efficiency</strong>: Reduced GPU memory requirements</li>
<li><strong>Modularity</strong>: Multiple LoRA adapters can be stored and swapped</li>
<li><strong>Preservation</strong>: Original model weights remain unchanged</li>
<li><strong>Composability</strong>: Multiple LoRAs can be combined</li>
</ol>
</section>
</section>
<section id="vision-language-models-overview" class="level2">
<h2 class="anchored" data-anchor-id="vision-language-models-overview" id="vision-language-models-overview">Vision-Language Models Overview</h2>
<section id="architecture-components" class="level3">
<h3 class="anchored" data-anchor-id="architecture-components" id="architecture-components">Architecture Components</h3>
<p>Modern VLMs typically consist of:</p>
<ol type="1">
<li><strong>Vision Encoder</strong>: Processes visual inputs (e.g., Vision Transformer, ResNet)</li>
<li><strong>Text Encoder</strong>: Processes textual inputs (e.g., BERT, GPT)</li>
<li><strong>Multimodal Fusion</strong>: Combines visual and textual representations</li>
<li><strong>Output Head</strong>: Task-specific prediction layers</li>
</ol>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A[Image Input] --&gt; B[Vision&lt;br/&gt;Encoder]
    C[Text Input] --&gt; D[Text&lt;br/&gt;Encoder]
    B --&gt; E[Multimodal&lt;br/&gt;Fusion]
    D --&gt; E
    E --&gt; F[Output&lt;br/&gt;Head]
    F --&gt; G[Predictions]
    
    classDef input fill:#add8e6,stroke:#000,stroke-width:2px
    classDef encoder fill:#90ee90,stroke:#000,stroke-width:2px
    classDef fusion fill:#ffffe0,stroke:#000,stroke-width:2px
    classDef output fill:#f08080,stroke:#000,stroke-width:2px
    classDef prediction fill:#d3d3d3,stroke:#000,stroke-width:2px
    
    class A,C input
    class B,D encoder
    class E fusion
    class F output
    class G prediction
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="popular-vlm-architectures" class="level3">
<h3 class="anchored" data-anchor-id="popular-vlm-architectures" id="popular-vlm-architectures">Popular VLM Architectures</h3>
<section id="clip-contrastive-language-image-pre-training" class="level4">
<h4 class="anchored" data-anchor-id="clip-contrastive-language-image-pre-training">CLIP (Contrastive Language-Image Pre-training)</h4>
<ul>
<li>Dual-encoder architecture</li>
<li>Contrastive learning objective</li>
<li>Strong zero-shot capabilities</li>
</ul>
</section>
<section id="blip-bootstrapping-language-image-pre-training" class="level4">
<h4 class="anchored" data-anchor-id="blip-bootstrapping-language-image-pre-training">BLIP (Bootstrapping Language-Image Pre-training)</h4>
<ul>
<li>Encoder-decoder architecture</li>
<li>Unified vision-language understanding and generation</li>
<li>Bootstrap learning from noisy web data</li>
</ul>
</section>
<section id="llava-large-language-and-vision-assistant" class="level4">
<h4 class="anchored" data-anchor-id="llava-large-language-and-vision-assistant">LLaVA (Large Language and Vision Assistant)</h4>
<ul>
<li>Combines vision encoder with large language model</li>
<li>Instruction tuning for conversational abilities</li>
<li>Strong multimodal reasoning</li>
</ul>
</section>
</section>
</section>
<section id="lora-architecture-for-vlms" class="level2">
<h2 class="anchored" data-anchor-id="lora-architecture-for-vlms" id="lora-architecture-for-vlms">LoRA Architecture for VLMs</h2>
<section id="component-wise-application" class="level3">
<h3 class="anchored" data-anchor-id="component-wise-application" id="component-wise-application">Component-wise Application</h3>
<p>LoRA can be applied to different components of VLMs:</p>
<div id="vlm-lora-adapter" class="cell" data-execution_count="3">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VLMLoRAAdapter:</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, config):</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_layers <span class="op">=</span> {}</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> add_lora_to_attention(<span class="va">self</span>, module_name, attention_layer):</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Add LoRA to attention mechanism"""</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Query, Key, Value projections</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(attention_layer, <span class="st">'q_proj'</span>):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>            attention_layer.q_proj <span class="op">=</span> LoRALinear(</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>                attention_layer.q_proj, </span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>                rank<span class="op">=</span><span class="va">self</span>.config.rank,</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>                alpha<span class="op">=</span><span class="va">self</span>.config.alpha</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(attention_layer, <span class="st">'k_proj'</span>):</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            attention_layer.k_proj <span class="op">=</span> LoRALinear(</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>                attention_layer.k_proj,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>                rank<span class="op">=</span><span class="va">self</span>.config.rank,</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>                alpha<span class="op">=</span><span class="va">self</span>.config.alpha</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(attention_layer, <span class="st">'v_proj'</span>):</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            attention_layer.v_proj <span class="op">=</span> LoRALinear(</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>                attention_layer.v_proj,</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>                rank<span class="op">=</span><span class="va">self</span>.config.rank,</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>                alpha<span class="op">=</span><span class="va">self</span>.config.alpha</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> add_lora_to_mlp(<span class="va">self</span>, module_name, mlp_layer):</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Add LoRA to feed-forward layers"""</span></span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(mlp_layer, <span class="st">'fc1'</span>):</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>            mlp_layer.fc1 <span class="op">=</span> LoRALinear(</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>                mlp_layer.fc1,</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>                rank<span class="op">=</span><span class="va">self</span>.config.rank,</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>                alpha<span class="op">=</span><span class="va">self</span>.config.alpha</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(mlp_layer, <span class="st">'fc2'</span>):</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>            mlp_layer.fc2 <span class="op">=</span> LoRALinear(</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>                mlp_layer.fc2,</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>                rank<span class="op">=</span><span class="va">self</span>.config.rank,</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>                alpha<span class="op">=</span><span class="va">self</span>.config.alpha</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>            )</span></code></pre></div></div>
</details>
</div>
</section>
<section id="layer-selection-strategy" class="level3">
<h3 class="anchored" data-anchor-id="layer-selection-strategy" id="layer-selection-strategy">Layer Selection Strategy</h3>
<p>Not all layers benefit equally from LoRA adaptation:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Priority</th>
<th>Layer Type</th>
<th>Reason</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>High</td>
<td>Final attention layers</td>
<td>Most task-specific representations</td>
</tr>
<tr class="even">
<td>High</td>
<td>Cross-modal attention</td>
<td>Critical for multimodal fusion</td>
</tr>
<tr class="odd">
<td>High</td>
<td>Task-specific output heads</td>
<td>Direct impact on outputs</td>
</tr>
<tr class="even">
<td>Medium</td>
<td>Middle transformer layers</td>
<td>Balanced feature extraction</td>
</tr>
<tr class="odd">
<td>Medium</td>
<td>Feed-forward networks</td>
<td>Non-linear transformations</td>
</tr>
<tr class="even">
<td>Low</td>
<td>Early encoder layers</td>
<td>Generic low-level features</td>
</tr>
<tr class="odd">
<td>Low</td>
<td>Embedding layers</td>
<td>Fixed vocabulary representations</td>
</tr>
</tbody>
</table>
</section>
<section id="rank-selection-guidelines" class="level3">
<h3 class="anchored" data-anchor-id="rank-selection-guidelines" id="rank-selection-guidelines">Rank Selection Guidelines</h3>
<p>The rank <span class="math inline">\(r\)</span> significantly impacts performance and efficiency:</p>
<div id="cell-fig-rank-comparison" class="cell" data-execution_count="4">
<div class="cell-output cell-output-display">
<div id="fig-rank-comparison" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-rank-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/fig-rank-comparison-output-1.png" width="1430" height="564" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-rank-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;2: LoRA Rank vs Performance Trade-off
</figcaption>
</figure>
</div>
</div>
</div>
<p><strong>Rank Selection Guidelines:</strong></p>
<ul>
<li><strong>r = 1-4</strong>: Minimal parameters, suitable for simple adaptations</li>
<li><strong>r = 8-16</strong>: Balanced efficiency and performance for most tasks</li>
<li><strong>r = 32-64</strong>: Higher capacity for complex domain adaptations</li>
<li><strong>r = 128+</strong>: Approaching full fine-tuning, rarely needed</li>
</ul>
</section>
</section>
<section id="configuration-management" class="level2">
<h2 class="anchored" data-anchor-id="configuration-management" id="configuration-management">Configuration Management</h2>
<div id="lora-config" class="cell" data-execution_count="5">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dataclasses <span class="im">import</span> dataclass</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Optional</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="at">@dataclass</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRAConfig:</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Basic LoRA parameters</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    rank: <span class="bu">int</span> <span class="op">=</span> <span class="dv">16</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    alpha: <span class="bu">int</span> <span class="op">=</span> <span class="dv">16</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    dropout: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Target modules</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    target_modules: List[<span class="bu">str</span>] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    vision_target_modules: List[<span class="bu">str</span>] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    text_target_modules: List[<span class="bu">str</span>] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training parameters</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    learning_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1e-4</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    weight_decay: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.01</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    warmup_steps: <span class="bu">int</span> <span class="op">=</span> <span class="dv">500</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Advanced options</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    use_gradient_checkpointing: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    mixed_precision: <span class="bu">bool</span> <span class="op">=</span> <span class="va">True</span></span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    task_type: <span class="bu">str</span> <span class="op">=</span> <span class="st">"multimodal_classification"</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> __post_init__(<span class="va">self</span>):</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.target_modules <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.target_modules <span class="op">=</span> [</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">"q_proj"</span>, <span class="st">"k_proj"</span>, <span class="st">"v_proj"</span>, <span class="st">"o_proj"</span>,</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">"gate_proj"</span>, <span class="st">"up_proj"</span>, <span class="st">"down_proj"</span></span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>            ]</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.vision_target_modules <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.vision_target_modules <span class="op">=</span> [</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>                <span class="st">"qkv"</span>, <span class="st">"proj"</span>, <span class="st">"fc1"</span>, <span class="st">"fc2"</span></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>            ]</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.text_target_modules <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.text_target_modules <span class="op">=</span> [</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>                <span class="st">"q_proj"</span>, <span class="st">"k_proj"</span>, <span class="st">"v_proj"</span>, <span class="st">"dense"</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>            ]</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a><span class="co"># Example configurations for different tasks</span></span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>task_configs <span class="op">=</span> {</span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>    <span class="st">"image_captioning"</span>: LoRAConfig(</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span><span class="dv">32</span>,</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>        alpha<span class="op">=</span><span class="dv">32</span>,</span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>        target_modules<span class="op">=</span>[<span class="st">"q_proj"</span>, <span class="st">"v_proj"</span>, <span class="st">"dense"</span>],</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">"image_captioning"</span></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>    ),</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>    <span class="st">"visual_question_answering"</span>: LoRAConfig(</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>        alpha<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>        target_modules<span class="op">=</span>[<span class="st">"q_proj"</span>, <span class="st">"k_proj"</span>, <span class="st">"v_proj"</span>],</span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">"visual_question_answering"</span></span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>    ),</span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>    <span class="st">"image_classification"</span>: LoRAConfig(</span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span><span class="dv">8</span>,</span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        alpha<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>        target_modules<span class="op">=</span>[<span class="st">"qkv"</span>, <span class="st">"proj"</span>],</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">"image_classification"</span></span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Available task configurations:"</span>)</span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> task, config <span class="kw">in</span> task_configs.items():</span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"- </span><span class="sc">{</span>task<span class="sc">}</span><span class="ss">: rank=</span><span class="sc">{</span>config<span class="sc">.</span>rank<span class="sc">}</span><span class="ss">, alpha=</span><span class="sc">{</span>config<span class="sc">.</span>alpha<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Available task configurations:
- image_captioning: rank=32, alpha=32
- visual_question_answering: rank=16, alpha=16
- image_classification: rank=8, alpha=16</code></pre>
</div>
</div>
</section>
<section id="training-strategies" class="level2">
<h2 class="anchored" data-anchor-id="training-strategies" id="training-strategies">Training Strategies</h2>
<section id="progressive-training" class="level3">
<h3 class="anchored" data-anchor-id="progressive-training" id="progressive-training">1. Progressive Training</h3>
<p>Start with lower ranks and gradually increase:</p>
<div id="progressive-training" class="cell" data-execution_count="6">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ProgressiveLoRATrainer:</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, initial_rank<span class="op">=</span><span class="dv">4</span>, max_rank<span class="op">=</span><span class="dv">32</span>):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.current_rank <span class="op">=</span> initial_rank</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_rank <span class="op">=</span> max_rank</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> expand_rank(<span class="va">self</span>, new_rank):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Expand LoRA rank while preserving learned weights"""</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, module <span class="kw">in</span> <span class="va">self</span>.model.named_modules():</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, LoRALinear):</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>                old_lora <span class="op">=</span> module.lora</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Create new LoRA layer</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>                new_lora <span class="op">=</span> LoRALayer(</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>                    old_lora.lora_A.in_features,</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>                    old_lora.lora_B.out_features,</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>                    rank<span class="op">=</span>new_rank</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Copy existing weights</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>                    new_lora.lora_A.weight[:old_lora.rank] <span class="op">=</span> old_lora.lora_A.weight</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>                    new_lora.lora_B.weight[:, :old_lora.rank] <span class="op">=</span> old_lora.lora_B.weight</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>                module.lora <span class="op">=</span> new_lora</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> progressive_training_schedule(<span class="va">self</span>, num_epochs):</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate progressive training schedule"""</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        schedule <span class="op">=</span> []</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        epochs_per_stage <span class="op">=</span> num_epochs <span class="op">//</span> <span class="dv">3</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Stage 1: Small rank</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        schedule.append({</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'epochs'</span>: epochs_per_stage,</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>            <span class="st">'rank'</span>: <span class="dv">4</span>,</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            <span class="st">'lr'</span>: <span class="fl">1e-3</span>,</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>            <span class="st">'description'</span>: <span class="st">'Initial adaptation with small rank'</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Stage 2: Medium rank</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        schedule.append({</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>            <span class="st">'epochs'</span>: epochs_per_stage,</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">'rank'</span>: <span class="dv">16</span>,</span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">'lr'</span>: <span class="fl">5e-4</span>,</span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'description'</span>: <span class="st">'Expand capacity with medium rank'</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Stage 3: Full rank</span></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>        schedule.append({</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>            <span class="st">'epochs'</span>: num_epochs <span class="op">-</span> <span class="dv">2</span> <span class="op">*</span> epochs_per_stage,</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>            <span class="st">'rank'</span>: <span class="dv">32</span>,</span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>            <span class="st">'lr'</span>: <span class="fl">1e-4</span>,</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>            <span class="st">'description'</span>: <span class="st">'Fine-tune with full rank'</span></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> schedule</span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> ProgressiveLoRATrainer(<span class="va">None</span>)  <span class="co"># Would pass actual model</span></span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>schedule <span class="op">=</span> trainer.progressive_training_schedule(<span class="dv">12</span>)</span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Progressive Training Schedule:"</span>)</span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, stage <span class="kw">in</span> <span class="bu">enumerate</span>(schedule, <span class="dv">1</span>):</span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Stage </span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>stage[<span class="st">'description'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  - Epochs: </span><span class="sc">{</span>stage[<span class="st">'epochs'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-66"><a href="#cb6-66" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  - Rank: </span><span class="sc">{</span>stage[<span class="st">'rank'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-67"><a href="#cb6-67" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  - Learning Rate: </span><span class="sc">{</span>stage[<span class="st">'lr'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-68"><a href="#cb6-68" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Progressive Training Schedule:
Stage 1: Initial adaptation with small rank
  - Epochs: 4
  - Rank: 4
  - Learning Rate: 0.001

Stage 2: Expand capacity with medium rank
  - Epochs: 4
  - Rank: 16
  - Learning Rate: 0.0005

Stage 3: Fine-tune with full rank
  - Epochs: 4
  - Rank: 32
  - Learning Rate: 0.0001
</code></pre>
</div>
</div>
</section>
<section id="multi-stage-training" class="level3">
<h3 class="anchored" data-anchor-id="multi-stage-training" id="multi-stage-training">2. Multi-Stage Training</h3>
<div id="multistage-training" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multi_stage_training(model, train_loader, config):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Multi-stage training strategy:</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">    1. Stage 1: Freeze vision encoder, train text components</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co">    2. Stage 2: Freeze text encoder, train vision components  </span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="co">    3. Stage 3: Joint training with reduced learning rate</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Multi-Stage Training Strategy"</span>)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">40</span>)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stage 1: Text-only training</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Stage 1: Text-only training"</span>)</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Freezing vision encoder"</span>)</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Training text LoRA components"</span>)</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'vision'</span> <span class="kw">in</span> name:</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> <span class="st">'lora'</span> <span class="kw">in</span> name <span class="kw">and</span> <span class="st">'text'</span> <span class="kw">in</span> name:</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">True</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    trainable_params_stage1 <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters() <span class="cf">if</span> p.requires_grad)</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"- Trainable parameters: </span><span class="sc">{</span>trainable_params_stage1<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># train_stage(model, train_loader, epochs=config.stage1_epochs)</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stage 2: Vision-only training</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Stage 2: Vision-only training"</span>)</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Freezing text encoder"</span>)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Training vision LoRA components"</span>)</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'text'</span> <span class="kw">in</span> name:</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> <span class="st">'lora'</span> <span class="kw">in</span> name <span class="kw">and</span> <span class="st">'vision'</span> <span class="kw">in</span> name:</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">True</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>    trainable_params_stage2 <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters() <span class="cf">if</span> p.requires_grad)</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"- Trainable parameters: </span><span class="sc">{</span>trainable_params_stage2<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>    <span class="co"># train_stage(model, train_loader, epochs=config.stage2_epochs)</span></span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stage 3: Joint training</span></span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Stage 3: Joint training"</span>)</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Training all LoRA components"</span>)</span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"- Reduced learning rate for stability"</span>)</span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'lora'</span> <span class="kw">in</span> name:</span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">True</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a>    trainable_params_stage3 <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters() <span class="cf">if</span> p.requires_grad)</span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"- Trainable parameters: </span><span class="sc">{</span>trainable_params_stage3<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>    <span class="co"># train_stage(model, train_loader, epochs=config.stage3_epochs, lr=config.lr * 0.1)</span></span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a><span class="co"># Example configuration</span></span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiStageConfig:</span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stage1_epochs <span class="op">=</span> <span class="dv">3</span></span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stage2_epochs <span class="op">=</span> <span class="dv">3</span></span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stage3_epochs <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lr <span class="op">=</span> <span class="fl">1e-4</span></span>
<span id="cb8-65"><a href="#cb8-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-66"><a href="#cb8-66" aria-hidden="true" tabindex="-1"></a>config <span class="op">=</span> MultiStageConfig()</span>
<span id="cb8-67"><a href="#cb8-67" aria-hidden="true" tabindex="-1"></a><span class="co"># multi_stage_training(None, None, config)  # Would pass actual model and data</span></span></code></pre></div></div>
</details>
</div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="adalora-adaptive-lora" class="level3">
<h3 class="anchored" data-anchor-id="adalora-adaptive-lora" id="adalora-adaptive-lora">1. AdaLoRA (Adaptive LoRA)</h3>
<p>Dynamically adjusts rank based on importance:</p>
<div id="adalora" class="cell" data-execution_count="8">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdaLoRALayer(nn.Module):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_features, out_features, max_rank<span class="op">=</span><span class="dv">64</span>, init_rank<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_rank <span class="op">=</span> max_rank</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.current_rank <span class="op">=</span> init_rank</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Full-rank matrices for potential expansion</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_A <span class="op">=</span> nn.Parameter(torch.zeros(max_rank, in_features))</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_B <span class="op">=</span> nn.Parameter(torch.zeros(out_features, max_rank))</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Importance scores</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.importance_scores <span class="op">=</span> nn.Parameter(torch.ones(max_rank))</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize only active components</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.reset_parameters()</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> reset_parameters(<span class="va">self</span>):</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize parameters"""</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        nn.init.kaiming_uniform_(<span class="va">self</span>.lora_A[:<span class="va">self</span>.current_rank], a<span class="op">=</span>math.sqrt(<span class="dv">5</span>))</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        nn.init.zeros_(<span class="va">self</span>.lora_B[:, :<span class="va">self</span>.current_rank])</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply importance-weighted LoRA</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        active_A <span class="op">=</span> <span class="va">self</span>.lora_A[:<span class="va">self</span>.current_rank] <span class="op">*</span> <span class="va">self</span>.importance_scores[:<span class="va">self</span>.current_rank, <span class="va">None</span>]</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>        active_B <span class="op">=</span> <span class="va">self</span>.lora_B[:, :<span class="va">self</span>.current_rank] <span class="op">*</span> <span class="va">self</span>.importance_scores[<span class="va">None</span>, :<span class="va">self</span>.current_rank]</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">@</span> active_A.T <span class="op">@</span> active_B.T</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update_rank(<span class="va">self</span>, budget_ratio<span class="op">=</span><span class="fl">0.7</span>):</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update rank based on importance scores"""</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> <span class="va">self</span>.importance_scores.<span class="bu">abs</span>()</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        threshold <span class="op">=</span> torch.quantile(scores, <span class="dv">1</span> <span class="op">-</span> budget_ratio)</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        new_rank <span class="op">=</span> (scores <span class="op">&gt;=</span> threshold).<span class="bu">sum</span>().item()</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> new_rank <span class="op">!=</span> <span class="va">self</span>.current_rank:</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Rank updated: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>current_rank<span class="sc">}</span><span class="ss"> -&gt; </span><span class="sc">{</span>new_rank<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.current_rank <span class="op">=</span> new_rank</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> new_rank</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstration of AdaLoRA rank adaptation</span></span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>adalora_layer <span class="op">=</span> AdaLoRALayer(<span class="dv">768</span>, <span class="dv">768</span>, max_rank<span class="op">=</span><span class="dv">64</span>, init_rank<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"AdaLoRA Rank Adaptation Demo:"</span>)</span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Initial rank: </span><span class="sc">{</span>adalora_layer<span class="sc">.</span>current_rank<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate importance score changes</span></span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>adalora_layer.importance_scores.data <span class="op">=</span> torch.rand(<span class="dv">64</span>)  <span class="co"># Random importance scores</span></span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a><span class="co"># Update rank based on importance</span></span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>new_rank <span class="op">=</span> adalora_layer.update_rank(budget_ratio<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"New rank after adaptation: </span><span class="sc">{</span>new_rank<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>AdaLoRA Rank Adaptation Demo:
Initial rank: 16
Rank updated: 16 -&gt; 32
New rank after adaptation: 32</code></pre>
</div>
</div>
</section>
<section id="dora-weight-decomposed-lora" class="level3">
<h3 class="anchored" data-anchor-id="dora-weight-decomposed-lora" id="dora-weight-decomposed-lora">2. DoRA (Weight-Decomposed LoRA)</h3>
<p>Separates magnitude and direction updates:</p>
<div id="dora" class="cell" data-execution_count="9">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DoRALayer(nn.Module):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_features, out_features, rank<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rank <span class="op">=</span> rank</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Standard LoRA components</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_A <span class="op">=</span> nn.Linear(in_features, rank, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_B <span class="op">=</span> nn.Linear(rank, out_features, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Magnitude component</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.magnitude <span class="op">=</span> nn.Parameter(torch.ones(out_features))</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize LoRA weights</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        nn.init.kaiming_uniform_(<span class="va">self</span>.lora_A.weight, a<span class="op">=</span>math.sqrt(<span class="dv">5</span>))</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        nn.init.zeros_(<span class="va">self</span>.lora_B.weight)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, original_weight):</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># LoRA adaptation</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        lora_result <span class="op">=</span> <span class="va">self</span>.lora_B(<span class="va">self</span>.lora_A(x))</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Direction component (normalized)</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        adapted_weight <span class="op">=</span> original_weight <span class="op">+</span> lora_result</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        direction <span class="op">=</span> F.normalize(adapted_weight, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply magnitude scaling</span></span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> direction <span class="op">*</span> <span class="va">self</span>.magnitude.unsqueeze(<span class="dv">0</span>)</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Compare LoRA vs DoRA</span></span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>original_weight <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">768</span>)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">768</span>)</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Standard LoRA</span></span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>lora_layer <span class="op">=</span> LoRALayer(<span class="dv">768</span>, <span class="dv">768</span>, rank<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>lora_output <span class="op">=</span> lora_layer(x)</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a><span class="co"># DoRA</span></span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>dora_layer <span class="op">=</span> DoRALayer(<span class="dv">768</span>, <span class="dv">768</span>, rank<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>dora_output <span class="op">=</span> dora_layer(x, original_weight)</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"LoRA vs DoRA Comparison:"</span>)</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"LoRA output shape: </span><span class="sc">{</span>lora_output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"DoRA output shape: </span><span class="sc">{</span>dora_output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"LoRA output norm: </span><span class="sc">{</span>lora_output<span class="sc">.</span>norm()<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"DoRA output norm: </span><span class="sc">{</span>dora_output<span class="sc">.</span>norm()<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>LoRA vs DoRA Comparison:
LoRA output shape: torch.Size([32, 768])
DoRA output shape: torch.Size([32, 768])
LoRA output norm: 0.0000
DoRA output norm: 5.6569</code></pre>
</div>
</div>
</section>
<section id="mixture-of-loras-molora" class="level3">
<h3 class="anchored" data-anchor-id="mixture-of-loras-molora" id="mixture-of-loras-molora">3. Mixture of LoRAs (MoLoRA)</h3>
<p>Multiple LoRA experts for different aspects:</p>
<div id="molora" class="cell" data-execution_count="10">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MoLoRALayer(nn.Module):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_features, out_features, num_experts<span class="op">=</span><span class="dv">4</span>, rank<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_experts <span class="op">=</span> num_experts</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multiple LoRA experts</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.experts <span class="op">=</span> nn.ModuleList([</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>            LoRALayer(in_features, out_features, rank)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_experts)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gating network</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.gate <span class="op">=</span> nn.Linear(in_features, num_experts)</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute gating weights</span></span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        gate_input <span class="op">=</span> x.mean(dim<span class="op">=</span><span class="dv">1</span>) <span class="cf">if</span> x.dim() <span class="op">&gt;</span> <span class="dv">2</span> <span class="cf">else</span> x</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        gate_weights <span class="op">=</span> F.softmax(<span class="va">self</span>.gate(gate_input), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine expert outputs</span></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        expert_outputs <span class="op">=</span> torch.stack([expert(x) <span class="cf">for</span> expert <span class="kw">in</span> <span class="va">self</span>.experts], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Weighted combination</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> gate_weights.dim() <span class="op">==</span> <span class="dv">2</span>:  <span class="co"># Batch of inputs</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>            gate_weights <span class="op">=</span> gate_weights.T.unsqueeze(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> torch.<span class="bu">sum</span>(gate_weights <span class="op">*</span> expert_outputs, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:  <span class="co"># Single input</span></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> torch.<span class="bu">sum</span>(gate_weights[:, <span class="va">None</span>] <span class="op">*</span> expert_outputs, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstration of MoLoRA</span></span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>molora_layer <span class="op">=</span> MoLoRALayer(<span class="dv">768</span>, <span class="dv">768</span>, num_experts<span class="op">=</span><span class="dv">4</span>, rank<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">768</span>)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>output <span class="op">=</span> molora_layer(x)</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Mixture of LoRAs (MoLoRA) Demo:"</span>)</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Input shape: </span><span class="sc">{</span>x<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Output shape: </span><span class="sc">{</span>output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Number of experts: </span><span class="sc">{</span>molora_layer<span class="sc">.</span>num_experts<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a><span class="co"># Show expert utilization</span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>    gate_weights <span class="op">=</span> F.softmax(molora_layer.gate(x), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>    expert_utilization <span class="op">=</span> gate_weights.mean(dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Expert utilization:"</span>)</span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, util <span class="kw">in</span> <span class="bu">enumerate</span>(expert_utilization):</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Expert </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>util<span class="sc">:.3f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Mixture of LoRAs (MoLoRA) Demo:
Input shape: torch.Size([32, 768])
Output shape: torch.Size([32, 768])
Number of experts: 4
Expert utilization:
  Expert 1: 0.260
  Expert 2: 0.252
  Expert 3: 0.245
  Expert 4: 0.243</code></pre>
</div>
</div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<div id="memory-optimization" class="cell" data-execution_count="11">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MemoryEfficientLoRA:</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> gradient_checkpointing_forward(module, <span class="op">*</span>args):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Custom gradient checkpointing for LoRA layers"""</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> create_custom_forward(module):</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> custom_forward(<span class="op">*</span>inputs):</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> module(<span class="op">*</span>inputs)</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> custom_forward</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.utils.checkpoint.checkpoint(</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            create_custom_forward(module), <span class="op">*</span>args</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> merge_lora_weights(model):</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Merge LoRA weights into base model for inference"""</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>        merged_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, LoRALinear):</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Compute merged weight</span></span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>                lora_weight <span class="op">=</span> module.lora.lora_B.weight <span class="op">@</span> module.lora.lora_A.weight</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>                merged_weight <span class="op">=</span> module.original_layer.weight <span class="op">+</span> lora_weight <span class="op">*</span> module.lora.scaling</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Create merged layer</span></span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>                merged_layer <span class="op">=</span> nn.Linear(</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>                    module.original_layer.in_features,</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>                    module.original_layer.out_features,</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>                    bias<span class="op">=</span>module.original_layer.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span></span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>                merged_layer.weight.data <span class="op">=</span> merged_weight</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> module.original_layer.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>                    merged_layer.bias.data <span class="op">=</span> module.original_layer.bias</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>                merged_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> merged_count</span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_memory_savings(model):</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute memory savings from LoRA"""</span></span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>        total_params <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>        lora_params <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>            total_params <span class="op">+=</span> param.numel()</span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">'lora'</span> <span class="kw">in</span> name:</span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a>                lora_params <span class="op">+=</span> param.numel()</span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>        savings_ratio <span class="op">=</span> <span class="dv">1</span> <span class="op">-</span> (lora_params <span class="op">/</span> total_params)</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>            <span class="st">'total_parameters'</span>: total_params,</span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>            <span class="st">'lora_parameters'</span>: lora_params,</span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a>            <span class="st">'base_parameters'</span>: total_params <span class="op">-</span> lora_params,</span>
<span id="cb15-56"><a href="#cb15-56" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_savings'</span>: savings_ratio,</span>
<span id="cb15-57"><a href="#cb15-57" aria-hidden="true" tabindex="-1"></a>            <span class="st">'compression_ratio'</span>: total_params <span class="op">/</span> lora_params <span class="cf">if</span> lora_params <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="bu">float</span>(<span class="st">'inf'</span>)</span>
<span id="cb15-58"><a href="#cb15-58" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb15-59"><a href="#cb15-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-60"><a href="#cb15-60" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstrate memory optimization</span></span>
<span id="cb15-61"><a href="#cb15-61" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> MemoryEfficientLoRA()</span>
<span id="cb15-62"><a href="#cb15-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-63"><a href="#cb15-63" aria-hidden="true" tabindex="-1"></a><span class="co"># Example memory analysis (would use real model)</span></span>
<span id="cb15-64"><a href="#cb15-64" aria-hidden="true" tabindex="-1"></a>example_stats <span class="op">=</span> {</span>
<span id="cb15-65"><a href="#cb15-65" aria-hidden="true" tabindex="-1"></a>    <span class="st">'total_parameters'</span>: <span class="dv">175_000_000</span>,</span>
<span id="cb15-66"><a href="#cb15-66" aria-hidden="true" tabindex="-1"></a>    <span class="st">'lora_parameters'</span>: <span class="dv">1_750_000</span>,</span>
<span id="cb15-67"><a href="#cb15-67" aria-hidden="true" tabindex="-1"></a>    <span class="st">'base_parameters'</span>: <span class="dv">173_250_000</span>,</span>
<span id="cb15-68"><a href="#cb15-68" aria-hidden="true" tabindex="-1"></a>    <span class="st">'memory_savings'</span>: <span class="fl">0.99</span>,</span>
<span id="cb15-69"><a href="#cb15-69" aria-hidden="true" tabindex="-1"></a>    <span class="st">'compression_ratio'</span>: <span class="dv">100</span></span>
<span id="cb15-70"><a href="#cb15-70" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb15-71"><a href="#cb15-71" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-72"><a href="#cb15-72" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Memory Optimization Analysis:"</span>)</span>
<span id="cb15-73"><a href="#cb15-73" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Total parameters: </span><span class="sc">{</span>example_stats[<span class="st">'total_parameters'</span>]<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb15-74"><a href="#cb15-74" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"LoRA parameters: </span><span class="sc">{</span>example_stats[<span class="st">'lora_parameters'</span>]<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb15-75"><a href="#cb15-75" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Memory savings: </span><span class="sc">{</span>example_stats[<span class="st">'memory_savings'</span>]<span class="sc">:.1%}</span><span class="ss">"</span>)</span>
<span id="cb15-76"><a href="#cb15-76" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Compression ratio: </span><span class="sc">{</span>example_stats[<span class="st">'compression_ratio'</span>]<span class="sc">:.1f}</span><span class="ss">x"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Memory Optimization Analysis:
Total parameters: 175,000,000
LoRA parameters: 1,750,000
Memory savings: 99.0%
Compression ratio: 100.0x</code></pre>
</div>
</div>
</section>
<section id="training-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="training-optimizations" id="training-optimizations">Training Optimizations</h3>
<div id="training-optimization" class="cell" data-execution_count="12">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedLoRATrainer:</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, config):</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Separate parameter groups</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.setup_parameter_groups()</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mixed precision training</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler <span class="op">=</span> torch.cuda.amp.GradScaler()</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler <span class="op">=</span> <span class="va">None</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_parameter_groups(<span class="va">self</span>):</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Separate LoRA and non-LoRA parameters"""</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        lora_params <span class="op">=</span> []</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>        other_params <span class="op">=</span> []</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, param <span class="kw">in</span> <span class="va">self</span>.model.named_parameters():</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> param.requires_grad:</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="st">'lora'</span> <span class="kw">in</span> name:</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>                    lora_params.append(param)</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>                <span class="cf">else</span>:</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>                    other_params.append(param)</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.param_groups <span class="op">=</span> [</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>            {</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">'params'</span>: lora_params, </span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">'lr'</span>: <span class="bu">getattr</span>(<span class="va">self</span>.config, <span class="st">'lora_lr'</span>, <span class="fl">1e-4</span>), </span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">'weight_decay'</span>: <span class="fl">0.01</span>,</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>                <span class="st">'name'</span>: <span class="st">'lora_params'</span></span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>            {</span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>                <span class="st">'params'</span>: other_params, </span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>                <span class="st">'lr'</span>: <span class="bu">getattr</span>(<span class="va">self</span>.config, <span class="st">'base_lr'</span>, <span class="fl">1e-5</span>), </span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>                <span class="st">'weight_decay'</span>: <span class="fl">0.1</span>,</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>                <span class="st">'name'</span>: <span class="st">'base_params'</span></span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-42"><a href="#cb17-42" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Parameter Groups Setup:"</span>)</span>
<span id="cb17-43"><a href="#cb17-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> group <span class="kw">in</span> <span class="va">self</span>.param_groups:</span>
<span id="cb17-44"><a href="#cb17-44" aria-hidden="true" tabindex="-1"></a>            param_count <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> group[<span class="st">'params'</span>])</span>
<span id="cb17-45"><a href="#cb17-45" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>group[<span class="st">'name'</span>]<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>param_count<span class="sc">:,}</span><span class="ss"> parameters, lr=</span><span class="sc">{</span>group[<span class="st">'lr'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb17-46"><a href="#cb17-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-47"><a href="#cb17-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, optimizer):</span>
<span id="cb17-48"><a href="#cb17-48" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Optimized training step with mixed precision"""</span></span>
<span id="cb17-49"><a href="#cb17-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.scaler <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb17-50"><a href="#cb17-50" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Mixed precision training</span></span>
<span id="cb17-51"><a href="#cb17-51" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb17-52"><a href="#cb17-52" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb17-53"><a href="#cb17-53" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> outputs.loss <span class="cf">if</span> <span class="bu">hasattr</span>(outputs, <span class="st">'loss'</span>) <span class="cf">else</span> outputs</span>
<span id="cb17-54"><a href="#cb17-54" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-55"><a href="#cb17-55" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Scaled backward pass</span></span>
<span id="cb17-56"><a href="#cb17-56" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.scale(loss).backward()</span>
<span id="cb17-57"><a href="#cb17-57" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-58"><a href="#cb17-58" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping for LoRA parameters only</span></span>
<span id="cb17-59"><a href="#cb17-59" aria-hidden="true" tabindex="-1"></a>            lora_params <span class="op">=</span> [p <span class="cf">for</span> group <span class="kw">in</span> <span class="va">self</span>.param_groups </span>
<span id="cb17-60"><a href="#cb17-60" aria-hidden="true" tabindex="-1"></a>                          <span class="cf">for</span> p <span class="kw">in</span> group[<span class="st">'params'</span>] <span class="cf">if</span> group[<span class="st">'name'</span>] <span class="op">==</span> <span class="st">'lora_params'</span>]</span>
<span id="cb17-61"><a href="#cb17-61" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-62"><a href="#cb17-62" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.unscale_(optimizer)</span>
<span id="cb17-63"><a href="#cb17-63" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(lora_params, max_norm<span class="op">=</span><span class="fl">1.0</span>)</span>
<span id="cb17-64"><a href="#cb17-64" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-65"><a href="#cb17-65" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.step(optimizer)</span>
<span id="cb17-66"><a href="#cb17-66" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.update()</span>
<span id="cb17-67"><a href="#cb17-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb17-68"><a href="#cb17-68" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Regular training</span></span>
<span id="cb17-69"><a href="#cb17-69" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(<span class="op">**</span>batch)</span>
<span id="cb17-70"><a href="#cb17-70" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> outputs.loss <span class="cf">if</span> <span class="bu">hasattr</span>(outputs, <span class="st">'loss'</span>) <span class="cf">else</span> outputs</span>
<span id="cb17-71"><a href="#cb17-71" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-72"><a href="#cb17-72" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb17-73"><a href="#cb17-73" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-74"><a href="#cb17-74" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping</span></span>
<span id="cb17-75"><a href="#cb17-75" aria-hidden="true" tabindex="-1"></a>            lora_params <span class="op">=</span> [p <span class="cf">for</span> group <span class="kw">in</span> <span class="va">self</span>.param_groups </span>
<span id="cb17-76"><a href="#cb17-76" aria-hidden="true" tabindex="-1"></a>                          <span class="cf">for</span> p <span class="kw">in</span> group[<span class="st">'params'</span>] <span class="cf">if</span> group[<span class="st">'name'</span>] <span class="op">==</span> <span class="st">'lora_params'</span>]</span>
<span id="cb17-77"><a href="#cb17-77" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(lora_params, max_norm<span class="op">=</span><span class="fl">1.0</span>)</span>
<span id="cb17-78"><a href="#cb17-78" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-79"><a href="#cb17-79" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb17-80"><a href="#cb17-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-81"><a href="#cb17-81" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb17-82"><a href="#cb17-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss.item() <span class="cf">if</span> <span class="bu">hasattr</span>(loss, <span class="st">'item'</span>) <span class="cf">else</span> loss</span>
<span id="cb17-83"><a href="#cb17-83" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-84"><a href="#cb17-84" aria-hidden="true" tabindex="-1"></a><span class="co"># Example configuration</span></span>
<span id="cb17-85"><a href="#cb17-85" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TrainingConfig:</span>
<span id="cb17-86"><a href="#cb17-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb17-87"><a href="#cb17-87" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_lr <span class="op">=</span> <span class="fl">1e-4</span></span>
<span id="cb17-88"><a href="#cb17-88" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_lr <span class="op">=</span> <span class="fl">1e-5</span></span>
<span id="cb17-89"><a href="#cb17-89" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mixed_precision <span class="op">=</span> <span class="va">True</span></span>
<span id="cb17-90"><a href="#cb17-90" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-91"><a href="#cb17-91" aria-hidden="true" tabindex="-1"></a>config <span class="op">=</span> TrainingConfig()</span>
<span id="cb17-92"><a href="#cb17-92" aria-hidden="true" tabindex="-1"></a><span class="co"># trainer = OptimizedLoRATrainer(model, config)  # Would use real model</span></span></code></pre></div></div>
</details>
</div>
</section>
</section>
<section id="use-cases-and-applications" class="level2">
<h2 class="anchored" data-anchor-id="use-cases-and-applications" id="use-cases-and-applications">Use Cases and Applications</h2>
<section id="domain-adaptation" class="level3">
<h3 class="anchored" data-anchor-id="domain-adaptation" id="domain-adaptation">1. Domain Adaptation</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-4-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-1" role="tab" aria-controls="tabset-4-1" aria-selected="true" href="">Medical Imaging</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-4-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-2" role="tab" aria-controls="tabset-4-2" aria-selected="false" href="">Satellite Imagery</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-4-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-3" role="tab" aria-controls="tabset-4-3" aria-selected="false" href="">Autonomous Driving</a></li></ul>
<div class="tab-content">
<div id="tabset-4-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-4-1-tab">
<div class="callout callout-style-default callout-note no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Configuration Overview
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Optimized for medical image analysis</strong></p>
<p><strong>Rank:</strong> 32 | <strong>Alpha:</strong> 32<br>
<strong>Target modules:</strong> q_proj, v_proj, fc1, fc2</p>
</div>
</div>
<div class="tabset-margin-container"></div><div class="panel-tabset" data-group="medical">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Key Features</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Technical Details</a></li></ul>
<div class="tab-content" data-group="medical">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<div class="grid">
<section id="higher-rank" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="higher-rank">Higher Rank</h4>
<p>Complex medical patterns require higher dimensional adaptations for accurate analysis</p>
</section>
<section id="attention-focus" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="attention-focus">Attention Focus</h4>
<p>Specialized targeting of attention and MLP layers for medical feature detection</p>
</section>
<section id="enhanced-extraction" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="enhanced-extraction">Enhanced Extraction</h4>
<p>Advanced feature extraction capabilities for diagnostic imaging</p>
</section>
</div>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<table class="caption-top table">
<colgroup>
<col style="width: 39%">
<col style="width: 28%">
<col style="width: 32%">
</colgroup>
<thead>
<tr class="header">
<th>Parameter</th>
<th>Value</th>
<th>Purpose</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Rank</td>
<td>32</td>
<td>Handle complex medical pattern recognition</td>
</tr>
<tr class="even">
<td>Alpha</td>
<td>32</td>
<td>Balanced learning rate for medical data</td>
</tr>
<tr class="odd">
<td>Modules</td>
<td>q_proj, v_proj, fc1, fc2</td>
<td>Focus on attention and feed-forward layers</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</div>
<div id="tabset-4-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-4-2-tab">
<div class="callout callout-style-default callout-note no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Configuration Overview
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Adapted for satellite and aerial imagery</strong></p>
<p><strong>Rank:</strong> 16 | <strong>Alpha:</strong> 16<br>
<strong>Target modules:</strong> qkv, proj</p>
</div>
</div>
<div class="tabset-margin-container"></div><div class="panel-tabset" data-group="satellite">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Key Features</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Technical Details</a></li></ul>
<div class="tab-content" data-group="satellite">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<div class="grid">
<section id="balanced-efficiency" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="balanced-efficiency">Balanced Efficiency</h4>
<p>Optimized rank for computational efficiency while maintaining accuracy</p>
</section>
<section id="vision-focused" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="vision-focused">Vision-Focused</h4>
<p>Specialized adaptations for computer vision tasks</p>
</section>
<section id="spatial-modeling" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="spatial-modeling">Spatial Modeling</h4>
<p>Enhanced spatial relationship understanding for geographic data</p>
</section>
</div>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Parameter</th>
<th>Value</th>
<th>Purpose</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Rank</td>
<td>16</td>
<td>Balance between performance and efficiency</td>
</tr>
<tr class="even">
<td>Alpha</td>
<td>16</td>
<td>Moderate learning rate for aerial imagery</td>
</tr>
<tr class="odd">
<td>Modules</td>
<td>qkv, proj</td>
<td>Streamlined attention mechanisms</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</div>
<div id="tabset-4-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-4-3-tab">
<div class="callout callout-style-default callout-note no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Configuration Overview
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Designed for autonomous vehicle perception</strong></p>
<p><strong>Rank:</strong> 24 | <strong>Alpha:</strong> 24<br>
<strong>Target modules:</strong> q_proj, k_proj, v_proj, dense</p>
</div>
</div>
<div class="tabset-margin-container"></div><div class="panel-tabset" data-group="driving">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-3-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-1" role="tab" aria-controls="tabset-3-1" aria-selected="true" href="">Key Features</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-2" role="tab" aria-controls="tabset-3-2" aria-selected="false" href="">Technical Details</a></li></ul>
<div class="tab-content" data-group="driving">
<div id="tabset-3-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-3-1-tab">
<div class="grid">
<section id="real-time-performance" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="real-time-performance">Real-Time Performance</h4>
<p>Optimized for real-time inference requirements in vehicle systems</p>
</section>
<section id="multi-object-detection" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="multi-object-detection">Multi-Object Detection</h4>
<p>Specialized for detecting and tracking multiple objects simultaneously</p>
</section>
<section id="safety-critical" class="level4 g-col-4">
<h4 class="anchored" data-anchor-id="safety-critical">Safety-Critical</h4>
<p>Designed for safety-critical applications with high reliability standards</p>
</section>
</div>
</div>
<div id="tabset-3-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-2-tab">
<table class="caption-top table">
<colgroup>
<col style="width: 39%">
<col style="width: 28%">
<col style="width: 32%">
</colgroup>
<thead>
<tr class="header">
<th>Parameter</th>
<th>Value</th>
<th>Purpose</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Rank</td>
<td>24</td>
<td>High performance for safety-critical applications</td>
</tr>
<tr class="even">
<td>Alpha</td>
<td>24</td>
<td>Balanced learning for multi-object scenarios</td>
</tr>
<tr class="odd">
<td>Modules</td>
<td>q_proj, k_proj, v_proj, dense</td>
<td>Comprehensive attention and dense layer targeting</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</div>
</div>
</div>
<section id="summary-comparison" class="level4">
<h4 class="anchored" data-anchor-id="summary-comparison">Summary Comparison</h4>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Quick Reference Table
</div>
</div>
<div class="callout-body-container callout-body">
<table class="caption-top table">
<colgroup>
<col style="width: 18%">
<col style="width: 11%">
<col style="width: 12%">
<col style="width: 27%">
<col style="width: 29%">
</colgroup>
<thead>
<tr class="header">
<th>Use Case</th>
<th>Rank</th>
<th>Alpha</th>
<th>Primary Focus</th>
<th>Target Modules</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Medical Imaging</td>
<td>32</td>
<td>32</td>
<td>Complex pattern recognition</td>
<td>q_proj, v_proj, fc1, fc2</td>
</tr>
<tr class="even">
<td>Satellite Imagery</td>
<td>16</td>
<td>16</td>
<td>Efficient spatial analysis</td>
<td>qkv, proj</td>
</tr>
<tr class="odd">
<td>Autonomous Driving</td>
<td>24</td>
<td>24</td>
<td>Real-time multi-object detection</td>
<td>q_proj, k_proj, v_proj, dense</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Configuration Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Higher ranks</strong> (24-32) for complex, safety-critical applications</li>
<li><strong>Moderate ranks</strong> (16-20) for balanced efficiency and performance<br>
</li>
<li><strong>Lower ranks</strong> (4-12) for lightweight, fast inference applications</li>
</ul>
</div>
</div>
</section>
</section>
<section id="multi-lingual-vision-language" class="level3">
<h3 class="anchored" data-anchor-id="multi-lingual-vision-language" id="multi-lingual-vision-language">2. Multi-lingual Vision-Language</h3>
<div id="multilingual-lora" class="cell" data-execution_count="13">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultilingualLoRA:</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model, languages):</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> base_model</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.languages <span class="op">=</span> languages</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.language_adapters <span class="op">=</span> {}</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> lang <span class="kw">in</span> languages:</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.language_adapters[lang] <span class="op">=</span> <span class="va">self</span>.create_language_adapter(lang)</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_language_adapter(<span class="va">self</span>, language):</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create language-specific LoRA adapter"""</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Language-specific configurations</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        lang_configs <span class="op">=</span> {</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">"english"</span>: {<span class="st">"rank"</span>: <span class="dv">16</span>, <span class="st">"alpha"</span>: <span class="dv">16</span>},</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">"chinese"</span>: {<span class="st">"rank"</span>: <span class="dv">20</span>, <span class="st">"alpha"</span>: <span class="dv">20</span>},  <span class="co"># More complex script</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">"arabic"</span>: {<span class="st">"rank"</span>: <span class="dv">18</span>, <span class="st">"alpha"</span>: <span class="dv">18</span>},   <span class="co"># RTL language</span></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>            <span class="st">"hindi"</span>: {<span class="st">"rank"</span>: <span class="dv">22</span>, <span class="st">"alpha"</span>: <span class="dv">22</span>},    <span class="co"># Complex script</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>            <span class="st">"spanish"</span>: {<span class="st">"rank"</span>: <span class="dv">14</span>, <span class="st">"alpha"</span>: <span class="dv">14</span>},  <span class="co"># Similar to English</span></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        config <span class="op">=</span> lang_configs.get(language, {<span class="st">"rank"</span>: <span class="dv">16</span>, <span class="st">"alpha"</span>: <span class="dv">16</span>})</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> LoRAConfig(</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>            rank<span class="op">=</span>config[<span class="st">"rank"</span>],</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>            alpha<span class="op">=</span>config[<span class="st">"alpha"</span>],</span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>            target_modules<span class="op">=</span>[<span class="st">"q_proj"</span>, <span class="st">"k_proj"</span>, <span class="st">"v_proj"</span>],</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>            task_type<span class="op">=</span><span class="ss">f"vlm_</span><span class="sc">{</span>language<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_adapter_stats(<span class="va">self</span>):</span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get statistics about language adapters"""</span></span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        stats <span class="op">=</span> {}</span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> lang, adapter <span class="kw">in</span> <span class="va">self</span>.language_adapters.items():</span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>            stats[lang] <span class="op">=</span> {</span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>                <span class="st">"rank"</span>: adapter.rank,</span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>                <span class="st">"alpha"</span>: adapter.alpha,</span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>                <span class="st">"parameters"</span>: adapter.rank <span class="op">*</span> <span class="dv">768</span> <span class="op">*</span> <span class="dv">2</span>,  <span class="co"># Approximate</span></span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>                <span class="st">"target_modules"</span>: <span class="bu">len</span>(adapter.target_modules)</span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> stats</span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, images, texts, language):</span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Forward pass with language-specific adapter"""</span></span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> language <span class="kw">not</span> <span class="kw">in</span> <span class="va">self</span>.language_adapters:</span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Language '</span><span class="sc">{</span>language<span class="sc">}</span><span class="ss">' not supported"</span>)</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Would activate language-specific adapter</span></span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>        adapter_config <span class="op">=</span> <span class="va">self</span>.language_adapters[language]</span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Return placeholder for demonstration</span></span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>            <span class="st">"language"</span>: language,</span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>            <span class="st">"adapter_config"</span>: adapter_config,</span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>            <span class="st">"message"</span>: <span class="ss">f"Processing with </span><span class="sc">{</span>language<span class="sc">}</span><span class="ss"> adapter"</span></span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-59"><a href="#cb18-59" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstration</span></span>
<span id="cb18-60"><a href="#cb18-60" aria-hidden="true" tabindex="-1"></a>languages <span class="op">=</span> [<span class="st">"english"</span>, <span class="st">"chinese"</span>, <span class="st">"arabic"</span>, <span class="st">"hindi"</span>, <span class="st">"spanish"</span>]</span>
<span id="cb18-61"><a href="#cb18-61" aria-hidden="true" tabindex="-1"></a>multilingual_model <span class="op">=</span> MultilingualLoRA(<span class="va">None</span>, languages)</span>
<span id="cb18-62"><a href="#cb18-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-63"><a href="#cb18-63" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Multilingual LoRA Configuration:"</span>)</span>
<span id="cb18-64"><a href="#cb18-64" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">40</span>)</span>
<span id="cb18-65"><a href="#cb18-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-66"><a href="#cb18-66" aria-hidden="true" tabindex="-1"></a>adapter_stats <span class="op">=</span> multilingual_model.get_adapter_stats()</span>
<span id="cb18-67"><a href="#cb18-67" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> lang, stats <span class="kw">in</span> adapter_stats.items():</span>
<span id="cb18-68"><a href="#cb18-68" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="sc">{</span>lang<span class="sc">.</span>title()<span class="sc">}</span><span class="ss">:"</span>)</span>
<span id="cb18-69"><a href="#cb18-69" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Rank: </span><span class="sc">{</span>stats[<span class="st">'rank'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-70"><a href="#cb18-70" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Alpha: </span><span class="sc">{</span>stats[<span class="st">'alpha'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-71"><a href="#cb18-71" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Parameters: ~</span><span class="sc">{</span>stats[<span class="st">'parameters'</span>]<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb18-72"><a href="#cb18-72" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Target modules: </span><span class="sc">{</span>stats[<span class="st">'target_modules'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-73"><a href="#cb18-73" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-74"><a href="#cb18-74" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb18-75"><a href="#cb18-75" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> multilingual_model.forward(<span class="va">None</span>, <span class="va">None</span>, <span class="st">"chinese"</span>)</span>
<span id="cb18-76"><a href="#cb18-76" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Example usage: </span><span class="sc">{</span>result[<span class="st">'message'</span>]<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Multilingual LoRA Configuration:
========================================

English:
  Rank: 16
  Alpha: 16
  Parameters: ~24,576
  Target modules: 3

Chinese:
  Rank: 20
  Alpha: 20
  Parameters: ~30,720
  Target modules: 3

Arabic:
  Rank: 18
  Alpha: 18
  Parameters: ~27,648
  Target modules: 3

Hindi:
  Rank: 22
  Alpha: 22
  Parameters: ~33,792
  Target modules: 3

Spanish:
  Rank: 14
  Alpha: 14
  Parameters: ~21,504
  Target modules: 3

Example usage: Processing with chinese adapter</code></pre>
</div>
</div>
</section>
<section id="few-shot-learning" class="level3">
<h3 class="anchored" data-anchor-id="few-shot-learning" id="few-shot-learning">3. Few-Shot Learning</h3>
<div id="few-shot-learning" class="cell" data-execution_count="14">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FewShotLoRALearner:</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model, config):</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> base_model</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.task_adapters <span class="op">=</span> {}</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_task_adapter(<span class="va">self</span>, task_name, rank<span class="op">=</span><span class="dv">8</span>, alpha<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create a lightweight adapter for few-shot learning"""</span></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> LoRAConfig(</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>            rank<span class="op">=</span>rank,</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>            alpha<span class="op">=</span>alpha,</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>            target_modules<span class="op">=</span>[<span class="st">"q_proj"</span>, <span class="st">"v_proj"</span>],  <span class="co"># Minimal modules for efficiency</span></span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>            task_type<span class="op">=</span><span class="ss">f"few_shot_</span><span class="sc">{</span>task_name<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>            learning_rate<span class="op">=</span><span class="fl">1e-3</span>,  <span class="co"># Higher LR for fast adaptation</span></span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>            dropout<span class="op">=</span><span class="fl">0.0</span>  <span class="co"># No dropout for few-shot</span></span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> adapt_to_task(<span class="va">self</span>, task_name, support_examples, num_steps<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Quick adaptation using few examples"""</span></span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Adapting to task: </span><span class="sc">{</span>task_name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Support examples: </span><span class="sc">{</span><span class="bu">len</span>(support_examples)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Adaptation steps: </span><span class="sc">{</span>num_steps<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create task-specific adapter</span></span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>        adapter_config <span class="op">=</span> <span class="va">self</span>.create_task_adapter(task_name)</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.task_adapters[task_name] <span class="op">=</span> adapter_config</span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate adaptation process</span></span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>        adaptation_progress <span class="op">=</span> []</span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> step <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, num_steps <span class="op">+</span> <span class="dv">1</span>, <span class="dv">20</span>):</span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simulate decreasing loss</span></span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> <span class="fl">2.0</span> <span class="op">*</span> np.exp(<span class="op">-</span>step <span class="op">/</span> <span class="dv">50</span>) <span class="op">+</span> <span class="fl">0.1</span></span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>            accuracy <span class="op">=</span> <span class="bu">min</span>(<span class="fl">0.95</span>, <span class="fl">0.3</span> <span class="op">+</span> <span class="fl">0.65</span> <span class="op">*</span> (<span class="dv">1</span> <span class="op">-</span> np.exp(<span class="op">-</span>step <span class="op">/</span> <span class="dv">30</span>)))</span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>            adaptation_progress.append({</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>                <span class="st">'step'</span>: step,</span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a>                <span class="st">'loss'</span>: loss,</span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a>                <span class="st">'accuracy'</span>: accuracy</span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> adaptation_progress</span>
<span id="cb20-42"><a href="#cb20-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-43"><a href="#cb20-43" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_adaptation(<span class="va">self</span>, task_name, test_examples):</span>
<span id="cb20-44"><a href="#cb20-44" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate adapted model on test examples"""</span></span>
<span id="cb20-45"><a href="#cb20-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> task_name <span class="kw">not</span> <span class="kw">in</span> <span class="va">self</span>.task_adapters:</span>
<span id="cb20-46"><a href="#cb20-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"No adapter found for task: </span><span class="sc">{</span>task_name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-47"><a href="#cb20-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-48"><a href="#cb20-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate evaluation results</span></span>
<span id="cb20-49"><a href="#cb20-49" aria-hidden="true" tabindex="-1"></a>        performance <span class="op">=</span> {</span>
<span id="cb20-50"><a href="#cb20-50" aria-hidden="true" tabindex="-1"></a>            <span class="st">'accuracy'</span>: <span class="fl">0.87</span>,</span>
<span id="cb20-51"><a href="#cb20-51" aria-hidden="true" tabindex="-1"></a>            <span class="st">'precision'</span>: <span class="fl">0.89</span>,</span>
<span id="cb20-52"><a href="#cb20-52" aria-hidden="true" tabindex="-1"></a>            <span class="st">'recall'</span>: <span class="fl">0.85</span>,</span>
<span id="cb20-53"><a href="#cb20-53" aria-hidden="true" tabindex="-1"></a>            <span class="st">'f1_score'</span>: <span class="fl">0.87</span>,</span>
<span id="cb20-54"><a href="#cb20-54" aria-hidden="true" tabindex="-1"></a>            <span class="st">'test_examples'</span>: <span class="bu">len</span>(test_examples)</span>
<span id="cb20-55"><a href="#cb20-55" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb20-56"><a href="#cb20-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-57"><a href="#cb20-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> performance</span>
<span id="cb20-58"><a href="#cb20-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-59"><a href="#cb20-59" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstration of few-shot learning</span></span>
<span id="cb20-60"><a href="#cb20-60" aria-hidden="true" tabindex="-1"></a>few_shot_learner <span class="op">=</span> FewShotLoRALearner(<span class="va">None</span>, <span class="va">None</span>)</span>
<span id="cb20-61"><a href="#cb20-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-62"><a href="#cb20-62" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate different tasks</span></span>
<span id="cb20-63"><a href="#cb20-63" aria-hidden="true" tabindex="-1"></a>tasks <span class="op">=</span> {</span>
<span id="cb20-64"><a href="#cb20-64" aria-hidden="true" tabindex="-1"></a>    <span class="st">"bird_classification"</span>: <span class="dv">16</span>,  <span class="co"># 16 support examples</span></span>
<span id="cb20-65"><a href="#cb20-65" aria-hidden="true" tabindex="-1"></a>    <span class="st">"medical_diagnosis"</span>: <span class="dv">8</span>,     <span class="co"># 8 support examples  </span></span>
<span id="cb20-66"><a href="#cb20-66" aria-hidden="true" tabindex="-1"></a>    <span class="st">"product_recognition"</span>: <span class="dv">32</span>   <span class="co"># 32 support examples</span></span>
<span id="cb20-67"><a href="#cb20-67" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb20-68"><a href="#cb20-68" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-69"><a href="#cb20-69" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Few-Shot Learning with LoRA:"</span>)</span>
<span id="cb20-70"><a href="#cb20-70" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">35</span>)</span>
<span id="cb20-71"><a href="#cb20-71" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-72"><a href="#cb20-72" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> task_name, num_examples <span class="kw">in</span> tasks.items():</span>
<span id="cb20-73"><a href="#cb20-73" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Task: </span><span class="sc">{</span>task_name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-74"><a href="#cb20-74" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-75"><a href="#cb20-75" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Adapt to task</span></span>
<span id="cb20-76"><a href="#cb20-76" aria-hidden="true" tabindex="-1"></a>    support_examples <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(num_examples))  <span class="co"># Mock examples</span></span>
<span id="cb20-77"><a href="#cb20-77" aria-hidden="true" tabindex="-1"></a>    progress <span class="op">=</span> few_shot_learner.adapt_to_task(task_name, support_examples)</span>
<span id="cb20-78"><a href="#cb20-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-79"><a href="#cb20-79" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Show adaptation progress</span></span>
<span id="cb20-80"><a href="#cb20-80" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Adaptation progress:"</span>)</span>
<span id="cb20-81"><a href="#cb20-81" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> point <span class="kw">in</span> progress[<span class="op">-</span><span class="dv">3</span>:]:  <span class="co"># Show last 3 points</span></span>
<span id="cb20-82"><a href="#cb20-82" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"  Step </span><span class="sc">{</span>point[<span class="st">'step'</span>]<span class="sc">:3d}</span><span class="ss">: Loss=</span><span class="sc">{</span>point[<span class="st">'loss'</span>]<span class="sc">:.3f}</span><span class="ss">, Acc=</span><span class="sc">{</span>point[<span class="st">'accuracy'</span>]<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb20-83"><a href="#cb20-83" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-84"><a href="#cb20-84" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluate</span></span>
<span id="cb20-85"><a href="#cb20-85" aria-hidden="true" tabindex="-1"></a>    test_examples <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">50</span>))  <span class="co"># Mock test set</span></span>
<span id="cb20-86"><a href="#cb20-86" aria-hidden="true" tabindex="-1"></a>    performance <span class="op">=</span> few_shot_learner.evaluate_adaptation(task_name, test_examples)</span>
<span id="cb20-87"><a href="#cb20-87" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Final performance: </span><span class="sc">{</span>performance[<span class="st">'accuracy'</span>]<span class="sc">:.3f}</span><span class="ss"> accuracy"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Few-Shot Learning with LoRA:
===================================

Task: bird_classification
Adapting to task: bird_classification
Support examples: 16
Adaptation steps: 100
Adaptation progress:
  Step  60: Loss=0.702, Acc=0.862
  Step  80: Loss=0.504, Acc=0.905
  Step 100: Loss=0.371, Acc=0.927
Final performance: 0.870 accuracy

Task: medical_diagnosis
Adapting to task: medical_diagnosis
Support examples: 8
Adaptation steps: 100
Adaptation progress:
  Step  60: Loss=0.702, Acc=0.862
  Step  80: Loss=0.504, Acc=0.905
  Step 100: Loss=0.371, Acc=0.927
Final performance: 0.870 accuracy

Task: product_recognition
Adapting to task: product_recognition
Support examples: 32
Adaptation steps: 100
Adaptation progress:
  Step  60: Loss=0.702, Acc=0.862
  Step  80: Loss=0.504, Acc=0.905
  Step 100: Loss=0.371, Acc=0.927
Final performance: 0.870 accuracy</code></pre>
</div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="hyperparameter-selection" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-selection" id="hyperparameter-selection">1. Hyperparameter Selection</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-5-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-5-1" role="tab" aria-controls="tabset-5-1" aria-selected="true" href="">Simple Classification</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-5-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-5-2" role="tab" aria-controls="tabset-5-2" aria-selected="false" href="">Medical VQA</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-5-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-5-3" role="tab" aria-controls="tabset-5-3" aria-selected="false" href="">General Captioning</a></li></ul>
<div class="tab-content">
<div id="tabset-5-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-5-1-tab">
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Recommended Settings
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Rank</strong>: 4</li>
<li><strong>Alpha</strong>: 4</li>
<li><strong>LoRA Learning Rate</strong>: 0.0001</li>
<li><strong>Base Learning Rate</strong>: 1e-05</li>
</ul>
</div>
</div>
<p><strong>Reasoning</strong>: Selected rank 4 for simple task complexity. This configuration provides sufficient adaptation capacity for straightforward classification tasks while maintaining parameter efficiency.</p>
</div>
<div id="tabset-5-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-5-2-tab">
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Recommended Settings
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Rank</strong>: 64</li>
<li><strong>Alpha</strong>: 128</li>
<li><strong>LoRA Learning Rate</strong>: 0.0001</li>
<li><strong>Base Learning Rate</strong>: 1e-05</li>
</ul>
</div>
</div>
<p><strong>Reasoning</strong>: Selected rank 64 for complex task complexity. Medical Visual Question Answering requires higher capacity to handle the intricate relationships between medical imagery and specialized domain knowledge.</p>
</div>
<div id="tabset-5-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-5-3-tab">
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Recommended Settings
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Rank</strong>: 16</li>
<li><strong>Alpha</strong>: 24</li>
<li><strong>LoRA Learning Rate</strong>: 0.0001</li>
<li><strong>Base Learning Rate</strong>: 1e-05</li>
</ul>
</div>
</div>
<p><strong>Reasoning</strong>: Selected rank 16 for balanced task complexity. General captioning strikes a middle ground between simple classification and highly specialized tasks, requiring moderate adaptation capacity.</p>
</div>
</div>
</div>
<section id="summary-table" class="level4">
<h4 class="anchored" data-anchor-id="summary-table">Summary Table</h4>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Quick Reference Table
</div>
</div>
<div class="callout-body-container callout-body">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Scenario</th>
<th>Rank</th>
<th>Alpha</th>
<th>LoRA LR</th>
<th>Base LR</th>
<th>Task Complexity</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Simple Classification</td>
<td>4</td>
<td>4</td>
<td>0.0001</td>
<td>1e-05</td>
<td>Low</td>
</tr>
<tr class="even">
<td>Medical VQA</td>
<td>64</td>
<td>128</td>
<td>0.0001</td>
<td>1e-05</td>
<td>High</td>
</tr>
<tr class="odd">
<td>General Captioning</td>
<td>16</td>
<td>24</td>
<td>0.0001</td>
<td>1e-05</td>
<td>Medium</td>
</tr>
</tbody>
</table>
</div>
</div>
</section>
</section>
<section id="module-selection-strategy" class="level3">
<h3 class="anchored" data-anchor-id="module-selection-strategy" id="module-selection-strategy">2. Module Selection Strategy</h3>
<div id="cell-fig-module-selection" class="cell" data-execution_count="15">
<div class="cell-output cell-output-display">
<div id="fig-module-selection" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-module-selection-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/fig-module-selection-output-1.png" width="1430" height="566" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-module-selection-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;3: LoRA Module Selection Impact Analysis
</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="training-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="training-best-practices" id="training-best-practices">3. Training Best Practices</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-6-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-6-1" role="tab" aria-controls="tabset-6-1" aria-selected="true" href="">Setup Phase</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-6-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-6-2" role="tab" aria-controls="tabset-6-2" aria-selected="false" href="">Monitoring Phase</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-6-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-6-3" role="tab" aria-controls="tabset-6-3" aria-selected="false" href="">Checkpointing Phase</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-6-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-6-4" role="tab" aria-controls="tabset-6-4" aria-selected="false" href="">Evaluation Phase</a></li></ul>
<div class="tab-content">
<div id="tabset-6-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-6-1-tab">
<ul>
<li>Configure separate learning rates for LoRA and base parameters</li>
<li>Enable mixed precision training</li>
<li>Set up gradient accumulation</li>
<li>Configure gradient clipping</li>
</ul>
</div>
<div id="tabset-6-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-6-2-tab">
<ul>
<li>Track LoRA weight norms</li>
<li>Monitor validation metrics</li>
<li>Check for overfitting signs</li>
<li>Validate rank utilization</li>
</ul>
</div>
<div id="tabset-6-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-6-3-tab">
<ul>
<li>Save model at regular intervals</li>
<li>Keep best performing checkpoint</li>
<li>Save LoRA adapters separately</li>
<li>Document hyperparameters</li>
</ul>
</div>
<div id="tabset-6-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-6-4-tab">
<ul>
<li>Test on multiple datasets</li>
<li>Measure parameter efficiency</li>
<li>Check inference speed</li>
<li>Validate robustness</li>
</ul>
</div>
</div>
</div>
<section id="configuration-validation" class="level4">
<h4 class="anchored" data-anchor-id="configuration-validation">Configuration Validation</h4>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-7-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-7-1" role="tab" aria-controls="tabset-7-1" aria-selected="true" href="">Good Config</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-7-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-7-2" role="tab" aria-controls="tabset-7-2" aria-selected="false" href="">High Rank</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-7-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-7-3" role="tab" aria-controls="tabset-7-3" aria-selected="false" href="">Low Alpha</a></li></ul>
<div class="tab-content">
<div id="tabset-7-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-7-1-tab">
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Status: ✅ Valid
</div>
</div>
<div class="callout-body-container callout-body">
<p>Configuration is valid and ready to use.</p>
</div>
</div>
</div>
<div id="tabset-7-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-7-2-tab">
<div class="callout callout-style-default callout-tip callout-empty-content callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Status: ✅ Valid
</div>
</div>
<div class="callout-body-container callout-body">

</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Warnings
</div>
</div>
<div class="callout-body-container callout-body">
<p>⚠️ Very high rank may reduce efficiency benefits</p>
</div>
</div>
</div>
<div id="tabset-7-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-7-3-tab">
<div class="callout callout-style-default callout-tip callout-empty-content callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Status: ✅ Valid
</div>
</div>
<div class="callout-body-container callout-body">

</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Warnings
</div>
</div>
<div class="callout-body-container callout-body">
<p>⚠️ Very low alpha may limit adaptation strength</p>
</div>
</div>
</div>
</div>
</div>
</section>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-9-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-9-1" role="tab" aria-controls="tabset-9-1" aria-selected="true" href="">Example Diagnosis</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-9-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-9-2" role="tab" aria-controls="tabset-9-2" aria-selected="false" href="">Debugging Checklist</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-9-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-9-3" role="tab" aria-controls="tabset-9-3" aria-selected="false" href="">Debugging Tools</a></li></ul>
<div class="tab-content">
<div id="tabset-9-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-9-1-tab">
<div class="callout callout-style-default callout-warning no-icon callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon no-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Training Issue Analysis
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Symptoms Observed:</strong> - Loss spikes during training - Gradient explosion detected<br>
- Poor convergence after many epochs</p>
<p><strong>Diagnosis:</strong> Training Instability<br>
<strong>Confidence Level:</strong> 67%</p>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Recommended Solutions
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Apply gradient clipping</strong> (max_norm=1.0)</li>
<li><strong>Use learning rate scheduling</strong></li>
<li><strong>Enable gradient accumulation</strong></li>
</ul>
</div>
</div>
</div>
<div id="tabset-9-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-9-2-tab">
<div class="tabset-margin-container"></div><div class="panel-tabset" data-group="checklist">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-8-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-8-1" role="tab" aria-controls="tabset-8-1" aria-selected="true" href="">📊 Data Quality</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-8-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-8-2" role="tab" aria-controls="tabset-8-2" aria-selected="false" href="">🔧 Model Configuration</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-8-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-8-3" role="tab" aria-controls="tabset-8-3" aria-selected="false" href="">📈 Training Metrics</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-8-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-8-4" role="tab" aria-controls="tabset-8-4" aria-selected="false" href="">💾 System Resources</a></li></ul>
<div class="tab-content" data-group="checklist">
<div id="tabset-8-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-8-1-tab">
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center" data-bs-toggle="collapse" data-bs-target=".callout-17-contents" aria-controls="callout-17" aria-expanded="true" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Data Validation Steps
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-17" class="callout-17-contents callout-collapse collapse show">
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox"><strong>Validate input preprocessing</strong></label>
<ul>
<li>Check normalization parameters</li>
<li>Verify tokenization consistency</li>
</ul></li>
<li><label><input type="checkbox"><strong>Check label distribution</strong></label>
<ul>
<li>Examine class balance</li>
<li>Identify potential bias</li>
</ul></li>
<li><label><input type="checkbox"><strong>Verify data augmentation</strong></label>
<ul>
<li>Test augmentation pipeline</li>
<li>Ensure proper randomization</li>
</ul></li>
<li><label><input type="checkbox"><strong>Ensure proper batching</strong></label>
<ul>
<li>Validate batch size settings</li>
<li>Check data loader configuration</li>
</ul></li>
</ul>
</div>
</div>
</div>
</div>
<div id="tabset-8-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-8-2-tab">
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center" data-bs-toggle="collapse" data-bs-target=".callout-18-contents" aria-controls="callout-18" aria-expanded="true" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Configuration Verification
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-18" class="callout-18-contents callout-collapse collapse show">
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox"><strong>Confirm LoRA target modules</strong></label>
<ul>
<li>Verify layer selection</li>
<li>Check module naming consistency</li>
</ul></li>
<li><label><input type="checkbox"><strong>Check rank and alpha values</strong></label>
<ul>
<li>Validate rank appropriateness</li>
<li>Ensure alpha scaling is correct</li>
</ul></li>
<li><label><input type="checkbox"><strong>Validate learning rates</strong></label>
<ul>
<li>Test different LR values</li>
<li>Check optimizer settings</li>
</ul></li>
<li><label><input type="checkbox"><strong>Ensure proper initialization</strong></label>
<ul>
<li>Verify weight initialization</li>
<li>Check adapter placement</li>
</ul></li>
</ul>
</div>
</div>
</div>
</div>
<div id="tabset-8-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-8-3-tab">
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center" data-bs-toggle="collapse" data-bs-target=".callout-19-contents" aria-controls="callout-19" aria-expanded="true" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Monitoring Guidelines
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-19" class="callout-19-contents callout-collapse collapse show">
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox"><strong>Track loss curves</strong></label>
<ul>
<li>Monitor training/validation loss</li>
<li>Identify overfitting patterns</li>
</ul></li>
<li><label><input type="checkbox"><strong>Monitor gradient norms</strong></label>
<ul>
<li>Check for gradient explosion</li>
<li>Detect vanishing gradients</li>
</ul></li>
<li><label><input type="checkbox"><strong>Check weight magnitudes</strong></label>
<ul>
<li>Monitor parameter updates</li>
<li>Verify adapter weights</li>
</ul></li>
<li><label><input type="checkbox"><strong>Validate learning rate schedule</strong></label>
<ul>
<li>Confirm schedule implementation</li>
<li>Monitor LR decay patterns</li>
</ul></li>
</ul>
</div>
</div>
</div>
</div>
<div id="tabset-8-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-8-4-tab">
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center" data-bs-toggle="collapse" data-bs-target=".callout-20-contents" aria-controls="callout-20" aria-expanded="true" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Resource Monitoring
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-20" class="callout-20-contents callout-collapse collapse show">
<div class="callout-body-container callout-body">
<ul class="task-list">
<li><label><input type="checkbox"><strong>Monitor GPU memory usage</strong></label>
<ul>
<li>Track memory consumption</li>
<li>Optimize memory allocation</li>
</ul></li>
<li><label><input type="checkbox"><strong>Check system RAM</strong></label>
<ul>
<li>Monitor system memory</li>
<li>Identify memory leaks</li>
</ul></li>
<li><label><input type="checkbox"><strong>Verify disk space</strong></label>
<ul>
<li>Check storage availability</li>
<li>Monitor checkpoint sizes</li>
</ul></li>
<li><label><input type="checkbox"><strong>Monitor temperature/throttling</strong></label>
<ul>
<li>Check GPU temperatures</li>
<li>Detect thermal throttling</li>
</ul></li>
</ul>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<div id="tabset-9-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-9-3-tab">
<section id="lora-debugging-analysis" class="level4">
<h4 class="anchored" data-anchor-id="lora-debugging-analysis">LoRA Debugging Analysis</h4>
<div class="grid">
<div class="g-col-6">
<p><strong>Adapter Information:</strong></p>
<ul>
<li><strong>Name:</strong> medical_vqa_adapter</li>
<li><strong>Health Status:</strong> 🟢 Healthy</li>
</ul>
</div>
<div class="g-col-6">
<p><strong>Rank Utilization Summary:</strong></p>
<ul>
<li><strong>Mean:</strong> 0.537</li>
<li><strong>Std Dev:</strong> 0.184<br>
</li>
<li><strong>Range:</strong> 0.250 - 0.812</li>
</ul>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>💡 Recommendation
</div>
</div>
<div class="callout-body-container callout-body">
<p>LoRA configuration appears optimal based on current metrics.</p>
</div>
</div>
</section>
</div>
</div>
</div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Quick Summary
</div>
</div>
<div class="callout-body-container callout-body">
<table class="caption-top table">
<colgroup>
<col style="width: 25%">
<col style="width: 37%">
<col style="width: 37%">
</colgroup>
<thead>
<tr class="header">
<th>Issue</th>
<th>Symptoms</th>
<th>Solution</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Gradient Explosion</strong></td>
<td>Loss spikes, NaN values</td>
<td>Apply gradient clipping</td>
</tr>
<tr class="even">
<td><strong>Slow Convergence</strong></td>
<td>Plateau in loss</td>
<td>Adjust learning rate</td>
</tr>
<tr class="odd">
<td><strong>Memory Issues</strong></td>
<td>OOM errors</td>
<td>Reduce batch size, use gradient accumulation</td>
</tr>
<tr class="even">
<td><strong>Overfitting</strong></td>
<td>Train/val loss divergence</td>
<td>Add regularization, reduce rank</td>
</tr>
<tr class="odd">
<td><strong>Poor Performance</strong></td>
<td>Low accuracy</td>
<td>Increase rank, check target modules</td>
</tr>
</tbody>
</table>
</div>
</div>
</section>
<section id="additional-resources" class="level3">
<h3 class="anchored" data-anchor-id="additional-resources" id="additional-resources">Additional Resources</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-23-contents" aria-controls="callout-23" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Useful Commands
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-23" class="callout-23-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor GPU usage</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="ex">nvidia-smi</span> <span class="at">-l</span> 1</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Check disk space</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a><span class="fu">df</span> <span class="at">-h</span></span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor system resources</span></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a><span class="ex">htop</span></span></code></pre></div></div>
</div>
</div>
</div>
</section>
<section id="debugging-tools-1" class="level3">
<h3 class="anchored" data-anchor-id="debugging-tools-1" id="debugging-tools-1">Debugging Tools</h3>
<div id="debugging-tools" class="cell" data-execution_count="16">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRADebugger:</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, adapter_name<span class="op">=</span><span class="st">"default"</span>):</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.adapter_name <span class="op">=</span> adapter_name</span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.analysis_cache <span class="op">=</span> {}</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_lora_weights(<span class="va">self</span>):</span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Analyze LoRA weight distributions"""</span></span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'weight_analysis'</span> <span class="kw">in</span> <span class="va">self</span>.analysis_cache:</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.analysis_cache[<span class="st">'weight_analysis'</span>]</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>        stats <span class="op">=</span> {}</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate analysis for demonstration</span></span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>        module_names <span class="op">=</span> [<span class="st">"attention.q_proj"</span>, <span class="st">"attention.k_proj"</span>, <span class="st">"attention.v_proj"</span>, </span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>                       <span class="st">"mlp.fc1"</span>, <span class="st">"mlp.fc2"</span>]</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name <span class="kw">in</span> module_names:</span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simulate weight statistics</span></span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a>            lora_A_norm <span class="op">=</span> np.random.uniform(<span class="fl">0.1</span>, <span class="fl">2.0</span>)</span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a>            lora_B_norm <span class="op">=</span> np.random.uniform(<span class="fl">0.1</span>, <span class="fl">2.0</span>)</span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a>            effective_rank <span class="op">=</span> np.random.randint(<span class="dv">4</span>, <span class="dv">16</span>)</span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb23-24"><a href="#cb23-24" aria-hidden="true" tabindex="-1"></a>            stats[name] <span class="op">=</span> {</span>
<span id="cb23-25"><a href="#cb23-25" aria-hidden="true" tabindex="-1"></a>                <span class="st">"lora_A_norm"</span>: lora_A_norm,</span>
<span id="cb23-26"><a href="#cb23-26" aria-hidden="true" tabindex="-1"></a>                <span class="st">"lora_B_norm"</span>: lora_B_norm,</span>
<span id="cb23-27"><a href="#cb23-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">"effective_rank"</span>: effective_rank,</span>
<span id="cb23-28"><a href="#cb23-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">"rank_utilization"</span>: effective_rank <span class="op">/</span> <span class="fl">16.0</span></span>
<span id="cb23-29"><a href="#cb23-29" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb23-30"><a href="#cb23-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-31"><a href="#cb23-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.analysis_cache[<span class="st">'weight_analysis'</span>] <span class="op">=</span> stats</span>
<span id="cb23-32"><a href="#cb23-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> stats</span>
<span id="cb23-33"><a href="#cb23-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-34"><a href="#cb23-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_rank_utilization(<span class="va">self</span>, threshold<span class="op">=</span><span class="fl">0.01</span>):</span>
<span id="cb23-35"><a href="#cb23-35" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute rank utilization across modules"""</span></span>
<span id="cb23-36"><a href="#cb23-36" aria-hidden="true" tabindex="-1"></a>        weight_stats <span class="op">=</span> <span class="va">self</span>.analyze_lora_weights()</span>
<span id="cb23-37"><a href="#cb23-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-38"><a href="#cb23-38" aria-hidden="true" tabindex="-1"></a>        utilizations <span class="op">=</span> []</span>
<span id="cb23-39"><a href="#cb23-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> module_name, stats <span class="kw">in</span> weight_stats.items():</span>
<span id="cb23-40"><a href="#cb23-40" aria-hidden="true" tabindex="-1"></a>            utilizations.append(stats[<span class="st">"rank_utilization"</span>])</span>
<span id="cb23-41"><a href="#cb23-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-42"><a href="#cb23-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb23-43"><a href="#cb23-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">"mean_utilization"</span>: np.mean(utilizations),</span>
<span id="cb23-44"><a href="#cb23-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">"std_utilization"</span>: np.std(utilizations),</span>
<span id="cb23-45"><a href="#cb23-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">"min_utilization"</span>: np.<span class="bu">min</span>(utilizations),</span>
<span id="cb23-46"><a href="#cb23-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">"max_utilization"</span>: np.<span class="bu">max</span>(utilizations),</span>
<span id="cb23-47"><a href="#cb23-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">"per_module"</span>: {name: stats[<span class="st">"rank_utilization"</span>] </span>
<span id="cb23-48"><a href="#cb23-48" aria-hidden="true" tabindex="-1"></a>                          <span class="cf">for</span> name, stats <span class="kw">in</span> weight_stats.items()}</span>
<span id="cb23-49"><a href="#cb23-49" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb23-50"><a href="#cb23-50" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-51"><a href="#cb23-51" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_health_report(<span class="va">self</span>):</span>
<span id="cb23-52"><a href="#cb23-52" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate comprehensive health report"""</span></span>
<span id="cb23-53"><a href="#cb23-53" aria-hidden="true" tabindex="-1"></a>        weight_analysis <span class="op">=</span> <span class="va">self</span>.analyze_lora_weights()</span>
<span id="cb23-54"><a href="#cb23-54" aria-hidden="true" tabindex="-1"></a>        rank_utilization <span class="op">=</span> <span class="va">self</span>.compute_rank_utilization()</span>
<span id="cb23-55"><a href="#cb23-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-56"><a href="#cb23-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Identify potential issues</span></span>
<span id="cb23-57"><a href="#cb23-57" aria-hidden="true" tabindex="-1"></a>        issues <span class="op">=</span> []</span>
<span id="cb23-58"><a href="#cb23-58" aria-hidden="true" tabindex="-1"></a>        warnings <span class="op">=</span> []</span>
<span id="cb23-59"><a href="#cb23-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-60"><a href="#cb23-60" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for very low rank utilization</span></span>
<span id="cb23-61"><a href="#cb23-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> rank_utilization[<span class="st">"mean_utilization"</span>] <span class="op">&lt;</span> <span class="fl">0.3</span>:</span>
<span id="cb23-62"><a href="#cb23-62" aria-hidden="true" tabindex="-1"></a>            issues.append(<span class="st">"Low average rank utilization - consider reducing rank"</span>)</span>
<span id="cb23-63"><a href="#cb23-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-64"><a href="#cb23-64" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for very high weight norms</span></span>
<span id="cb23-65"><a href="#cb23-65" aria-hidden="true" tabindex="-1"></a>        high_norm_modules <span class="op">=</span> [name <span class="cf">for</span> name, stats <span class="kw">in</span> weight_analysis.items() </span>
<span id="cb23-66"><a href="#cb23-66" aria-hidden="true" tabindex="-1"></a>                           <span class="cf">if</span> stats[<span class="st">"lora_A_norm"</span>] <span class="op">&gt;</span> <span class="fl">5.0</span> <span class="kw">or</span> stats[<span class="st">"lora_B_norm"</span>] <span class="op">&gt;</span> <span class="fl">5.0</span>]</span>
<span id="cb23-67"><a href="#cb23-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> high_norm_modules:</span>
<span id="cb23-68"><a href="#cb23-68" aria-hidden="true" tabindex="-1"></a>            warnings.append(<span class="ss">f"High weight norms in modules: </span><span class="sc">{</span><span class="st">', '</span><span class="sc">.</span>join(high_norm_modules)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb23-69"><a href="#cb23-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-70"><a href="#cb23-70" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for rank imbalance</span></span>
<span id="cb23-71"><a href="#cb23-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> rank_utilization[<span class="st">"std_utilization"</span>] <span class="op">&gt;</span> <span class="fl">0.3</span>:</span>
<span id="cb23-72"><a href="#cb23-72" aria-hidden="true" tabindex="-1"></a>            warnings.append(<span class="st">"High variance in rank utilization across modules"</span>)</span>
<span id="cb23-73"><a href="#cb23-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-74"><a href="#cb23-74" aria-hidden="true" tabindex="-1"></a>        report <span class="op">=</span> {</span>
<span id="cb23-75"><a href="#cb23-75" aria-hidden="true" tabindex="-1"></a>            <span class="st">"adapter_name"</span>: <span class="va">self</span>.adapter_name,</span>
<span id="cb23-76"><a href="#cb23-76" aria-hidden="true" tabindex="-1"></a>            <span class="st">"weight_analysis"</span>: weight_analysis,</span>
<span id="cb23-77"><a href="#cb23-77" aria-hidden="true" tabindex="-1"></a>            <span class="st">"rank_utilization"</span>: rank_utilization,</span>
<span id="cb23-78"><a href="#cb23-78" aria-hidden="true" tabindex="-1"></a>            <span class="st">"health_status"</span>: <span class="st">"healthy"</span> <span class="cf">if</span> <span class="kw">not</span> issues <span class="cf">else</span> <span class="st">"needs_attention"</span>,</span>
<span id="cb23-79"><a href="#cb23-79" aria-hidden="true" tabindex="-1"></a>            <span class="st">"issues"</span>: issues,</span>
<span id="cb23-80"><a href="#cb23-80" aria-hidden="true" tabindex="-1"></a>            <span class="st">"warnings"</span>: warnings,</span>
<span id="cb23-81"><a href="#cb23-81" aria-hidden="true" tabindex="-1"></a>            <span class="st">"recommendations"</span>: <span class="va">self</span>._generate_recommendations(issues, warnings)</span>
<span id="cb23-82"><a href="#cb23-82" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb23-83"><a href="#cb23-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-84"><a href="#cb23-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> report</span>
<span id="cb23-85"><a href="#cb23-85" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-86"><a href="#cb23-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _generate_recommendations(<span class="va">self</span>, issues, warnings):</span>
<span id="cb23-87"><a href="#cb23-87" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate recommendations based on analysis"""</span></span>
<span id="cb23-88"><a href="#cb23-88" aria-hidden="true" tabindex="-1"></a>        recommendations <span class="op">=</span> []</span>
<span id="cb23-89"><a href="#cb23-89" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-90"><a href="#cb23-90" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">any</span>(<span class="st">"rank utilization"</span> <span class="kw">in</span> issue <span class="cf">for</span> issue <span class="kw">in</span> issues):</span>
<span id="cb23-91"><a href="#cb23-91" aria-hidden="true" tabindex="-1"></a>            recommendations.append(<span class="st">"Consider reducing LoRA rank to improve efficiency"</span>)</span>
<span id="cb23-92"><a href="#cb23-92" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-93"><a href="#cb23-93" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">any</span>(<span class="st">"weight norms"</span> <span class="kw">in</span> warning <span class="cf">for</span> warning <span class="kw">in</span> warnings):</span>
<span id="cb23-94"><a href="#cb23-94" aria-hidden="true" tabindex="-1"></a>            recommendations.append(<span class="st">"Apply stronger weight regularization or gradient clipping"</span>)</span>
<span id="cb23-95"><a href="#cb23-95" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-96"><a href="#cb23-96" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">any</span>(<span class="st">"variance"</span> <span class="kw">in</span> warning <span class="cf">for</span> warning <span class="kw">in</span> warnings):</span>
<span id="cb23-97"><a href="#cb23-97" aria-hidden="true" tabindex="-1"></a>            recommendations.append(<span class="st">"Use different ranks for different module types"</span>)</span>
<span id="cb23-98"><a href="#cb23-98" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-99"><a href="#cb23-99" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> issues <span class="kw">and</span> <span class="kw">not</span> warnings:</span>
<span id="cb23-100"><a href="#cb23-100" aria-hidden="true" tabindex="-1"></a>            recommendations.append(<span class="st">"LoRA configuration appears optimal"</span>)</span>
<span id="cb23-101"><a href="#cb23-101" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-102"><a href="#cb23-102" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> recommendations</span>
<span id="cb23-103"><a href="#cb23-103" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-104"><a href="#cb23-104" aria-hidden="true" tabindex="-1"></a><span class="co"># Debugging demonstration</span></span>
<span id="cb23-105"><a href="#cb23-105" aria-hidden="true" tabindex="-1"></a>debugger <span class="op">=</span> LoRADebugger(<span class="va">None</span>, <span class="st">"medical_vqa_adapter"</span>)  <span class="co"># Would use real model</span></span>
<span id="cb23-106"><a href="#cb23-106" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-107"><a href="#cb23-107" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"LoRA Debugging Analysis:"</span>)</span>
<span id="cb23-108"><a href="#cb23-108" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">25</span>)</span>
<span id="cb23-109"><a href="#cb23-109" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-110"><a href="#cb23-110" aria-hidden="true" tabindex="-1"></a><span class="co"># Generate health report</span></span>
<span id="cb23-111"><a href="#cb23-111" aria-hidden="true" tabindex="-1"></a>health_report <span class="op">=</span> debugger.generate_health_report()</span>
<span id="cb23-112"><a href="#cb23-112" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-113"><a href="#cb23-113" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Adapter: </span><span class="sc">{</span>health_report[<span class="st">'adapter_name'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb23-114"><a href="#cb23-114" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Health Status: </span><span class="sc">{</span>health_report[<span class="st">'health_status'</span>]<span class="sc">.</span>title()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb23-115"><a href="#cb23-115" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-116"><a href="#cb23-116" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Rank Utilization Summary:"</span>)</span>
<span id="cb23-117"><a href="#cb23-117" aria-hidden="true" tabindex="-1"></a>rank_util <span class="op">=</span> health_report[<span class="st">'rank_utilization'</span>]</span>
<span id="cb23-118"><a href="#cb23-118" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"  Mean: </span><span class="sc">{</span>rank_util[<span class="st">'mean_utilization'</span>]<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb23-119"><a href="#cb23-119" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"  Std:  </span><span class="sc">{</span>rank_util[<span class="st">'std_utilization'</span>]<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb23-120"><a href="#cb23-120" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"  Range: </span><span class="sc">{</span>rank_util[<span class="st">'min_utilization'</span>]<span class="sc">:.3f}</span><span class="ss"> - </span><span class="sc">{</span>rank_util[<span class="st">'max_utilization'</span>]<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb23-121"><a href="#cb23-121" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-122"><a href="#cb23-122" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> health_report[<span class="st">'issues'</span>]:</span>
<span id="cb23-123"><a href="#cb23-123" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Issues Found:"</span>)</span>
<span id="cb23-124"><a href="#cb23-124" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> issue <span class="kw">in</span> health_report[<span class="st">'issues'</span>]:</span>
<span id="cb23-125"><a href="#cb23-125" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"  ❌ </span><span class="sc">{</span>issue<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb23-126"><a href="#cb23-126" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-127"><a href="#cb23-127" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> health_report[<span class="st">'warnings'</span>]:</span>
<span id="cb23-128"><a href="#cb23-128" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Warnings:"</span>)</span>
<span id="cb23-129"><a href="#cb23-129" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> warning <span class="kw">in</span> health_report[<span class="st">'warnings'</span>]:</span>
<span id="cb23-130"><a href="#cb23-130" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"  ⚠️  </span><span class="sc">{</span>warning<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb23-131"><a href="#cb23-131" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-132"><a href="#cb23-132" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Recommendations:"</span>)</span>
<span id="cb23-133"><a href="#cb23-133" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> rec <span class="kw">in</span> health_report[<span class="st">'recommendations'</span>]:</span>
<span id="cb23-134"><a href="#cb23-134" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  💡 </span><span class="sc">{</span>rec<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>LoRA Debugging Analysis:
=========================
Adapter: medical_vqa_adapter
Health Status: Healthy

Rank Utilization Summary:
  Mean: 0.475
  Std:  0.211
  Range: 0.250 - 0.750

Recommendations:
  💡 LoRA configuration appears optimal</code></pre>
</div>
</div>
</section>
</section>
<section id="production-deployment" class="level2">
<h2 class="anchored" data-anchor-id="production-deployment" id="production-deployment">Production Deployment</h2>
<section id="model-management-system" class="level3">
<h3 class="anchored" data-anchor-id="model-management-system" id="model-management-system">Model Management System</h3>
<div id="production-deployment" class="cell" data-execution_count="17">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Dict, Any, Optional, Union</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRAModelManager:</span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Production-ready LoRA model management system"""</span></span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model_path: <span class="bu">str</span>, device: <span class="bu">str</span> <span class="op">=</span> <span class="st">"auto"</span>):</span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model_path <span class="op">=</span> base_model_path</span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> <span class="va">self</span>._setup_device(device)</span>
<span id="cb25-12"><a href="#cb25-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> <span class="va">None</span></span>
<span id="cb25-13"><a href="#cb25-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.active_adapters <span class="op">=</span> {}</span>
<span id="cb25-14"><a href="#cb25-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.adapter_configs <span class="op">=</span> {}</span>
<span id="cb25-15"><a href="#cb25-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-16"><a href="#cb25-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Performance monitoring</span></span>
<span id="cb25-17"><a href="#cb25-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb25-18"><a href="#cb25-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.total_inference_time <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb25-19"><a href="#cb25-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.error_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb25-20"><a href="#cb25-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-21"><a href="#cb25-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup logging</span></span>
<span id="cb25-22"><a href="#cb25-22" aria-hidden="true" tabindex="-1"></a>        logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb25-23"><a href="#cb25-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb25-24"><a href="#cb25-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-25"><a href="#cb25-25" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"LoRA Model Manager initialized"</span>)</span>
<span id="cb25-26"><a href="#cb25-26" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Device: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>device<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-27"><a href="#cb25-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-28"><a href="#cb25-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _setup_device(<span class="va">self</span>, device: <span class="bu">str</span>) <span class="op">-&gt;</span> <span class="bu">str</span>:</span>
<span id="cb25-29"><a href="#cb25-29" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup compute device"""</span></span>
<span id="cb25-30"><a href="#cb25-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> device <span class="op">==</span> <span class="st">"auto"</span>:</span>
<span id="cb25-31"><a href="#cb25-31" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb25-32"><a href="#cb25-32" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="st">"cuda"</span></span>
<span id="cb25-33"><a href="#cb25-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb25-34"><a href="#cb25-34" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="st">"cpu"</span></span>
<span id="cb25-35"><a href="#cb25-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> device</span>
<span id="cb25-36"><a href="#cb25-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-37"><a href="#cb25-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_adapter(<span class="va">self</span>, adapter_name: <span class="bu">str</span>, adapter_path: <span class="bu">str</span>, config: Optional[Dict] <span class="op">=</span> <span class="va">None</span>):</span>
<span id="cb25-38"><a href="#cb25-38" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load a LoRA adapter"""</span></span>
<span id="cb25-39"><a href="#cb25-39" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(<span class="ss">f"Loading adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' from </span><span class="sc">{</span>adapter_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-40"><a href="#cb25-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-41"><a href="#cb25-41" aria-hidden="true" tabindex="-1"></a>        default_config <span class="op">=</span> {</span>
<span id="cb25-42"><a href="#cb25-42" aria-hidden="true" tabindex="-1"></a>            <span class="st">"rank"</span>: <span class="dv">16</span>,</span>
<span id="cb25-43"><a href="#cb25-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">"alpha"</span>: <span class="dv">16</span>,</span>
<span id="cb25-44"><a href="#cb25-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">"target_modules"</span>: [<span class="st">"q_proj"</span>, <span class="st">"k_proj"</span>, <span class="st">"v_proj"</span>],</span>
<span id="cb25-45"><a href="#cb25-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">"task_type"</span>: <span class="st">"multimodal"</span></span>
<span id="cb25-46"><a href="#cb25-46" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb25-47"><a href="#cb25-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-48"><a href="#cb25-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Merge defaults with provided config</span></span>
<span id="cb25-49"><a href="#cb25-49" aria-hidden="true" tabindex="-1"></a>        adapter_config <span class="op">=</span> {<span class="op">**</span>default_config, <span class="op">**</span>(config <span class="kw">or</span> {})}</span>
<span id="cb25-50"><a href="#cb25-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-51"><a href="#cb25-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Store adapter (in real implementation, would load actual weights)</span></span>
<span id="cb25-52"><a href="#cb25-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.active_adapters[adapter_name] <span class="op">=</span> {</span>
<span id="cb25-53"><a href="#cb25-53" aria-hidden="true" tabindex="-1"></a>            <span class="st">"path"</span>: adapter_path,</span>
<span id="cb25-54"><a href="#cb25-54" aria-hidden="true" tabindex="-1"></a>            <span class="st">"loaded_at"</span>: time.time(),</span>
<span id="cb25-55"><a href="#cb25-55" aria-hidden="true" tabindex="-1"></a>            <span class="st">"parameters"</span>: adapter_config[<span class="st">"rank"</span>] <span class="op">*</span> <span class="dv">768</span> <span class="op">*</span> <span class="dv">2</span> <span class="op">*</span> <span class="bu">len</span>(adapter_config[<span class="st">"target_modules"</span>])</span>
<span id="cb25-56"><a href="#cb25-56" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb25-57"><a href="#cb25-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.adapter_configs[adapter_name] <span class="op">=</span> adapter_config</span>
<span id="cb25-58"><a href="#cb25-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-59"><a href="#cb25-59" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(<span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' loaded successfully"</span>)</span>
<span id="cb25-60"><a href="#cb25-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">True</span></span>
<span id="cb25-61"><a href="#cb25-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-62"><a href="#cb25-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-63"><a href="#cb25-63" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unload_adapter(<span class="va">self</span>, adapter_name: <span class="bu">str</span>):</span>
<span id="cb25-64"><a href="#cb25-64" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Unload a LoRA adapter to free memory"""</span></span>
<span id="cb25-65"><a href="#cb25-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> adapter_name <span class="kw">in</span> <span class="va">self</span>.active_adapters:</span>
<span id="cb25-66"><a href="#cb25-66" aria-hidden="true" tabindex="-1"></a>            <span class="kw">del</span> <span class="va">self</span>.active_adapters[adapter_name]</span>
<span id="cb25-67"><a href="#cb25-67" aria-hidden="true" tabindex="-1"></a>            <span class="kw">del</span> <span class="va">self</span>.adapter_configs[adapter_name]</span>
<span id="cb25-68"><a href="#cb25-68" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.info(<span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' unloaded"</span>)</span>
<span id="cb25-69"><a href="#cb25-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">True</span></span>
<span id="cb25-70"><a href="#cb25-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb25-71"><a href="#cb25-71" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.warning(<span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' not found"</span>)</span>
<span id="cb25-72"><a href="#cb25-72" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">False</span></span>
<span id="cb25-73"><a href="#cb25-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-74"><a href="#cb25-74" aria-hidden="true" tabindex="-1"></a>    <span class="at">@contextmanager</span></span>
<span id="cb25-75"><a href="#cb25-75" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> use_adapter(<span class="va">self</span>, adapter_name: <span class="bu">str</span>):</span>
<span id="cb25-76"><a href="#cb25-76" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Context manager for temporarily using an adapter"""</span></span>
<span id="cb25-77"><a href="#cb25-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> adapter_name <span class="kw">not</span> <span class="kw">in</span> <span class="va">self</span>.active_adapters:</span>
<span id="cb25-78"><a href="#cb25-78" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' not loaded"</span>)</span>
<span id="cb25-79"><a href="#cb25-79" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-80"><a href="#cb25-80" aria-hidden="true" tabindex="-1"></a>        <span class="co"># In real implementation, would apply adapter weights</span></span>
<span id="cb25-81"><a href="#cb25-81" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.debug(<span class="ss">f"Applying adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb25-82"><a href="#cb25-82" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-83"><a href="#cb25-83" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb25-84"><a href="#cb25-84" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> adapter_name</span>
<span id="cb25-85"><a href="#cb25-85" aria-hidden="true" tabindex="-1"></a>        <span class="cf">finally</span>:</span>
<span id="cb25-86"><a href="#cb25-86" aria-hidden="true" tabindex="-1"></a>            <span class="co"># In real implementation, would restore original weights</span></span>
<span id="cb25-87"><a href="#cb25-87" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.debug(<span class="ss">f"Restored from adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">'"</span>)</span>
<span id="cb25-88"><a href="#cb25-88" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-89"><a href="#cb25-89" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inference(<span class="va">self</span>, inputs: Dict[<span class="bu">str</span>, Any], adapter_name: Optional[<span class="bu">str</span>] <span class="op">=</span> <span class="va">None</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb25-90"><a href="#cb25-90" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Perform inference with optional adapter"""</span></span>
<span id="cb25-91"><a href="#cb25-91" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb25-92"><a href="#cb25-92" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-93"><a href="#cb25-93" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb25-94"><a href="#cb25-94" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> adapter_name:</span>
<span id="cb25-95"><a href="#cb25-95" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> <span class="va">self</span>.use_adapter(adapter_name):</span>
<span id="cb25-96"><a href="#cb25-96" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Simulate inference with adapter</span></span>
<span id="cb25-97"><a href="#cb25-97" aria-hidden="true" tabindex="-1"></a>                    time.sleep(<span class="fl">0.01</span>)  <span class="co"># Simulate processing time</span></span>
<span id="cb25-98"><a href="#cb25-98" aria-hidden="true" tabindex="-1"></a>                    outputs <span class="op">=</span> {<span class="st">"prediction"</span>: <span class="st">"sample_output"</span>, <span class="st">"confidence"</span>: <span class="fl">0.95</span>}</span>
<span id="cb25-99"><a href="#cb25-99" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb25-100"><a href="#cb25-100" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Simulate base model inference</span></span>
<span id="cb25-101"><a href="#cb25-101" aria-hidden="true" tabindex="-1"></a>                time.sleep(<span class="fl">0.008</span>)  <span class="co"># Slightly faster without adapter</span></span>
<span id="cb25-102"><a href="#cb25-102" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> {<span class="st">"prediction"</span>: <span class="st">"base_output"</span>, <span class="st">"confidence"</span>: <span class="fl">0.85</span>}</span>
<span id="cb25-103"><a href="#cb25-103" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-104"><a href="#cb25-104" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update performance metrics</span></span>
<span id="cb25-105"><a href="#cb25-105" aria-hidden="true" tabindex="-1"></a>            inference_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb25-106"><a href="#cb25-106" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.request_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb25-107"><a href="#cb25-107" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.total_inference_time <span class="op">+=</span> inference_time</span>
<span id="cb25-108"><a href="#cb25-108" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-109"><a href="#cb25-109" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb25-110"><a href="#cb25-110" aria-hidden="true" tabindex="-1"></a>                <span class="st">'outputs'</span>: outputs,</span>
<span id="cb25-111"><a href="#cb25-111" aria-hidden="true" tabindex="-1"></a>                <span class="st">'inference_time'</span>: inference_time,</span>
<span id="cb25-112"><a href="#cb25-112" aria-hidden="true" tabindex="-1"></a>                <span class="st">'adapter_used'</span>: adapter_name,</span>
<span id="cb25-113"><a href="#cb25-113" aria-hidden="true" tabindex="-1"></a>                <span class="st">'request_id'</span>: <span class="va">self</span>.request_count</span>
<span id="cb25-114"><a href="#cb25-114" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb25-115"><a href="#cb25-115" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-116"><a href="#cb25-116" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb25-117"><a href="#cb25-117" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.error_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb25-118"><a href="#cb25-118" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.error(<span class="ss">f"Inference failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-119"><a href="#cb25-119" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb25-120"><a href="#cb25-120" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-121"><a href="#cb25-121" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_performance_stats(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, <span class="bu">float</span>]:</span>
<span id="cb25-122"><a href="#cb25-122" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get performance statistics"""</span></span>
<span id="cb25-123"><a href="#cb25-123" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.request_count <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb25-124"><a href="#cb25-124" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {<span class="st">'requests'</span>: <span class="dv">0</span>, <span class="st">'avg_time'</span>: <span class="dv">0</span>, <span class="st">'total_time'</span>: <span class="dv">0</span>, <span class="st">'error_rate'</span>: <span class="dv">0</span>}</span>
<span id="cb25-125"><a href="#cb25-125" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-126"><a href="#cb25-126" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb25-127"><a href="#cb25-127" aria-hidden="true" tabindex="-1"></a>            <span class="st">'requests'</span>: <span class="va">self</span>.request_count,</span>
<span id="cb25-128"><a href="#cb25-128" aria-hidden="true" tabindex="-1"></a>            <span class="st">'avg_time'</span>: <span class="va">self</span>.total_inference_time <span class="op">/</span> <span class="va">self</span>.request_count,</span>
<span id="cb25-129"><a href="#cb25-129" aria-hidden="true" tabindex="-1"></a>            <span class="st">'total_time'</span>: <span class="va">self</span>.total_inference_time,</span>
<span id="cb25-130"><a href="#cb25-130" aria-hidden="true" tabindex="-1"></a>            <span class="st">'requests_per_second'</span>: <span class="va">self</span>.request_count <span class="op">/</span> <span class="va">self</span>.total_inference_time <span class="cf">if</span> <span class="va">self</span>.total_inference_time <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb25-131"><a href="#cb25-131" aria-hidden="true" tabindex="-1"></a>            <span class="st">'error_rate'</span>: <span class="va">self</span>.error_count <span class="op">/</span> <span class="va">self</span>.request_count,</span>
<span id="cb25-132"><a href="#cb25-132" aria-hidden="true" tabindex="-1"></a>            <span class="st">'error_count'</span>: <span class="va">self</span>.error_count</span>
<span id="cb25-133"><a href="#cb25-133" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb25-134"><a href="#cb25-134" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-135"><a href="#cb25-135" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> health_check(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb25-136"><a href="#cb25-136" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Perform system health check"""</span></span>
<span id="cb25-137"><a href="#cb25-137" aria-hidden="true" tabindex="-1"></a>        health_status <span class="op">=</span> {</span>
<span id="cb25-138"><a href="#cb25-138" aria-hidden="true" tabindex="-1"></a>            <span class="st">'status'</span>: <span class="st">'healthy'</span>,</span>
<span id="cb25-139"><a href="#cb25-139" aria-hidden="true" tabindex="-1"></a>            <span class="st">'active_adapters'</span>: <span class="bu">list</span>(<span class="va">self</span>.active_adapters.keys()),</span>
<span id="cb25-140"><a href="#cb25-140" aria-hidden="true" tabindex="-1"></a>            <span class="st">'device'</span>: <span class="bu">str</span>(<span class="va">self</span>.device),</span>
<span id="cb25-141"><a href="#cb25-141" aria-hidden="true" tabindex="-1"></a>            <span class="st">'performance'</span>: <span class="va">self</span>.get_performance_stats(),</span>
<span id="cb25-142"><a href="#cb25-142" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_usage'</span>: <span class="va">self</span>._get_memory_usage()</span>
<span id="cb25-143"><a href="#cb25-143" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb25-144"><a href="#cb25-144" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-145"><a href="#cb25-145" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for issues</span></span>
<span id="cb25-146"><a href="#cb25-146" aria-hidden="true" tabindex="-1"></a>        perf_stats <span class="op">=</span> health_status[<span class="st">'performance'</span>]</span>
<span id="cb25-147"><a href="#cb25-147" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> perf_stats[<span class="st">'error_rate'</span>] <span class="op">&gt;</span> <span class="fl">0.05</span>:  <span class="co"># 5% error threshold</span></span>
<span id="cb25-148"><a href="#cb25-148" aria-hidden="true" tabindex="-1"></a>            health_status[<span class="st">'status'</span>] <span class="op">=</span> <span class="st">'degraded'</span></span>
<span id="cb25-149"><a href="#cb25-149" aria-hidden="true" tabindex="-1"></a>            health_status[<span class="st">'issues'</span>] <span class="op">=</span> [<span class="st">'High error rate detected'</span>]</span>
<span id="cb25-150"><a href="#cb25-150" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-151"><a href="#cb25-151" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> perf_stats[<span class="st">'avg_time'</span>] <span class="op">&gt;</span> <span class="fl">1.0</span>:  <span class="co"># 1 second threshold</span></span>
<span id="cb25-152"><a href="#cb25-152" aria-hidden="true" tabindex="-1"></a>            health_status[<span class="st">'status'</span>] <span class="op">=</span> <span class="st">'degraded'</span></span>
<span id="cb25-153"><a href="#cb25-153" aria-hidden="true" tabindex="-1"></a>            health_status.setdefault(<span class="st">'issues'</span>, []).append(<span class="st">'High latency detected'</span>)</span>
<span id="cb25-154"><a href="#cb25-154" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-155"><a href="#cb25-155" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> health_status</span>
<span id="cb25-156"><a href="#cb25-156" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-157"><a href="#cb25-157" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _get_memory_usage(<span class="va">self</span>):</span>
<span id="cb25-158"><a href="#cb25-158" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get memory usage statistics"""</span></span>
<span id="cb25-159"><a href="#cb25-159" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate memory usage</span></span>
<span id="cb25-160"><a href="#cb25-160" aria-hidden="true" tabindex="-1"></a>        total_adapters <span class="op">=</span> <span class="bu">len</span>(<span class="va">self</span>.active_adapters)</span>
<span id="cb25-161"><a href="#cb25-161" aria-hidden="true" tabindex="-1"></a>        estimated_memory <span class="op">=</span> total_adapters <span class="op">*</span> <span class="fl">0.1</span>  <span class="co"># GB per adapter</span></span>
<span id="cb25-162"><a href="#cb25-162" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-163"><a href="#cb25-163" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb25-164"><a href="#cb25-164" aria-hidden="true" tabindex="-1"></a>            <span class="st">'estimated_adapter_memory_gb'</span>: estimated_memory,</span>
<span id="cb25-165"><a href="#cb25-165" aria-hidden="true" tabindex="-1"></a>            <span class="st">'active_adapters'</span>: total_adapters</span>
<span id="cb25-166"><a href="#cb25-166" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb25-167"><a href="#cb25-167" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-168"><a href="#cb25-168" aria-hidden="true" tabindex="-1"></a><span class="co"># Production deployment demonstration</span></span>
<span id="cb25-169"><a href="#cb25-169" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Production LoRA Deployment Demo:"</span>)</span>
<span id="cb25-170"><a href="#cb25-170" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">35</span>)</span>
<span id="cb25-171"><a href="#cb25-171" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-172"><a href="#cb25-172" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize model manager</span></span>
<span id="cb25-173"><a href="#cb25-173" aria-hidden="true" tabindex="-1"></a>manager <span class="op">=</span> LoRAModelManager(<span class="st">"path/to/base/model"</span>, device<span class="op">=</span><span class="st">"cuda"</span>)</span>
<span id="cb25-174"><a href="#cb25-174" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-175"><a href="#cb25-175" aria-hidden="true" tabindex="-1"></a><span class="co"># Load multiple adapters</span></span>
<span id="cb25-176"><a href="#cb25-176" aria-hidden="true" tabindex="-1"></a>adapters_to_load <span class="op">=</span> [</span>
<span id="cb25-177"><a href="#cb25-177" aria-hidden="true" tabindex="-1"></a>    {<span class="st">"name"</span>: <span class="st">"medical_adapter"</span>, <span class="st">"path"</span>: <span class="st">"adapters/medical"</span>, <span class="st">"config"</span>: {<span class="st">"rank"</span>: <span class="dv">32</span>, <span class="st">"task"</span>: <span class="st">"medical_vqa"</span>}},</span>
<span id="cb25-178"><a href="#cb25-178" aria-hidden="true" tabindex="-1"></a>    {<span class="st">"name"</span>: <span class="st">"general_adapter"</span>, <span class="st">"path"</span>: <span class="st">"adapters/general"</span>, <span class="st">"config"</span>: {<span class="st">"rank"</span>: <span class="dv">16</span>, <span class="st">"task"</span>: <span class="st">"general_vqa"</span>}},</span>
<span id="cb25-179"><a href="#cb25-179" aria-hidden="true" tabindex="-1"></a>    {<span class="st">"name"</span>: <span class="st">"multilingual_adapter"</span>, <span class="st">"path"</span>: <span class="st">"adapters/multilingual"</span>, <span class="st">"config"</span>: {<span class="st">"rank"</span>: <span class="dv">24</span>, <span class="st">"task"</span>: <span class="st">"multilingual"</span>}}</span>
<span id="cb25-180"><a href="#cb25-180" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb25-181"><a href="#cb25-181" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-182"><a href="#cb25-182" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> adapter <span class="kw">in</span> adapters_to_load:</span>
<span id="cb25-183"><a href="#cb25-183" aria-hidden="true" tabindex="-1"></a>    manager.load_adapter(adapter[<span class="st">"name"</span>], adapter[<span class="st">"path"</span>], adapter[<span class="st">"config"</span>])</span>
<span id="cb25-184"><a href="#cb25-184" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-185"><a href="#cb25-185" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Loaded </span><span class="sc">{</span><span class="bu">len</span>(manager.active_adapters)<span class="sc">}</span><span class="ss"> adapters"</span>)</span>
<span id="cb25-186"><a href="#cb25-186" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-187"><a href="#cb25-187" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate inference requests</span></span>
<span id="cb25-188"><a href="#cb25-188" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Simulating inference requests..."</span>)</span>
<span id="cb25-189"><a href="#cb25-189" aria-hidden="true" tabindex="-1"></a>test_inputs <span class="op">=</span> {<span class="st">"image"</span>: <span class="st">"test_image.jpg"</span>, <span class="st">"text"</span>: <span class="st">"What is in this image?"</span>}</span>
<span id="cb25-190"><a href="#cb25-190" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-191"><a href="#cb25-191" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb25-192"><a href="#cb25-192" aria-hidden="true" tabindex="-1"></a>    adapter <span class="op">=</span> [<span class="st">"medical_adapter"</span>, <span class="st">"general_adapter"</span>, <span class="va">None</span>][i <span class="op">%</span> <span class="dv">3</span>]</span>
<span id="cb25-193"><a href="#cb25-193" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> manager.inference(test_inputs, adapter)</span>
<span id="cb25-194"><a href="#cb25-194" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Request </span><span class="sc">{</span>result[<span class="st">'request_id'</span>]<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>result[<span class="st">'inference_time'</span>]<span class="sc">:.3f}</span><span class="ss">s (</span><span class="sc">{</span><span class="st">'with '</span> <span class="op">+</span> result[<span class="st">'adapter_used'</span>] <span class="cf">if</span> result[<span class="st">'adapter_used'</span>] <span class="cf">else</span> <span class="st">'base model'</span><span class="sc">}</span><span class="ss">)"</span>)</span>
<span id="cb25-195"><a href="#cb25-195" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-196"><a href="#cb25-196" aria-hidden="true" tabindex="-1"></a><span class="co"># Check system health</span></span>
<span id="cb25-197"><a href="#cb25-197" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">System Health Check:"</span>)</span>
<span id="cb25-198"><a href="#cb25-198" aria-hidden="true" tabindex="-1"></a>health <span class="op">=</span> manager.health_check()</span>
<span id="cb25-199"><a href="#cb25-199" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Status: </span><span class="sc">{</span>health[<span class="st">'status'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-200"><a href="#cb25-200" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Active adapters: </span><span class="sc">{</span><span class="bu">len</span>(health[<span class="st">'active_adapters'</span>])<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb25-201"><a href="#cb25-201" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Average latency: </span><span class="sc">{</span>health[<span class="st">'performance'</span>][<span class="st">'avg_time'</span>]<span class="sc">:.3f}</span><span class="ss">s"</span>)</span>
<span id="cb25-202"><a href="#cb25-202" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Error rate: </span><span class="sc">{</span>health[<span class="st">'performance'</span>][<span class="st">'error_rate'</span>]<span class="sc">:.1%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stderr">
<pre><code>INFO:__main__:Loading adapter 'medical_adapter' from adapters/medical
INFO:__main__:Adapter 'medical_adapter' loaded successfully
INFO:__main__:Loading adapter 'general_adapter' from adapters/general
INFO:__main__:Adapter 'general_adapter' loaded successfully
INFO:__main__:Loading adapter 'multilingual_adapter' from adapters/multilingual
INFO:__main__:Adapter 'multilingual_adapter' loaded successfully</code></pre>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Production LoRA Deployment Demo:
===================================
LoRA Model Manager initialized
Device: cuda

Loaded 3 adapters

Simulating inference requests...
Request 1: 0.013s (with medical_adapter)
Request 2: 0.013s (with general_adapter)
Request 3: 0.010s (base model)
Request 4: 0.013s (with medical_adapter)
Request 5: 0.013s (with general_adapter)

System Health Check:
Status: healthy
Active adapters: 3
Average latency: 0.012s
Error rate: 0.0%</code></pre>
</div>
</div>
</section>
<section id="api-server-implementation" class="level3">
<h3 class="anchored" data-anchor-id="api-server-implementation" id="api-server-implementation">API Server Implementation</h3>
<div id="api-server" class="cell" data-execution_count="18">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRAAPIServer:</span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""FastAPI-style server for LoRA model serving"""</span></span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_manager: LoRAModelManager):</span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model_manager <span class="op">=</span> model_manager</span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_history <span class="op">=</span> []</span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"LoRA API Server initialized"</span>)</span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Available endpoints:"</span>)</span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"  POST /inference - Perform inference"</span>)</span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"  POST /load_adapter - Load new adapter"</span>)</span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"  DELETE /adapter/</span><span class="sc">{name}</span><span class="st"> - Unload adapter"</span>)</span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"  GET /health - Health check"</span>)</span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"  GET /adapters - List adapters"</span>)</span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inference_endpoint(<span class="va">self</span>, request_data: Dict[<span class="bu">str</span>, Any]) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle inference requests"""</span></span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>            inputs <span class="op">=</span> request_data.get(<span class="st">"inputs"</span>, {})</span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a>            adapter_name <span class="op">=</span> request_data.get(<span class="st">"adapter_name"</span>)</span>
<span id="cb28-21"><a href="#cb28-21" aria-hidden="true" tabindex="-1"></a>            parameters <span class="op">=</span> request_data.get(<span class="st">"parameters"</span>, {})</span>
<span id="cb28-22"><a href="#cb28-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-23"><a href="#cb28-23" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Perform inference</span></span>
<span id="cb28-24"><a href="#cb28-24" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> <span class="va">self</span>.model_manager.inference(inputs, adapter_name)</span>
<span id="cb28-25"><a href="#cb28-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-26"><a href="#cb28-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Log request</span></span>
<span id="cb28-27"><a href="#cb28-27" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.request_history.append({</span>
<span id="cb28-28"><a href="#cb28-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">"timestamp"</span>: time.time(),</span>
<span id="cb28-29"><a href="#cb28-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">"adapter"</span>: adapter_name,</span>
<span id="cb28-30"><a href="#cb28-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">"latency"</span>: result[<span class="st">"inference_time"</span>],</span>
<span id="cb28-31"><a href="#cb28-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"success"</span></span>
<span id="cb28-32"><a href="#cb28-32" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb28-33"><a href="#cb28-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-34"><a href="#cb28-34" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb28-35"><a href="#cb28-35" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"success"</span>,</span>
<span id="cb28-36"><a href="#cb28-36" aria-hidden="true" tabindex="-1"></a>                <span class="st">"outputs"</span>: result[<span class="st">"outputs"</span>],</span>
<span id="cb28-37"><a href="#cb28-37" aria-hidden="true" tabindex="-1"></a>                <span class="st">"inference_time"</span>: result[<span class="st">"inference_time"</span>],</span>
<span id="cb28-38"><a href="#cb28-38" aria-hidden="true" tabindex="-1"></a>                <span class="st">"adapter_used"</span>: result[<span class="st">"adapter_used"</span>],</span>
<span id="cb28-39"><a href="#cb28-39" aria-hidden="true" tabindex="-1"></a>                <span class="st">"request_id"</span>: result[<span class="st">"request_id"</span>]</span>
<span id="cb28-40"><a href="#cb28-40" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb28-41"><a href="#cb28-41" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-42"><a href="#cb28-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb28-43"><a href="#cb28-43" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Log error</span></span>
<span id="cb28-44"><a href="#cb28-44" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.request_history.append({</span>
<span id="cb28-45"><a href="#cb28-45" aria-hidden="true" tabindex="-1"></a>                <span class="st">"timestamp"</span>: time.time(),</span>
<span id="cb28-46"><a href="#cb28-46" aria-hidden="true" tabindex="-1"></a>                <span class="st">"adapter"</span>: request_data.get(<span class="st">"adapter_name"</span>),</span>
<span id="cb28-47"><a href="#cb28-47" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-48"><a href="#cb28-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">"error"</span>: <span class="bu">str</span>(e)</span>
<span id="cb28-49"><a href="#cb28-49" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb28-50"><a href="#cb28-50" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-51"><a href="#cb28-51" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb28-52"><a href="#cb28-52" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-53"><a href="#cb28-53" aria-hidden="true" tabindex="-1"></a>                <span class="st">"error"</span>: <span class="bu">str</span>(e),</span>
<span id="cb28-54"><a href="#cb28-54" aria-hidden="true" tabindex="-1"></a>                <span class="st">"request_id"</span>: <span class="va">None</span></span>
<span id="cb28-55"><a href="#cb28-55" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb28-56"><a href="#cb28-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-57"><a href="#cb28-57" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_adapter_endpoint(<span class="va">self</span>, request_data: Dict[<span class="bu">str</span>, Any]) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-58"><a href="#cb28-58" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle adapter loading requests"""</span></span>
<span id="cb28-59"><a href="#cb28-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb28-60"><a href="#cb28-60" aria-hidden="true" tabindex="-1"></a>            adapter_name <span class="op">=</span> request_data[<span class="st">"adapter_name"</span>]</span>
<span id="cb28-61"><a href="#cb28-61" aria-hidden="true" tabindex="-1"></a>            adapter_path <span class="op">=</span> request_data[<span class="st">"adapter_path"</span>]</span>
<span id="cb28-62"><a href="#cb28-62" aria-hidden="true" tabindex="-1"></a>            config <span class="op">=</span> request_data.get(<span class="st">"config"</span>)</span>
<span id="cb28-63"><a href="#cb28-63" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-64"><a href="#cb28-64" aria-hidden="true" tabindex="-1"></a>            success <span class="op">=</span> <span class="va">self</span>.model_manager.load_adapter(adapter_name, adapter_path, config)</span>
<span id="cb28-65"><a href="#cb28-65" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-66"><a href="#cb28-66" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> success:</span>
<span id="cb28-67"><a href="#cb28-67" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> {</span>
<span id="cb28-68"><a href="#cb28-68" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"status"</span>: <span class="st">"success"</span>,</span>
<span id="cb28-69"><a href="#cb28-69" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"message"</span>: <span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' loaded successfully"</span></span>
<span id="cb28-70"><a href="#cb28-70" aria-hidden="true" tabindex="-1"></a>                }</span>
<span id="cb28-71"><a href="#cb28-71" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb28-72"><a href="#cb28-72" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> {</span>
<span id="cb28-73"><a href="#cb28-73" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-74"><a href="#cb28-74" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"message"</span>: <span class="ss">f"Failed to load adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">'"</span></span>
<span id="cb28-75"><a href="#cb28-75" aria-hidden="true" tabindex="-1"></a>                }</span>
<span id="cb28-76"><a href="#cb28-76" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb28-77"><a href="#cb28-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb28-78"><a href="#cb28-78" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb28-79"><a href="#cb28-79" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-80"><a href="#cb28-80" aria-hidden="true" tabindex="-1"></a>                <span class="st">"message"</span>: <span class="bu">str</span>(e)</span>
<span id="cb28-81"><a href="#cb28-81" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb28-82"><a href="#cb28-82" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-83"><a href="#cb28-83" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unload_adapter_endpoint(<span class="va">self</span>, adapter_name: <span class="bu">str</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-84"><a href="#cb28-84" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle adapter unloading requests"""</span></span>
<span id="cb28-85"><a href="#cb28-85" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb28-86"><a href="#cb28-86" aria-hidden="true" tabindex="-1"></a>            success <span class="op">=</span> <span class="va">self</span>.model_manager.unload_adapter(adapter_name)</span>
<span id="cb28-87"><a href="#cb28-87" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb28-88"><a href="#cb28-88" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> success:</span>
<span id="cb28-89"><a href="#cb28-89" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> {</span>
<span id="cb28-90"><a href="#cb28-90" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"status"</span>: <span class="st">"success"</span>, </span>
<span id="cb28-91"><a href="#cb28-91" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"message"</span>: <span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' unloaded successfully"</span></span>
<span id="cb28-92"><a href="#cb28-92" aria-hidden="true" tabindex="-1"></a>                }</span>
<span id="cb28-93"><a href="#cb28-93" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb28-94"><a href="#cb28-94" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> {</span>
<span id="cb28-95"><a href="#cb28-95" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-96"><a href="#cb28-96" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"message"</span>: <span class="ss">f"Adapter '</span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">' not found"</span></span>
<span id="cb28-97"><a href="#cb28-97" aria-hidden="true" tabindex="-1"></a>                }</span>
<span id="cb28-98"><a href="#cb28-98" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb28-99"><a href="#cb28-99" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb28-100"><a href="#cb28-100" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb28-101"><a href="#cb28-101" aria-hidden="true" tabindex="-1"></a>                <span class="st">"status"</span>: <span class="st">"error"</span>,</span>
<span id="cb28-102"><a href="#cb28-102" aria-hidden="true" tabindex="-1"></a>                <span class="st">"message"</span>: <span class="bu">str</span>(e)</span>
<span id="cb28-103"><a href="#cb28-103" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb28-104"><a href="#cb28-104" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-105"><a href="#cb28-105" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> health_endpoint(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-106"><a href="#cb28-106" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle health check requests"""</span></span>
<span id="cb28-107"><a href="#cb28-107" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.model_manager.health_check()</span>
<span id="cb28-108"><a href="#cb28-108" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-109"><a href="#cb28-109" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> list_adapters_endpoint(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-110"><a href="#cb28-110" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle adapter listing requests"""</span></span>
<span id="cb28-111"><a href="#cb28-111" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb28-112"><a href="#cb28-112" aria-hidden="true" tabindex="-1"></a>            <span class="st">"active_adapters"</span>: <span class="bu">list</span>(<span class="va">self</span>.model_manager.active_adapters.keys()),</span>
<span id="cb28-113"><a href="#cb28-113" aria-hidden="true" tabindex="-1"></a>            <span class="st">"adapter_configs"</span>: <span class="va">self</span>.model_manager.adapter_configs,</span>
<span id="cb28-114"><a href="#cb28-114" aria-hidden="true" tabindex="-1"></a>            <span class="st">"total_adapters"</span>: <span class="bu">len</span>(<span class="va">self</span>.model_manager.active_adapters)</span>
<span id="cb28-115"><a href="#cb28-115" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb28-116"><a href="#cb28-116" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-117"><a href="#cb28-117" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_metrics_endpoint(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb28-118"><a href="#cb28-118" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get detailed metrics"""</span></span>
<span id="cb28-119"><a href="#cb28-119" aria-hidden="true" tabindex="-1"></a>        recent_requests <span class="op">=</span> [req <span class="cf">for</span> req <span class="kw">in</span> <span class="va">self</span>.request_history </span>
<span id="cb28-120"><a href="#cb28-120" aria-hidden="true" tabindex="-1"></a>                          <span class="cf">if</span> time.time() <span class="op">-</span> req[<span class="st">"timestamp"</span>] <span class="op">&lt;</span> <span class="dv">3600</span>]  <span class="co"># Last hour</span></span>
<span id="cb28-121"><a href="#cb28-121" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-122"><a href="#cb28-122" aria-hidden="true" tabindex="-1"></a>        success_requests <span class="op">=</span> [req <span class="cf">for</span> req <span class="kw">in</span> recent_requests <span class="cf">if</span> req[<span class="st">"status"</span>] <span class="op">==</span> <span class="st">"success"</span>]</span>
<span id="cb28-123"><a href="#cb28-123" aria-hidden="true" tabindex="-1"></a>        error_requests <span class="op">=</span> [req <span class="cf">for</span> req <span class="kw">in</span> recent_requests <span class="cf">if</span> req[<span class="st">"status"</span>] <span class="op">==</span> <span class="st">"error"</span>]</span>
<span id="cb28-124"><a href="#cb28-124" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-125"><a href="#cb28-125" aria-hidden="true" tabindex="-1"></a>        metrics <span class="op">=</span> {</span>
<span id="cb28-126"><a href="#cb28-126" aria-hidden="true" tabindex="-1"></a>            <span class="st">"total_requests_last_hour"</span>: <span class="bu">len</span>(recent_requests),</span>
<span id="cb28-127"><a href="#cb28-127" aria-hidden="true" tabindex="-1"></a>            <span class="st">"successful_requests"</span>: <span class="bu">len</span>(success_requests),</span>
<span id="cb28-128"><a href="#cb28-128" aria-hidden="true" tabindex="-1"></a>            <span class="st">"failed_requests"</span>: <span class="bu">len</span>(error_requests),</span>
<span id="cb28-129"><a href="#cb28-129" aria-hidden="true" tabindex="-1"></a>            <span class="st">"success_rate"</span>: <span class="bu">len</span>(success_requests) <span class="op">/</span> <span class="bu">len</span>(recent_requests) <span class="cf">if</span> recent_requests <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb28-130"><a href="#cb28-130" aria-hidden="true" tabindex="-1"></a>            <span class="st">"average_latency"</span>: np.mean([req[<span class="st">"latency"</span>] <span class="cf">for</span> req <span class="kw">in</span> success_requests]) <span class="cf">if</span> success_requests <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb28-131"><a href="#cb28-131" aria-hidden="true" tabindex="-1"></a>            <span class="st">"adapter_usage"</span>: {}</span>
<span id="cb28-132"><a href="#cb28-132" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb28-133"><a href="#cb28-133" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-134"><a href="#cb28-134" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Adapter usage statistics</span></span>
<span id="cb28-135"><a href="#cb28-135" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> req <span class="kw">in</span> success_requests:</span>
<span id="cb28-136"><a href="#cb28-136" aria-hidden="true" tabindex="-1"></a>            adapter <span class="op">=</span> req.get(<span class="st">"adapter"</span>, <span class="st">"base_model"</span>)</span>
<span id="cb28-137"><a href="#cb28-137" aria-hidden="true" tabindex="-1"></a>            metrics[<span class="st">"adapter_usage"</span>][adapter] <span class="op">=</span> metrics[<span class="st">"adapter_usage"</span>].get(adapter, <span class="dv">0</span>) <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb28-138"><a href="#cb28-138" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-139"><a href="#cb28-139" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> metrics</span>
<span id="cb28-140"><a href="#cb28-140" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-141"><a href="#cb28-141" aria-hidden="true" tabindex="-1"></a><span class="co"># API server demonstration</span></span>
<span id="cb28-142"><a href="#cb28-142" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">API Server Demo:"</span>)</span>
<span id="cb28-143"><a href="#cb28-143" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">20</span>)</span>
<span id="cb28-144"><a href="#cb28-144" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-145"><a href="#cb28-145" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize API server</span></span>
<span id="cb28-146"><a href="#cb28-146" aria-hidden="true" tabindex="-1"></a>api_server <span class="op">=</span> LoRAAPIServer(manager)</span>
<span id="cb28-147"><a href="#cb28-147" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-148"><a href="#cb28-148" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate API requests</span></span>
<span id="cb28-149"><a href="#cb28-149" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Simulating API requests..."</span>)</span>
<span id="cb28-150"><a href="#cb28-150" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-151"><a href="#cb28-151" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Inference request</span></span>
<span id="cb28-152"><a href="#cb28-152" aria-hidden="true" tabindex="-1"></a>inference_request <span class="op">=</span> {</span>
<span id="cb28-153"><a href="#cb28-153" aria-hidden="true" tabindex="-1"></a>    <span class="st">"inputs"</span>: {<span class="st">"image"</span>: <span class="st">"test.jpg"</span>, <span class="st">"text"</span>: <span class="st">"Describe this image"</span>},</span>
<span id="cb28-154"><a href="#cb28-154" aria-hidden="true" tabindex="-1"></a>    <span class="st">"adapter_name"</span>: <span class="st">"medical_adapter"</span></span>
<span id="cb28-155"><a href="#cb28-155" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb28-156"><a href="#cb28-156" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-157"><a href="#cb28-157" aria-hidden="true" tabindex="-1"></a>response <span class="op">=</span> api_server.inference_endpoint(inference_request)</span>
<span id="cb28-158"><a href="#cb28-158" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Inference response: </span><span class="sc">{</span>response[<span class="st">'status'</span>]<span class="sc">}</span><span class="ss"> (took </span><span class="sc">{</span>response<span class="sc">.</span>get(<span class="st">'inference_time'</span>, <span class="dv">0</span>)<span class="sc">:.3f}</span><span class="ss">s)"</span>)</span>
<span id="cb28-159"><a href="#cb28-159" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-160"><a href="#cb28-160" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Load new adapter</span></span>
<span id="cb28-161"><a href="#cb28-161" aria-hidden="true" tabindex="-1"></a>load_request <span class="op">=</span> {</span>
<span id="cb28-162"><a href="#cb28-162" aria-hidden="true" tabindex="-1"></a>    <span class="st">"adapter_name"</span>: <span class="st">"custom_adapter"</span>,</span>
<span id="cb28-163"><a href="#cb28-163" aria-hidden="true" tabindex="-1"></a>    <span class="st">"adapter_path"</span>: <span class="st">"adapters/custom"</span>,</span>
<span id="cb28-164"><a href="#cb28-164" aria-hidden="true" tabindex="-1"></a>    <span class="st">"config"</span>: {<span class="st">"rank"</span>: <span class="dv">20</span>, <span class="st">"alpha"</span>: <span class="dv">20</span>}</span>
<span id="cb28-165"><a href="#cb28-165" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb28-166"><a href="#cb28-166" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-167"><a href="#cb28-167" aria-hidden="true" tabindex="-1"></a>response <span class="op">=</span> api_server.load_adapter_endpoint(load_request)</span>
<span id="cb28-168"><a href="#cb28-168" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Load adapter response: </span><span class="sc">{</span>response[<span class="st">'status'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-169"><a href="#cb28-169" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-170"><a href="#cb28-170" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Health check</span></span>
<span id="cb28-171"><a href="#cb28-171" aria-hidden="true" tabindex="-1"></a>health_response <span class="op">=</span> api_server.health_endpoint()</span>
<span id="cb28-172"><a href="#cb28-172" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Health status: </span><span class="sc">{</span>health_response[<span class="st">'status'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-173"><a href="#cb28-173" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-174"><a href="#cb28-174" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. List adapters</span></span>
<span id="cb28-175"><a href="#cb28-175" aria-hidden="true" tabindex="-1"></a>adapters_response <span class="op">=</span> api_server.list_adapters_endpoint()</span>
<span id="cb28-176"><a href="#cb28-176" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Active adapters: </span><span class="sc">{</span>adapters_response[<span class="st">'total_adapters'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb28-177"><a href="#cb28-177" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-178"><a href="#cb28-178" aria-hidden="true" tabindex="-1"></a><span class="co"># 5. Get metrics</span></span>
<span id="cb28-179"><a href="#cb28-179" aria-hidden="true" tabindex="-1"></a>metrics_response <span class="op">=</span> api_server.get_metrics_endpoint()</span>
<span id="cb28-180"><a href="#cb28-180" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Success rate: </span><span class="sc">{</span>metrics_response[<span class="st">'success_rate'</span>]<span class="sc">:.1%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>
API Server Demo:
====================
LoRA API Server initialized
Available endpoints:
  POST /inference - Perform inference
  POST /load_adapter - Load new adapter
  DELETE /adapter/{name} - Unload adapter
  GET /health - Health check
  GET /adapters - List adapters

Simulating API requests...</code></pre>
</div>
<div class="cell-output cell-output-stderr">
<pre><code>INFO:__main__:Loading adapter 'custom_adapter' from adapters/custom
INFO:__main__:Adapter 'custom_adapter' loaded successfully</code></pre>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Inference response: success (took 0.013s)
Load adapter response: success
Health status: healthy
Active adapters: 4
Success rate: 100.0%</code></pre>
</div>
</div>
</section>
</section>
<section id="monitoring-and-observability" class="level2">
<h2 class="anchored" data-anchor-id="monitoring-and-observability" id="monitoring-and-observability">Monitoring and Observability</h2>
<section id="performance-monitoring" class="level3">
<h3 class="anchored" data-anchor-id="performance-monitoring" id="performance-monitoring">Performance Monitoring</h3>
<div id="monitoring-system" class="cell" data-execution_count="19">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> defaultdict, deque</span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LoRAMonitor:</span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive monitoring for LoRA-adapted VLMs"""</span></span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, adapter_name: <span class="bu">str</span> <span class="op">=</span> <span class="st">"default"</span>, window_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1000</span>):</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.adapter_name <span class="op">=</span> adapter_name</span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.window_size <span class="op">=</span> window_size</span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Metrics storage</span></span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics <span class="op">=</span> {</span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">'inference_times'</span>: deque(maxlen<span class="op">=</span>window_size),</span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_usage'</span>: deque(maxlen<span class="op">=</span>window_size),</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>            <span class="st">'accuracy_scores'</span>: deque(maxlen<span class="op">=</span>window_size),</span>
<span id="cb32-18"><a href="#cb32-18" aria-hidden="true" tabindex="-1"></a>            <span class="st">'request_counts'</span>: defaultdict(<span class="bu">int</span>),</span>
<span id="cb32-19"><a href="#cb32-19" aria-hidden="true" tabindex="-1"></a>            <span class="st">'error_counts'</span>: defaultdict(<span class="bu">int</span>),</span>
<span id="cb32-20"><a href="#cb32-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'timestamps'</span>: deque(maxlen<span class="op">=</span>window_size)</span>
<span id="cb32-21"><a href="#cb32-21" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb32-22"><a href="#cb32-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-23"><a href="#cb32-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># LoRA-specific metrics</span></span>
<span id="cb32-24"><a href="#cb32-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lora_metrics <span class="op">=</span> {</span>
<span id="cb32-25"><a href="#cb32-25" aria-hidden="true" tabindex="-1"></a>            <span class="st">'weight_norms'</span>: {},</span>
<span id="cb32-26"><a href="#cb32-26" aria-hidden="true" tabindex="-1"></a>            <span class="st">'rank_utilization'</span>: {},</span>
<span id="cb32-27"><a href="#cb32-27" aria-hidden="true" tabindex="-1"></a>            <span class="st">'adaptation_strength'</span>: {}</span>
<span id="cb32-28"><a href="#cb32-28" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb32-29"><a href="#cb32-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-30"><a href="#cb32-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Performance thresholds</span></span>
<span id="cb32-31"><a href="#cb32-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.thresholds <span class="op">=</span> {</span>
<span id="cb32-32"><a href="#cb32-32" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_inference_time'</span>: <span class="fl">2.0</span>,  <span class="co"># seconds</span></span>
<span id="cb32-33"><a href="#cb32-33" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_memory_usage'</span>: <span class="fl">4.0</span>,    <span class="co"># GB</span></span>
<span id="cb32-34"><a href="#cb32-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'min_accuracy'</span>: <span class="fl">0.8</span>,        <span class="co"># minimum acceptable accuracy</span></span>
<span id="cb32-35"><a href="#cb32-35" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_error_rate'</span>: <span class="fl">0.02</span>      <span class="co"># maximum error rate</span></span>
<span id="cb32-36"><a href="#cb32-36" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb32-37"><a href="#cb32-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-38"><a href="#cb32-38" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"LoRA Monitor initialized for adapter: </span><span class="sc">{</span>adapter_name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb32-39"><a href="#cb32-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-40"><a href="#cb32-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> log_inference(<span class="va">self</span>, inference_time: <span class="bu">float</span>, memory_usage: <span class="bu">float</span>, </span>
<span id="cb32-41"><a href="#cb32-41" aria-hidden="true" tabindex="-1"></a>                     accuracy: Optional[<span class="bu">float</span>] <span class="op">=</span> <span class="va">None</span>):</span>
<span id="cb32-42"><a href="#cb32-42" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Log inference metrics"""</span></span>
<span id="cb32-43"><a href="#cb32-43" aria-hidden="true" tabindex="-1"></a>        current_time <span class="op">=</span> time.time()</span>
<span id="cb32-44"><a href="#cb32-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-45"><a href="#cb32-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[<span class="st">'inference_times'</span>].append(inference_time)</span>
<span id="cb32-46"><a href="#cb32-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[<span class="st">'memory_usage'</span>].append(memory_usage)</span>
<span id="cb32-47"><a href="#cb32-47" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[<span class="st">'timestamps'</span>].append(current_time)</span>
<span id="cb32-48"><a href="#cb32-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-49"><a href="#cb32-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> accuracy <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb32-50"><a href="#cb32-50" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.metrics[<span class="st">'accuracy_scores'</span>].append(accuracy)</span>
<span id="cb32-51"><a href="#cb32-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-52"><a href="#cb32-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check thresholds and alert if necessary</span></span>
<span id="cb32-53"><a href="#cb32-53" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.check_thresholds(inference_time, memory_usage, accuracy)</span>
<span id="cb32-54"><a href="#cb32-54" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-55"><a href="#cb32-55" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> check_thresholds(<span class="va">self</span>, inference_time: <span class="bu">float</span>, memory_usage: <span class="bu">float</span>, </span>
<span id="cb32-56"><a href="#cb32-56" aria-hidden="true" tabindex="-1"></a>                        accuracy: Optional[<span class="bu">float</span>] <span class="op">=</span> <span class="va">None</span>):</span>
<span id="cb32-57"><a href="#cb32-57" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Check if metrics exceed defined thresholds"""</span></span>
<span id="cb32-58"><a href="#cb32-58" aria-hidden="true" tabindex="-1"></a>        alerts <span class="op">=</span> []</span>
<span id="cb32-59"><a href="#cb32-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-60"><a href="#cb32-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> inference_time <span class="op">&gt;</span> <span class="va">self</span>.thresholds[<span class="st">'max_inference_time'</span>]:</span>
<span id="cb32-61"><a href="#cb32-61" aria-hidden="true" tabindex="-1"></a>            alerts.append(<span class="ss">f"HIGH_LATENCY: </span><span class="sc">{</span>inference_time<span class="sc">:.3f}</span><span class="ss">s &gt; </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>thresholds[<span class="st">'max_inference_time'</span>]<span class="sc">}</span><span class="ss">s"</span>)</span>
<span id="cb32-62"><a href="#cb32-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-63"><a href="#cb32-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> memory_usage <span class="op">&gt;</span> <span class="va">self</span>.thresholds[<span class="st">'max_memory_usage'</span>]:</span>
<span id="cb32-64"><a href="#cb32-64" aria-hidden="true" tabindex="-1"></a>            alerts.append(<span class="ss">f"HIGH_MEMORY: </span><span class="sc">{</span>memory_usage<span class="sc">:.2f}</span><span class="ss">GB &gt; </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>thresholds[<span class="st">'max_memory_usage'</span>]<span class="sc">}</span><span class="ss">GB"</span>)</span>
<span id="cb32-65"><a href="#cb32-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-66"><a href="#cb32-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> accuracy <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span> <span class="kw">and</span> accuracy <span class="op">&lt;</span> <span class="va">self</span>.thresholds[<span class="st">'min_accuracy'</span>]:</span>
<span id="cb32-67"><a href="#cb32-67" aria-hidden="true" tabindex="-1"></a>            alerts.append(<span class="ss">f"LOW_ACCURACY: </span><span class="sc">{</span>accuracy<span class="sc">:.3f}</span><span class="ss"> &lt; </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>thresholds[<span class="st">'min_accuracy'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb32-68"><a href="#cb32-68" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-69"><a href="#cb32-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> alert <span class="kw">in</span> alerts:</span>
<span id="cb32-70"><a href="#cb32-70" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"🚨 ALERT [</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>adapter_name<span class="sc">}</span><span class="ss">]: </span><span class="sc">{</span>alert<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb32-71"><a href="#cb32-71" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-72"><a href="#cb32-72" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_performance_stats(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb32-73"><a href="#cb32-73" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute performance statistics from collected metrics"""</span></span>
<span id="cb32-74"><a href="#cb32-74" aria-hidden="true" tabindex="-1"></a>        stats <span class="op">=</span> {}</span>
<span id="cb32-75"><a href="#cb32-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-76"><a href="#cb32-76" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference time statistics</span></span>
<span id="cb32-77"><a href="#cb32-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.metrics[<span class="st">'inference_times'</span>]:</span>
<span id="cb32-78"><a href="#cb32-78" aria-hidden="true" tabindex="-1"></a>            times <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'inference_times'</span>])</span>
<span id="cb32-79"><a href="#cb32-79" aria-hidden="true" tabindex="-1"></a>            stats[<span class="st">'inference_time'</span>] <span class="op">=</span> {</span>
<span id="cb32-80"><a href="#cb32-80" aria-hidden="true" tabindex="-1"></a>                <span class="st">'mean'</span>: np.mean(times),</span>
<span id="cb32-81"><a href="#cb32-81" aria-hidden="true" tabindex="-1"></a>                <span class="st">'std'</span>: np.std(times),</span>
<span id="cb32-82"><a href="#cb32-82" aria-hidden="true" tabindex="-1"></a>                <span class="st">'p50'</span>: np.percentile(times, <span class="dv">50</span>),</span>
<span id="cb32-83"><a href="#cb32-83" aria-hidden="true" tabindex="-1"></a>                <span class="st">'p95'</span>: np.percentile(times, <span class="dv">95</span>),</span>
<span id="cb32-84"><a href="#cb32-84" aria-hidden="true" tabindex="-1"></a>                <span class="st">'p99'</span>: np.percentile(times, <span class="dv">99</span>),</span>
<span id="cb32-85"><a href="#cb32-85" aria-hidden="true" tabindex="-1"></a>                <span class="st">'min'</span>: np.<span class="bu">min</span>(times),</span>
<span id="cb32-86"><a href="#cb32-86" aria-hidden="true" tabindex="-1"></a>                <span class="st">'max'</span>: np.<span class="bu">max</span>(times)</span>
<span id="cb32-87"><a href="#cb32-87" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb32-88"><a href="#cb32-88" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-89"><a href="#cb32-89" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Memory usage statistics</span></span>
<span id="cb32-90"><a href="#cb32-90" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.metrics[<span class="st">'memory_usage'</span>]:</span>
<span id="cb32-91"><a href="#cb32-91" aria-hidden="true" tabindex="-1"></a>            memory <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'memory_usage'</span>])</span>
<span id="cb32-92"><a href="#cb32-92" aria-hidden="true" tabindex="-1"></a>            stats[<span class="st">'memory_usage'</span>] <span class="op">=</span> {</span>
<span id="cb32-93"><a href="#cb32-93" aria-hidden="true" tabindex="-1"></a>                <span class="st">'mean'</span>: np.mean(memory),</span>
<span id="cb32-94"><a href="#cb32-94" aria-hidden="true" tabindex="-1"></a>                <span class="st">'max'</span>: np.<span class="bu">max</span>(memory),</span>
<span id="cb32-95"><a href="#cb32-95" aria-hidden="true" tabindex="-1"></a>                <span class="st">'min'</span>: np.<span class="bu">min</span>(memory),</span>
<span id="cb32-96"><a href="#cb32-96" aria-hidden="true" tabindex="-1"></a>                <span class="st">'current'</span>: memory[<span class="op">-</span><span class="dv">1</span>] <span class="cf">if</span> memory <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb32-97"><a href="#cb32-97" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb32-98"><a href="#cb32-98" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-99"><a href="#cb32-99" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Accuracy statistics</span></span>
<span id="cb32-100"><a href="#cb32-100" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.metrics[<span class="st">'accuracy_scores'</span>]:</span>
<span id="cb32-101"><a href="#cb32-101" aria-hidden="true" tabindex="-1"></a>            accuracy <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'accuracy_scores'</span>])</span>
<span id="cb32-102"><a href="#cb32-102" aria-hidden="true" tabindex="-1"></a>            stats[<span class="st">'accuracy'</span>] <span class="op">=</span> {</span>
<span id="cb32-103"><a href="#cb32-103" aria-hidden="true" tabindex="-1"></a>                <span class="st">'mean'</span>: np.mean(accuracy),</span>
<span id="cb32-104"><a href="#cb32-104" aria-hidden="true" tabindex="-1"></a>                <span class="st">'std'</span>: np.std(accuracy),</span>
<span id="cb32-105"><a href="#cb32-105" aria-hidden="true" tabindex="-1"></a>                <span class="st">'min'</span>: np.<span class="bu">min</span>(accuracy),</span>
<span id="cb32-106"><a href="#cb32-106" aria-hidden="true" tabindex="-1"></a>                <span class="st">'max'</span>: np.<span class="bu">max</span>(accuracy),</span>
<span id="cb32-107"><a href="#cb32-107" aria-hidden="true" tabindex="-1"></a>                <span class="st">'recent'</span>: np.mean(accuracy[<span class="op">-</span><span class="dv">10</span>:]) <span class="cf">if</span> <span class="bu">len</span>(accuracy) <span class="op">&gt;=</span> <span class="dv">10</span> <span class="cf">else</span> np.mean(accuracy)</span>
<span id="cb32-108"><a href="#cb32-108" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb32-109"><a href="#cb32-109" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-110"><a href="#cb32-110" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Throughput calculation</span></span>
<span id="cb32-111"><a href="#cb32-111" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(<span class="va">self</span>.metrics[<span class="st">'timestamps'</span>]) <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb32-112"><a href="#cb32-112" aria-hidden="true" tabindex="-1"></a>            time_span <span class="op">=</span> <span class="va">self</span>.metrics[<span class="st">'timestamps'</span>][<span class="op">-</span><span class="dv">1</span>] <span class="op">-</span> <span class="va">self</span>.metrics[<span class="st">'timestamps'</span>][<span class="dv">0</span>]</span>
<span id="cb32-113"><a href="#cb32-113" aria-hidden="true" tabindex="-1"></a>            stats[<span class="st">'throughput'</span>] <span class="op">=</span> {</span>
<span id="cb32-114"><a href="#cb32-114" aria-hidden="true" tabindex="-1"></a>                <span class="st">'requests_per_second'</span>: <span class="bu">len</span>(<span class="va">self</span>.metrics[<span class="st">'timestamps'</span>]) <span class="op">/</span> time_span <span class="cf">if</span> time_span <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb32-115"><a href="#cb32-115" aria-hidden="true" tabindex="-1"></a>                <span class="st">'time_span_minutes'</span>: time_span <span class="op">/</span> <span class="dv">60</span></span>
<span id="cb32-116"><a href="#cb32-116" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb32-117"><a href="#cb32-117" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-118"><a href="#cb32-118" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> stats</span>
<span id="cb32-119"><a href="#cb32-119" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-120"><a href="#cb32-120" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_trends(<span class="va">self</span>, window_minutes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">30</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb32-121"><a href="#cb32-121" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Analyze performance trends over time"""</span></span>
<span id="cb32-122"><a href="#cb32-122" aria-hidden="true" tabindex="-1"></a>        current_time <span class="op">=</span> time.time()</span>
<span id="cb32-123"><a href="#cb32-123" aria-hidden="true" tabindex="-1"></a>        cutoff_time <span class="op">=</span> current_time <span class="op">-</span> (window_minutes <span class="op">*</span> <span class="dv">60</span>)</span>
<span id="cb32-124"><a href="#cb32-124" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-125"><a href="#cb32-125" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Filter recent metrics</span></span>
<span id="cb32-126"><a href="#cb32-126" aria-hidden="true" tabindex="-1"></a>        recent_indices <span class="op">=</span> [i <span class="cf">for</span> i, t <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.metrics[<span class="st">'timestamps'</span>]) </span>
<span id="cb32-127"><a href="#cb32-127" aria-hidden="true" tabindex="-1"></a>                         <span class="cf">if</span> t <span class="op">&gt;=</span> cutoff_time]</span>
<span id="cb32-128"><a href="#cb32-128" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-129"><a href="#cb32-129" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(recent_indices) <span class="op">&lt;</span> <span class="dv">2</span>:</span>
<span id="cb32-130"><a href="#cb32-130" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {<span class="st">"error"</span>: <span class="st">"Insufficient data for trend analysis"</span>}</span>
<span id="cb32-131"><a href="#cb32-131" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-132"><a href="#cb32-132" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract recent data</span></span>
<span id="cb32-133"><a href="#cb32-133" aria-hidden="true" tabindex="-1"></a>        recent_times <span class="op">=</span> [<span class="va">self</span>.metrics[<span class="st">'inference_times'</span>][i] <span class="cf">for</span> i <span class="kw">in</span> recent_indices]</span>
<span id="cb32-134"><a href="#cb32-134" aria-hidden="true" tabindex="-1"></a>        recent_memory <span class="op">=</span> [<span class="va">self</span>.metrics[<span class="st">'memory_usage'</span>][i] <span class="cf">for</span> i <span class="kw">in</span> recent_indices]</span>
<span id="cb32-135"><a href="#cb32-135" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-136"><a href="#cb32-136" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate trends (simple linear regression slope)</span></span>
<span id="cb32-137"><a href="#cb32-137" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> np.arange(<span class="bu">len</span>(recent_times))</span>
<span id="cb32-138"><a href="#cb32-138" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-139"><a href="#cb32-139" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference time trend</span></span>
<span id="cb32-140"><a href="#cb32-140" aria-hidden="true" tabindex="-1"></a>        time_slope <span class="op">=</span> np.polyfit(x, recent_times, <span class="dv">1</span>)[<span class="dv">0</span>] <span class="cf">if</span> <span class="bu">len</span>(recent_times) <span class="op">&gt;</span> <span class="dv">1</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb32-141"><a href="#cb32-141" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-142"><a href="#cb32-142" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Memory usage trend  </span></span>
<span id="cb32-143"><a href="#cb32-143" aria-hidden="true" tabindex="-1"></a>        memory_slope <span class="op">=</span> np.polyfit(x, recent_memory, <span class="dv">1</span>)[<span class="dv">0</span>] <span class="cf">if</span> <span class="bu">len</span>(recent_memory) <span class="op">&gt;</span> <span class="dv">1</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb32-144"><a href="#cb32-144" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-145"><a href="#cb32-145" aria-hidden="true" tabindex="-1"></a>        trends <span class="op">=</span> {</span>
<span id="cb32-146"><a href="#cb32-146" aria-hidden="true" tabindex="-1"></a>            <span class="st">'window_minutes'</span>: window_minutes,</span>
<span id="cb32-147"><a href="#cb32-147" aria-hidden="true" tabindex="-1"></a>            <span class="st">'data_points'</span>: <span class="bu">len</span>(recent_indices),</span>
<span id="cb32-148"><a href="#cb32-148" aria-hidden="true" tabindex="-1"></a>            <span class="st">'inference_time_trend'</span>: {</span>
<span id="cb32-149"><a href="#cb32-149" aria-hidden="true" tabindex="-1"></a>                <span class="st">'slope'</span>: time_slope,</span>
<span id="cb32-150"><a href="#cb32-150" aria-hidden="true" tabindex="-1"></a>                <span class="st">'direction'</span>: <span class="st">'increasing'</span> <span class="cf">if</span> time_slope <span class="op">&gt;</span> <span class="fl">0.001</span> <span class="cf">else</span> <span class="st">'decreasing'</span> <span class="cf">if</span> time_slope <span class="op">&lt;</span> <span class="op">-</span><span class="fl">0.001</span> <span class="cf">else</span> <span class="st">'stable'</span>,</span>
<span id="cb32-151"><a href="#cb32-151" aria-hidden="true" tabindex="-1"></a>                <span class="st">'severity'</span>: <span class="st">'high'</span> <span class="cf">if</span> <span class="bu">abs</span>(time_slope) <span class="op">&gt;</span> <span class="fl">0.01</span> <span class="cf">else</span> <span class="st">'medium'</span> <span class="cf">if</span> <span class="bu">abs</span>(time_slope) <span class="op">&gt;</span> <span class="fl">0.005</span> <span class="cf">else</span> <span class="st">'low'</span></span>
<span id="cb32-152"><a href="#cb32-152" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb32-153"><a href="#cb32-153" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_usage_trend'</span>: {</span>
<span id="cb32-154"><a href="#cb32-154" aria-hidden="true" tabindex="-1"></a>                <span class="st">'slope'</span>: memory_slope,</span>
<span id="cb32-155"><a href="#cb32-155" aria-hidden="true" tabindex="-1"></a>                <span class="st">'direction'</span>: <span class="st">'increasing'</span> <span class="cf">if</span> memory_slope <span class="op">&gt;</span> <span class="fl">0.01</span> <span class="cf">else</span> <span class="st">'decreasing'</span> <span class="cf">if</span> memory_slope <span class="op">&lt;</span> <span class="op">-</span><span class="fl">0.01</span> <span class="cf">else</span> <span class="st">'stable'</span>,</span>
<span id="cb32-156"><a href="#cb32-156" aria-hidden="true" tabindex="-1"></a>                <span class="st">'severity'</span>: <span class="st">'high'</span> <span class="cf">if</span> <span class="bu">abs</span>(memory_slope) <span class="op">&gt;</span> <span class="fl">0.1</span> <span class="cf">else</span> <span class="st">'medium'</span> <span class="cf">if</span> <span class="bu">abs</span>(memory_slope) <span class="op">&gt;</span> <span class="fl">0.05</span> <span class="cf">else</span> <span class="st">'low'</span></span>
<span id="cb32-157"><a href="#cb32-157" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb32-158"><a href="#cb32-158" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb32-159"><a href="#cb32-159" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-160"><a href="#cb32-160" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> trends</span>
<span id="cb32-161"><a href="#cb32-161" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-162"><a href="#cb32-162" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_monitoring_report(<span class="va">self</span>) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, Any]:</span>
<span id="cb32-163"><a href="#cb32-163" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate comprehensive monitoring report"""</span></span>
<span id="cb32-164"><a href="#cb32-164" aria-hidden="true" tabindex="-1"></a>        report <span class="op">=</span> {</span>
<span id="cb32-165"><a href="#cb32-165" aria-hidden="true" tabindex="-1"></a>            <span class="st">'adapter_name'</span>: <span class="va">self</span>.adapter_name,</span>
<span id="cb32-166"><a href="#cb32-166" aria-hidden="true" tabindex="-1"></a>            <span class="st">'report_timestamp'</span>: time.time(),</span>
<span id="cb32-167"><a href="#cb32-167" aria-hidden="true" tabindex="-1"></a>            <span class="st">'performance_stats'</span>: <span class="va">self</span>.compute_performance_stats(),</span>
<span id="cb32-168"><a href="#cb32-168" aria-hidden="true" tabindex="-1"></a>            <span class="st">'trends'</span>: <span class="va">self</span>.analyze_trends(),</span>
<span id="cb32-169"><a href="#cb32-169" aria-hidden="true" tabindex="-1"></a>            <span class="st">'thresholds'</span>: <span class="va">self</span>.thresholds,</span>
<span id="cb32-170"><a href="#cb32-170" aria-hidden="true" tabindex="-1"></a>            <span class="st">'health_status'</span>: <span class="va">self</span>._compute_health_status()</span>
<span id="cb32-171"><a href="#cb32-171" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb32-172"><a href="#cb32-172" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-173"><a href="#cb32-173" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> report</span>
<span id="cb32-174"><a href="#cb32-174" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-175"><a href="#cb32-175" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _compute_health_status(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="bu">str</span>:</span>
<span id="cb32-176"><a href="#cb32-176" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute overall health status"""</span></span>
<span id="cb32-177"><a href="#cb32-177" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.metrics[<span class="st">'inference_times'</span>]:</span>
<span id="cb32-178"><a href="#cb32-178" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="st">'unknown'</span></span>
<span id="cb32-179"><a href="#cb32-179" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-180"><a href="#cb32-180" aria-hidden="true" tabindex="-1"></a>        recent_times <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'inference_times'</span>])[<span class="op">-</span><span class="dv">10</span>:]</span>
<span id="cb32-181"><a href="#cb32-181" aria-hidden="true" tabindex="-1"></a>        recent_memory <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'memory_usage'</span>])[<span class="op">-</span><span class="dv">10</span>:]</span>
<span id="cb32-182"><a href="#cb32-182" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-183"><a href="#cb32-183" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for threshold violations</span></span>
<span id="cb32-184"><a href="#cb32-184" aria-hidden="true" tabindex="-1"></a>        high_latency <span class="op">=</span> <span class="bu">any</span>(t <span class="op">&gt;</span> <span class="va">self</span>.thresholds[<span class="st">'max_inference_time'</span>] <span class="cf">for</span> t <span class="kw">in</span> recent_times)</span>
<span id="cb32-185"><a href="#cb32-185" aria-hidden="true" tabindex="-1"></a>        high_memory <span class="op">=</span> <span class="bu">any</span>(m <span class="op">&gt;</span> <span class="va">self</span>.thresholds[<span class="st">'max_memory_usage'</span>] <span class="cf">for</span> m <span class="kw">in</span> recent_memory)</span>
<span id="cb32-186"><a href="#cb32-186" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-187"><a href="#cb32-187" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> high_latency <span class="kw">or</span> high_memory:</span>
<span id="cb32-188"><a href="#cb32-188" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="st">'degraded'</span></span>
<span id="cb32-189"><a href="#cb32-189" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-190"><a href="#cb32-190" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for accuracy issues</span></span>
<span id="cb32-191"><a href="#cb32-191" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.metrics[<span class="st">'accuracy_scores'</span>]:</span>
<span id="cb32-192"><a href="#cb32-192" aria-hidden="true" tabindex="-1"></a>            recent_accuracy <span class="op">=</span> <span class="bu">list</span>(<span class="va">self</span>.metrics[<span class="st">'accuracy_scores'</span>])[<span class="op">-</span><span class="dv">10</span>:]</span>
<span id="cb32-193"><a href="#cb32-193" aria-hidden="true" tabindex="-1"></a>            low_accuracy <span class="op">=</span> <span class="bu">any</span>(a <span class="op">&lt;</span> <span class="va">self</span>.thresholds[<span class="st">'min_accuracy'</span>] <span class="cf">for</span> a <span class="kw">in</span> recent_accuracy)</span>
<span id="cb32-194"><a href="#cb32-194" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> low_accuracy:</span>
<span id="cb32-195"><a href="#cb32-195" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="st">'degraded'</span></span>
<span id="cb32-196"><a href="#cb32-196" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb32-197"><a href="#cb32-197" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="st">'healthy'</span></span>
<span id="cb32-198"><a href="#cb32-198" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-199"><a href="#cb32-199" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitoring demonstration</span></span>
<span id="cb32-200"><a href="#cb32-200" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"LoRA Monitoring System Demo:"</span>)</span>
<span id="cb32-201"><a href="#cb32-201" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span> <span class="op">*</span> <span class="dv">30</span>)</span>
<span id="cb32-202"><a href="#cb32-202" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-203"><a href="#cb32-203" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize monitor</span></span>
<span id="cb32-204"><a href="#cb32-204" aria-hidden="true" tabindex="-1"></a>monitor <span class="op">=</span> LoRAMonitor(<span class="va">None</span>, <span class="st">"production_adapter"</span>)</span>
<span id="cb32-205"><a href="#cb32-205" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-206"><a href="#cb32-206" aria-hidden="true" tabindex="-1"></a><span class="co"># Simulate monitoring data</span></span>
<span id="cb32-207"><a href="#cb32-207" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Simulating monitoring data..."</span>)</span>
<span id="cb32-208"><a href="#cb32-208" aria-hidden="true" tabindex="-1"></a>np.random.seed(<span class="dv">42</span>)  <span class="co"># For reproducible results</span></span>
<span id="cb32-209"><a href="#cb32-209" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-210"><a href="#cb32-210" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):</span>
<span id="cb32-211"><a href="#cb32-211" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Simulate varying performance</span></span>
<span id="cb32-212"><a href="#cb32-212" aria-hidden="true" tabindex="-1"></a>    base_latency <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb32-213"><a href="#cb32-213" aria-hidden="true" tabindex="-1"></a>    latency_noise <span class="op">=</span> np.random.normal(<span class="dv">0</span>, <span class="fl">0.02</span>)</span>
<span id="cb32-214"><a href="#cb32-214" aria-hidden="true" tabindex="-1"></a>    memory_base <span class="op">=</span> <span class="fl">2.0</span></span>
<span id="cb32-215"><a href="#cb32-215" aria-hidden="true" tabindex="-1"></a>    memory_noise <span class="op">=</span> np.random.normal(<span class="dv">0</span>, <span class="fl">0.1</span>)</span>
<span id="cb32-216"><a href="#cb32-216" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-217"><a href="#cb32-217" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Add some performance degradation over time</span></span>
<span id="cb32-218"><a href="#cb32-218" aria-hidden="true" tabindex="-1"></a>    degradation_factor <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> (i <span class="op">/</span> <span class="dv">1000</span>)</span>
<span id="cb32-219"><a href="#cb32-219" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-220"><a href="#cb32-220" aria-hidden="true" tabindex="-1"></a>    inference_time <span class="op">=</span> base_latency <span class="op">*</span> degradation_factor <span class="op">+</span> latency_noise</span>
<span id="cb32-221"><a href="#cb32-221" aria-hidden="true" tabindex="-1"></a>    memory_usage <span class="op">=</span> memory_base <span class="op">+</span> memory_noise</span>
<span id="cb32-222"><a href="#cb32-222" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> <span class="fl">0.92</span> <span class="op">+</span> np.random.normal(<span class="dv">0</span>, <span class="fl">0.03</span>)</span>
<span id="cb32-223"><a href="#cb32-223" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-224"><a href="#cb32-224" aria-hidden="true" tabindex="-1"></a>    monitor.log_inference(inference_time, memory_usage, accuracy)</span>
<span id="cb32-225"><a href="#cb32-225" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-226"><a href="#cb32-226" aria-hidden="true" tabindex="-1"></a><span class="co"># Generate performance report</span></span>
<span id="cb32-227"><a href="#cb32-227" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Generating performance report..."</span>)</span>
<span id="cb32-228"><a href="#cb32-228" aria-hidden="true" tabindex="-1"></a>report <span class="op">=</span> monitor.generate_monitoring_report()</span>
<span id="cb32-229"><a href="#cb32-229" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-230"><a href="#cb32-230" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Health Status: </span><span class="sc">{</span>report[<span class="st">'health_status'</span>]<span class="sc">.</span>upper()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb32-231"><a href="#cb32-231" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-232"><a href="#cb32-232" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="st">'performance_stats'</span> <span class="kw">in</span> report:</span>
<span id="cb32-233"><a href="#cb32-233" aria-hidden="true" tabindex="-1"></a>    perf <span class="op">=</span> report[<span class="st">'performance_stats'</span>]</span>
<span id="cb32-234"><a href="#cb32-234" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-235"><a href="#cb32-235" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">'inference_time'</span> <span class="kw">in</span> perf:</span>
<span id="cb32-236"><a href="#cb32-236" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Inference Time - Mean: </span><span class="sc">{</span>perf[<span class="st">'inference_time'</span>][<span class="st">'mean'</span>]<span class="sc">:.3f}</span><span class="ss">s, P95: </span><span class="sc">{</span>perf[<span class="st">'inference_time'</span>][<span class="st">'p95'</span>]<span class="sc">:.3f}</span><span class="ss">s"</span>)</span>
<span id="cb32-237"><a href="#cb32-237" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-238"><a href="#cb32-238" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">'memory_usage'</span> <span class="kw">in</span> perf:</span>
<span id="cb32-239"><a href="#cb32-239" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Memory Usage - Mean: </span><span class="sc">{</span>perf[<span class="st">'memory_usage'</span>][<span class="st">'mean'</span>]<span class="sc">:.2f}</span><span class="ss">GB, Max: </span><span class="sc">{</span>perf[<span class="st">'memory_usage'</span>][<span class="st">'max'</span>]<span class="sc">:.2f}</span><span class="ss">GB"</span>)</span>
<span id="cb32-240"><a href="#cb32-240" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-241"><a href="#cb32-241" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">'accuracy'</span> <span class="kw">in</span> perf:</span>
<span id="cb32-242"><a href="#cb32-242" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Accuracy - Mean: </span><span class="sc">{</span>perf[<span class="st">'accuracy'</span>][<span class="st">'mean'</span>]<span class="sc">:.3f}</span><span class="ss">, Recent: </span><span class="sc">{</span>perf[<span class="st">'accuracy'</span>][<span class="st">'recent'</span>]<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb32-243"><a href="#cb32-243" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb32-244"><a href="#cb32-244" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">'throughput'</span> <span class="kw">in</span> perf:</span>
<span id="cb32-245"><a href="#cb32-245" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Throughput: </span><span class="sc">{</span>perf[<span class="st">'throughput'</span>][<span class="st">'requests_per_second'</span>]<span class="sc">:.1f}</span><span class="ss"> req/s"</span>)</span>
<span id="cb32-246"><a href="#cb32-246" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-247"><a href="#cb32-247" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="st">'trends'</span> <span class="kw">in</span> report <span class="kw">and</span> <span class="st">'error'</span> <span class="kw">not</span> <span class="kw">in</span> report[<span class="st">'trends'</span>]:</span>
<span id="cb32-248"><a href="#cb32-248" aria-hidden="true" tabindex="-1"></a>    trends <span class="op">=</span> report[<span class="st">'trends'</span>]</span>
<span id="cb32-249"><a href="#cb32-249" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="ch">\n</span><span class="ss">Trend Analysis (</span><span class="sc">{</span>trends[<span class="st">'window_minutes'</span>]<span class="sc">}</span><span class="ss"> min window):"</span>)</span>
<span id="cb32-250"><a href="#cb32-250" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Latency trend: </span><span class="sc">{</span>trends[<span class="st">'inference_time_trend'</span>][<span class="st">'direction'</span>]<span class="sc">}</span><span class="ss"> (</span><span class="sc">{</span>trends[<span class="st">'inference_time_trend'</span>][<span class="st">'severity'</span>]<span class="sc">}</span><span class="ss"> severity)"</span>)</span>
<span id="cb32-251"><a href="#cb32-251" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Memory trend: </span><span class="sc">{</span>trends[<span class="st">'memory_usage_trend'</span>][<span class="st">'direction'</span>]<span class="sc">}</span><span class="ss"> (</span><span class="sc">{</span>trends[<span class="st">'memory_usage_trend'</span>][<span class="st">'severity'</span>]<span class="sc">}</span><span class="ss"> severity)"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>LoRA Monitoring System Demo:
==============================
LoRA Monitor initialized for adapter: production_adapter

Simulating monitoring data...

Generating performance report...
Health Status: HEALTHY
Inference Time - Mean: 0.102s, P95: 0.131s
Memory Usage - Mean: 1.99GB, Max: 2.19GB
Accuracy - Mean: 0.917, Recent: 0.926
Throughput: 543303.6 req/s

Trend Analysis (30 min window):
Latency trend: stable (low severity)
Memory trend: stable (low severity)</code></pre>
</div>
</div>
</section>
<section id="visualization-and-dashboards" class="level3">
<h3 class="anchored" data-anchor-id="visualization-and-dashboards" id="visualization-and-dashboards">Visualization and Dashboards</h3>
<div id="cell-fig-monitoring-dashboard" class="cell" data-execution_count="20">
<div class="cell-output cell-output-display">
<div id="fig-monitoring-dashboard" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-monitoring-dashboard-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/fig-monitoring-dashboard-output-1.png" width="1522" height="1136" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-monitoring-dashboard-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;4: LoRA Monitoring Dashboard
</figcaption>
</figure>
</div>
</div>
</div>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-10-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-10-1" role="tab" aria-controls="tabset-10-1" aria-selected="true" href="">Emerging Techniques</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-10-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-10-2" role="tab" aria-controls="tabset-10-2" aria-selected="false" href="">Research Roadmap</a></li></ul>
<div class="tab-content">
<div id="tabset-10-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-10-1-tab">
<section id="dynamic-lora" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-lora" id="dynamic-lora">Dynamic LoRA</h3>
<ul>
<li><strong>Description</strong>: Adaptive rank and module selection during training</li>
<li><strong>Potential Impact</strong>: 30-50% efficiency improvement</li>
<li><strong>Maturity</strong>: Research phase</li>
<li><strong>Status</strong>: 🔬 Active Research</li>
</ul>
</section>
<section id="hierarchical-lora" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-lora" id="hierarchical-lora">Hierarchical LoRA</h3>
<ul>
<li><strong>Description</strong>: Multi-level adaptation for different abstraction levels</li>
<li><strong>Potential Impact</strong>: Better transfer learning</li>
<li><strong>Maturity</strong>: Early development</li>
<li><strong>Status</strong>: 🌱 Early Development</li>
</ul>
</section>
<section id="conditional-lora" class="level3">
<h3 class="anchored" data-anchor-id="conditional-lora" id="conditional-lora">Conditional LoRA</h3>
<ul>
<li><strong>Description</strong>: Task-conditional parameter generation</li>
<li><strong>Potential Impact</strong>: Unlimited task adaptation</li>
<li><strong>Maturity</strong>: Conceptual</li>
<li><strong>Status</strong>: 💡 Conceptual</li>
</ul>
</section>
<section id="federated-lora" class="level3">
<h3 class="anchored" data-anchor-id="federated-lora" id="federated-lora">Federated LoRA</h3>
<ul>
<li><strong>Description</strong>: Distributed learning with privacy preservation</li>
<li><strong>Potential Impact</strong>: Privacy-safe collaboration</li>
<li><strong>Maturity</strong>: Active research</li>
<li><strong>Status</strong>: 🔬 Active Research</li>
</ul>
</section>
<section id="neural-architecture-lora" class="level3">
<h3 class="anchored" data-anchor-id="neural-architecture-lora" id="neural-architecture-lora">Neural Architecture LoRA</h3>
<ul>
<li><strong>Description</strong>: Architecture search for optimal LoRA configurations</li>
<li><strong>Potential Impact</strong>: Optimal configurations automatically</li>
<li><strong>Maturity</strong>: Research phase</li>
<li><strong>Status</strong>: 🔬 Research Phase</li>
</ul>
</section>
</div>
<div id="tabset-10-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-10-2-tab">
<section id="short-term-6-12-months" class="level3">
<h3 class="anchored" data-anchor-id="short-term-6-12-months" id="short-term-6-12-months">Short Term (6-12 months)</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Focus Areas
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Improved rank selection algorithms</li>
<li>Better initialization strategies</li>
<li>Enhanced debugging tools</li>
<li>Standardized evaluation protocols</li>
</ul>
</div>
</div>
<p><strong>Expected Outcomes:</strong></p>
<ul>
<li>More stable training</li>
<li>Better out-of-box performance</li>
<li>Easier troubleshooting</li>
</ul>
</section>
<section id="medium-term-1-2-years" class="level3">
<h3 class="anchored" data-anchor-id="medium-term-1-2-years" id="medium-term-1-2-years">Medium Term (1-2 years)</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Focus Areas
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Dynamic and adaptive LoRA</li>
<li>Multi-modal LoRA extensions</li>
<li>Automated hyperparameter optimization</li>
<li>Large-scale deployment frameworks</li>
</ul>
</div>
</div>
<p><strong>Expected Outcomes:</strong></p>
<ul>
<li>Self-optimizing systems</li>
<li>Audio-visual-text models</li>
<li>Production-ready pipelines</li>
</ul>
</section>
<section id="long-term-2-5-years" class="level3">
<h3 class="anchored" data-anchor-id="long-term-2-5-years" id="long-term-2-5-years">Long Term (2-5 years)</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Focus Areas
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Theoretical understanding of adaptation</li>
<li>Novel mathematical frameworks</li>
<li>Integration with emerging architectures</li>
<li>Quantum-inspired adaptations</li>
</ul>
</div>
</div>
<p><strong>Expected Outcomes:</strong></p>
<ul>
<li>Principled design guidelines</li>
<li>Next-generation efficiency</li>
<li>Revolutionary capabilities</li>
</ul>
</section>
</div>
</div>
</div>
<section id="impact-analysis" class="level3">
<h3 class="anchored" data-anchor-id="impact-analysis" id="impact-analysis">Impact Analysis</h3>
<section id="dynamic-lora-case-study" class="level4">
<h4 class="anchored" data-anchor-id="dynamic-lora-case-study">Dynamic LoRA Case Study</h4>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Predicted Impact Analysis
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Technique</strong>: Dynamic LoRA<br>
<strong>Description</strong>: Adaptive rank and module selection during training</p>
</div>
</div>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Metric</th>
<th>Value</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Efficiency Gain</strong></td>
<td>1.8x</td>
</tr>
<tr class="even">
<td><strong>Performance Improvement</strong></td>
<td>+3.0%</td>
</tr>
<tr class="odd">
<td><strong>Adoption Timeline</strong></td>
<td>6 months</td>
</tr>
<tr class="even">
<td><strong>Implementation Complexity</strong></td>
<td>Medium</td>
</tr>
<tr class="odd">
<td><strong>Research Interest Score</strong></td>
<td>0.94/1.00</td>
</tr>
</tbody>
</table>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">gantt
    title LoRA Research Timeline
    dateFormat  YYYY-MM
    section Short Term
    Rank Selection     :active, st1, 2024-08, 6M
    Initialization     :active, st2, 2024-08, 6M
    Debugging Tools    :st3, after st1, 4M
    section Medium Term
    Dynamic LoRA       :mt1, 2025-02, 12M
    Multi-modal        :mt2, 2025-06, 18M
    Auto-optimization  :mt3, after mt1, 12M
    section Long Term
    Theory Framework   :lt1, 2026-01, 24M
    Next-gen Arch      :lt2, 2026-06, 30M
    Quantum Inspired   :lt3, 2027-01, 36M
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="summary" class="level4">
<h4 class="anchored" data-anchor-id="summary">Summary</h4>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Dynamic LoRA</strong> shows the most immediate promise with 1.8x efficiency gains</li>
<li><strong>Short-term focus</strong> should be on stability and usability improvements</li>
<li><strong>Long-term vision</strong> includes theoretical breakthroughs and quantum adaptations</li>
<li><strong>Timeline</strong> spans from 6 months to 5 years for full roadmap completion</li>
</ol>
</div>
</div>
</section>
</section>
<section id="research-opportunities" class="level3">
<h3 class="anchored" data-anchor-id="research-opportunities" id="research-opportunities">Research Opportunities</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Research Domains
</div>
</div>
<div class="callout-body-container callout-body">
<p>Three primary areas have been identified for immediate investigation:</p>
</div>
</div>
<div class="quarto-layout-panel" data-layout-ncol="3">
<div class="quarto-layout-row">
<div class="card quarto-layout-cell" style="flex-basis: 33.3%;justify-content: flex-start;">
<p><strong>Theoretical Analysis</strong></p>
<ul>
<li>Better understanding of LoRA’s approximation capabilities</li>
<li>4 key research questions identified</li>
<li>Focus on mathematical foundations</li>
</ul>
</div>
<div class="card quarto-layout-cell" style="flex-basis: 33.3%;justify-content: flex-start;">
<p><strong>Architecture Specific</strong></p>
<ul>
<li>Optimized LoRA for different VLM architectures</li>
<li>4 key research questions identified</li>
<li>Vision-language model specialization</li>
</ul>
</div>
<div class="card quarto-layout-cell" style="flex-basis: 33.3%;justify-content: flex-start;">
<p><strong>Efficiency Optimization</strong></p>
<ul>
<li>Hardware-aware LoRA optimization</li>
<li>4 key research questions identified</li>
<li>Performance and resource utilization</li>
</ul>
</div>
</div>
</div>
</section>
<section id="detailed-proposals" class="level3">
<h3 class="anchored" data-anchor-id="detailed-proposals" id="detailed-proposals">Detailed Proposals</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center" data-bs-toggle="collapse" data-bs-target=".callout-30-contents" aria-controls="callout-30" aria-expanded="true" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Research Proposal Details
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-30" class="callout-30-contents callout-collapse collapse show">
<div class="callout-body-container callout-body">
<p><strong>Area:</strong> Theoretical Analysis<br>
<strong>Priority:</strong> HIGH<br>
<strong>Description:</strong> Better understanding of LoRA’s approximation capabilities</p>
<section id="proposal-1-theoretical-limits-investigation" class="level4">
<h4 class="anchored" data-anchor-id="proposal-1-theoretical-limits-investigation">Proposal 1: Theoretical Limits Investigation</h4>
<ul>
<li><strong>Objective:</strong> What is the theoretical limit of low-rank approximation?</li>
<li><strong>Methodology:</strong> Matrix perturbation theory</li>
<li><strong>Timeline:</strong> 12-18 months</li>
<li><strong>Expected Outcomes:</strong>
<ul>
<li>Mathematical bounds on approximation quality</li>
<li>Guidelines for rank selection</li>
<li>Theoretical framework for optimization</li>
</ul></li>
</ul>
</section>
</div>
</div>
</div>
<section id="research-questions-framework" class="level4">
<h4 class="anchored" data-anchor-id="research-questions-framework">Research Questions Framework</h4>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-11-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-11-1" role="tab" aria-controls="tabset-11-1" aria-selected="true" href="">Theoretical</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-11-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-11-2" role="tab" aria-controls="tabset-11-2" aria-selected="false" href="">Architectural</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-11-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-11-3" role="tab" aria-controls="tabset-11-3" aria-selected="false" href="">Efficiency</a></li></ul>
<div class="tab-content">
<div id="tabset-11-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-11-1-tab">
<ol type="1">
<li>What are the fundamental limits of low-rank approximation in neural networks?</li>
<li>How does rank selection impact convergence and generalization?</li>
<li>Can we establish theoretical guarantees for LoRA performance?</li>
<li>What is the relationship between rank and model capacity?</li>
</ol>
</div>
<div id="tabset-11-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-11-2-tab">
<ol type="1">
<li>How can LoRA be optimized for transformer architectures?</li>
<li>What are the best practices for multi-modal model adaptation?</li>
<li>How does LoRA performance vary across different layer types?</li>
<li>Can we develop architecture-specific rank selection strategies?</li>
</ol>
</div>
<div id="tabset-11-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-11-3-tab">
<ol type="1">
<li>What are the optimal hardware configurations for LoRA training?</li>
<li>How can we minimize memory overhead during adaptation?</li>
<li>What parallelization strategies work best for LoRA?</li>
<li>Can we develop real-time adaptation capabilities?</li>
</ol>
</div>
</div>
</div>
</section>
</section>
<section id="impact-assessment" class="level3">
<h3 class="anchored" data-anchor-id="impact-assessment" id="impact-assessment">Impact Assessment</h3>
<div id="80219aa5" class="cell" data-execution_count="21">
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-language-models/vision-language-lora/cell-22-output-1.png" width="942" height="758" class="figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="impact-scores-summary" class="level3">
<h3 class="anchored" data-anchor-id="impact-scores-summary" id="impact-scores-summary">Impact Scores Summary</h3>
<table class="caption-top table">
<colgroup>
<col style="width: 17%">
<col style="width: 19%">
<col style="width: 22%">
<col style="width: 21%">
<col style="width: 19%">
</colgroup>
<thead>
<tr class="header">
<th>Research Area</th>
<th>Overall Impact</th>
<th>Scientific Impact</th>
<th>Practical Impact</th>
<th>Recommendation</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Multimodal Extensions</strong></td>
<td>0.75</td>
<td>0.79</td>
<td>0.79</td>
<td>MEDIUM PRIORITY</td>
</tr>
<tr class="even">
<td><strong>Continual Learning</strong></td>
<td>0.72</td>
<td>0.86</td>
<td>0.72</td>
<td>MEDIUM PRIORITY</td>
</tr>
<tr class="odd">
<td><strong>Architecture Specific</strong></td>
<td>0.65</td>
<td>0.84</td>
<td>0.66</td>
<td>MEDIUM PRIORITY</td>
</tr>
<tr class="even">
<td><strong>Theoretical Analysis</strong></td>
<td>0.64</td>
<td>0.75</td>
<td>0.53</td>
<td>MEDIUM PRIORITY</td>
</tr>
<tr class="odd">
<td><strong>Efficiency Optimization</strong></td>
<td>0.63</td>
<td>0.72</td>
<td>0.80</td>
<td>MEDIUM PRIORITY</td>
</tr>
</tbody>
</table>
</section>
</section>
<section id="summary-of-key-points" class="level2">
<h2 class="anchored" data-anchor-id="summary-of-key-points" id="summary-of-key-points">Summary of Key Points</h2>
<ol type="1">
<li><strong>Conservative Hyperparameter Initialization</strong></li>
</ol>
<ul>
<li>Start with conservative hyperparameters (rank=16, alpha=16)</li>
<li>Gradually increase complexity based on validation performance</li>
<li>Avoid overfitting with aggressive initial configurations</li>
</ul>
<ol start="2" type="1">
<li><strong>Strategic Module Selection</strong></li>
</ol>
<ul>
<li>Focus on high-impact modules (attention layers, cross-modal fusion)</li>
<li>Prioritize modules that maximize efficiency gains</li>
<li>Consider computational cost vs.&nbsp;performance trade-offs</li>
</ul>
<ol start="3" type="1">
<li><strong>Comprehensive Monitoring</strong></li>
</ol>
<ul>
<li>Monitor both performance and efficiency metrics throughout development</li>
<li>Track convergence patterns and training stability</li>
<li>Implement early stopping based on validation metrics</li>
</ul>
<ol start="4" type="1">
<li><strong>Debugging and Analysis Tools</strong></li>
</ol>
<ul>
<li>Use appropriate debugging tools to understand adapter behavior</li>
<li>Analyze attention patterns and feature representations</li>
<li>Implement gradient flow monitoring for training diagnostics</li>
</ul>
<ol start="5" type="1">
<li><strong>Progressive Training Strategies</strong></li>
</ol>
<ul>
<li>Implement progressive training strategies for stable convergence</li>
<li>Use curriculum learning approaches when appropriate</li>
<li>Consider staged training with increasing complexity</li>
</ul>
<ol start="6" type="1">
<li><strong>Memory Optimization</strong></li>
</ol>
<ul>
<li>Apply memory optimization techniques for large-scale deployment</li>
<li>Implement gradient checkpointing and mixed precision training</li>
<li>Optimize batch sizes and sequence lengths</li>
</ul>
<ol start="7" type="1">
<li><strong>Production Monitoring</strong></li>
</ol>
<ul>
<li>Establish comprehensive monitoring for production systems</li>
<li>Track model performance drift and adaptation effectiveness</li>
<li>Implement automated alerts for performance degradation</li>
</ul>
<ol start="8" type="1">
<li><strong>Continuous Learning</strong></li>
</ol>
<ul>
<li>Stay updated with emerging techniques and research developments</li>
<li>Regularly evaluate new LoRA variants and improvements</li>
<li>Participate in community discussions and knowledge sharing</li>
</ul>
<ol start="9" type="1">
<li><strong>Task-Specific Optimization</strong></li>
</ol>
<ul>
<li>Consider task-specific configurations for optimal performance</li>
<li>Adapt hyperparameters based on domain requirements</li>
<li>Fine-tune approaches for different VLM applications</li>
</ul>
<ol start="10" type="1">
<li><strong>Robust Troubleshooting</strong></li>
</ol>
<ul>
<li>Implement robust troubleshooting procedures for common issues</li>
<li>Maintain comprehensive error handling and recovery mechanisms</li>
<li>Document solutions for recurring problems</li>
</ul>
</section>
<section id="implementation-checklist" class="level2">
<h2 class="anchored" data-anchor-id="implementation-checklist" id="implementation-checklist">Implementation Checklist</h2>
<ul class="task-list">
<li><label><input type="checkbox">Initialize with conservative hyperparameters</label></li>
<li><label><input type="checkbox">Identify and target high-impact modules</label></li>
<li><label><input type="checkbox">Set up comprehensive monitoring systems</label></li>
<li><label><input type="checkbox">Configure debugging and analysis tools</label></li>
<li><label><input type="checkbox">Implement progressive training pipeline</label></li>
<li><label><input type="checkbox">Apply memory optimization techniques</label></li>
<li><label><input type="checkbox">Establish production monitoring</label></li>
<li><label><input type="checkbox">Create update and maintenance procedures</label></li>
<li><label><input type="checkbox">Customize for specific task requirements</label></li>
<li><label><input type="checkbox">Prepare troubleshooting documentation</label></li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Pro Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Remember that successful LoRA implementation is an iterative process. Start simple, monitor carefully, and gradually optimize based on empirical results rather than theoretical assumptions.</p>
</div>
</div>
</section>
<section id="future-outlook" class="level2">
<h2 class="anchored" data-anchor-id="future-outlook" id="future-outlook">Future Outlook</h2>
<p>As the field continues to evolve, LoRA and its variants will likely become even more sophisticated, enabling more efficient and capable multimodal AI systems. The techniques and principles outlined in this guide provide a solid foundation for leveraging these advances in your own Vision-Language Model applications.</p>
</section>
<section id="resources-for-further-learning" class="level2">
<h2 class="anchored" data-anchor-id="resources-for-further-learning" id="resources-for-further-learning">Resources for Further Learning</h2>
<ul>
<li><strong>Hugging Face PEFT</strong>: Parameter-Efficient Fine-Tuning library</li>
<li><strong>LoRA Paper</strong>: “LoRA: Low-Rank Adaptation of Large Language Models” (Hu et al., 2021)</li>
<li><strong>CLIP Paper</strong>: “Learning Transferable Visual Representations from Natural Language Supervision” (Radford et al., 2021)</li>
<li><strong>LLaVA Paper</strong>: “Visual Instruction Tuning” (Liu et al., 2023)</li>
<li><strong>AdaLoRA Paper</strong>: “Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning” (Zhang et al., 2023)</li>
</ul>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references" id="references">References</h2>
<ol type="1">
<li><p>Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., … &amp; Chen, W. (2021). LoRA: Low-Rank Adaptation of Large Language Models. <em>arXiv preprint arXiv:2106.09685</em>.</p></li>
<li><p>Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., … &amp; Sutskever, I. (2021). Learning Transferable Visual Representations from Natural Language Supervision. <em>International Conference on Machine Learning</em>.</p></li>
<li><p>Li, J., Li, D., Xiong, C., &amp; Hoi, S. (2022). BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation. <em>International Conference on Machine Learning</em>.</p></li>
<li><p>Liu, H., Li, C., Wu, Q., &amp; Lee, Y. J. (2023). Visual Instruction Tuning. <em>arXiv preprint arXiv:2304.08485</em>.</p></li>
<li><p>Zhang, Q., Chen, M., Bukharin, A., He, P., Cheng, Y., Chen, W., &amp; Zhao, T. (2023). AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning. <em>International Conference on Learning Representations</em>.</p></li>
</ol>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[ControlNet: Revolutionizing AI Image Generation with Precise Control]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/generative-ai/control-net/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/generative-ai/control-net/</guid>
      <pubDate>Tue, 22 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="controlnet-revolutionizing-ai-image-generation-with-precise-control" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/generative-ai/control-net/cn.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>ControlNet represents a groundbreaking advancement in the field of AI-generated imagery, providing unprecedented control over the output of diffusion models like Stable Diffusion. Developed by researchers at Stanford University and released in early 2023, ControlNet has fundamentally changed how artists, designers, and developers approach AI image generation by enabling precise spatial control while maintaining the creative power of the underlying diffusion model.</p>
<p>Unlike traditional text-to-image generation where users rely solely on prompts and hope for desired compositions, ControlNet introduces conditional inputs that guide the generation process through various control mechanisms such as edge maps, depth maps, pose detection, and semantic segmentation. This innovation bridges the gap between creative intent and AI output, making AI image generation more predictable and professionally viable.</p>
</section>
<section id="technical-architecture" class="level2">
<h2 class="anchored" data-anchor-id="technical-architecture" id="technical-architecture">Technical Architecture</h2>
<section id="core-concept" class="level3">
<h3 class="anchored" data-anchor-id="core-concept" id="core-concept">Core Concept</h3>
<p>ControlNet operates as an additional neural network architecture that works alongside pre-trained diffusion models. Rather than modifying the original model weights, ControlNet creates a parallel pathway that processes control inputs and injects spatial guidance into the generation process. This approach preserves the original model’s capabilities while adding new functionality.</p>
<p>The architecture consists of two main components:</p>
<ol type="1">
<li><strong>Trainable Copy</strong>: A duplicate of the encoding layers from the original diffusion model</li>
<li><strong>Zero Convolution Layers</strong>: Special convolution layers initialized to zero that gradually learn to incorporate control information</li>
</ol>
</section>
<section id="how-controlnet-works" class="level3">
<h3 class="anchored" data-anchor-id="how-controlnet-works" id="how-controlnet-works">How ControlNet Works</h3>
<p>The ControlNet process follows these key steps:</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>ControlNet Process Flow
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Control Input Processing</strong>: The control image (edge map, depth map, etc.) is processed through the trainable copy of the original model’s encoder</li>
<li><strong>Feature Integration</strong>: Zero convolution layers combine the control features with the original model’s features</li>
<li><strong>Guided Generation</strong>: The combined features guide the denoising process, ensuring the output adheres to the spatial constraints while maintaining semantic coherence</li>
</ol>
</div>
</div>
<p>This design is particularly elegant because it allows the original model to retain its learned knowledge while gradually incorporating new control information through the zero-initialized layers.</p>
</section>
</section>
<section id="types-of-controlnet-models" class="level2">
<h2 class="anchored" data-anchor-id="types-of-controlnet-models" id="types-of-controlnet-models">Types of ControlNet Models</h2>
<section id="canny-edge-detection" class="level3">
<h3 class="anchored" data-anchor-id="canny-edge-detection" id="canny-edge-detection">Canny Edge Detection</h3>
<p>The Canny ControlNet is one of the most popular and versatile control methods. It uses the Canny edge detection algorithm to create line drawings that preserve the structural composition of reference images.</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Use Cases</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Technical Details</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li>Converting sketches to detailed artwork</li>
<li>Maintaining architectural layouts</li>
<li>Preserving character poses and proportions</li>
<li>Creating variations while keeping composition intact</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Canny edge detection identifies areas of rapid intensity change in images, creating clean line drawings that capture essential structural information without color or texture details. The ControlNet then uses these edges as spatial constraints during generation.</p>
</div>
</div>
</div>
</section>
<section id="depth-map-control" class="level3">
<h3 class="anchored" data-anchor-id="depth-map-control" id="depth-map-control">Depth Map Control</h3>
<p>Depth ControlNet utilizes depth information to control the three-dimensional structure of generated images. This is particularly powerful for architectural visualization and scene composition.</p>
<p><strong>Applications:</strong></p>
<ul>
<li>Interior design visualization</li>
<li>Landscape generation with specific topography</li>
<li>Product placement in 3D space</li>
<li>Architectural rendering</li>
</ul>
<p><strong>Implementation:</strong> Depth maps are typically generated using models like MiDaS (Monocular Depth Estimation) or can be manually created in 3D software. The depth information is encoded as grayscale images where darker pixels represent closer objects.</p>
</section>
<section id="openpose-human-detection" class="level3">
<h3 class="anchored" data-anchor-id="openpose-human-detection" id="openpose-human-detection">OpenPose Human Detection</h3>
<p>The OpenPose ControlNet focuses specifically on human pose control, using skeletal keypoint detection to guide the generation of human figures in specific poses.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>OpenPose Features
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>18 keypoint skeleton detection</li>
<li>Hand and face pose estimation</li>
<li>Multi-person pose control</li>
<li>Precise gesture and posture control</li>
</ul>
</div>
</div>
<p><strong>Professional Applications:</strong></p>
<ul>
<li>Fashion photography concepts</li>
<li>Sports pose illustration</li>
<li>Dance and movement studies</li>
<li>Character design and animation pre-visualization</li>
</ul>
</section>
<section id="scribble-control" class="level3">
<h3 class="anchored" data-anchor-id="scribble-control" id="scribble-control">Scribble Control</h3>
<p>Scribble ControlNet allows users to provide rough sketches or scribbles as control input, making it highly accessible for quick concept development.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li>No artistic skill required</li>
<li>Rapid prototyping</li>
<li>Intuitive control method</li>
<li>Compatible with touchscreen devices</li>
</ul>
</section>
<section id="semantic-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="semantic-segmentation" id="semantic-segmentation">Semantic Segmentation</h3>
<p>This ControlNet variant uses semantic segmentation maps where different colors represent different object categories (sky, trees, buildings, etc.).</p>
<p><strong>Professional Use Cases:</strong></p>
<ul>
<li>Landscape composition planning</li>
<li>Urban planning visualization</li>
<li>Environmental concept art</li>
<li>Scene layout design</li>
</ul>
</section>
<section id="normal-map-control" class="level3">
<h3 class="anchored" data-anchor-id="normal-map-control" id="normal-map-control">Normal Map Control</h3>
<p>Normal maps provide surface detail information, allowing for precise control over lighting and surface textures in generated images.</p>
<p><strong>Applications:</strong></p>
<ul>
<li>Product visualization</li>
<li>Material design</li>
<li>Texture synthesis</li>
<li>3D rendering enhancement</li>
</ul>
</section>
<section id="line-art-lineart" class="level3">
<h3 class="anchored" data-anchor-id="line-art-lineart" id="line-art-lineart">Line Art (Lineart)</h3>
<p>Specialized for clean line drawings, this ControlNet excels at converting anime-style line art into fully rendered illustrations.</p>
<p><strong>Strengths:</strong></p>
<ul>
<li>Anime and manga artwork</li>
<li>Technical illustrations</li>
<li>Clean vector-style outputs</li>
<li>Precise line preservation</li>
</ul>
</section>
</section>
<section id="advanced-controlnet-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-controlnet-techniques" id="advanced-controlnet-techniques">Advanced ControlNet Techniques</h2>
<section id="multi-controlnet-workflows" class="level3">
<h3 class="anchored" data-anchor-id="multi-controlnet-workflows" id="multi-controlnet-workflows">Multi-ControlNet Workflows</h3>
<p>One of ControlNet’s most powerful features is the ability to combine multiple control types simultaneously. This enables complex, multi-layered control over the generation process.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Common ControlNet Combinations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Canny + Depth</strong>: Structural control with 3D spatial awareness</li>
<li><strong>OpenPose + Canny</strong>: Human pose with environmental structure</li>
<li><strong>Depth + Semantic Segmentation</strong>: 3D layout with object placement control</li>
<li><strong>Normal Map + Canny</strong>: Surface detail with edge preservation</li>
</ul>
</div>
</div>
<p><strong>Implementation Considerations:</strong> When using multiple ControlNets, careful weight balancing is crucial. Each ControlNet has a weight parameter (typically 0.0 to 2.0) that determines its influence on the final output. Higher weights increase control strength but may reduce creative flexibility.</p>
</section>
<section id="controlnet-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="controlnet-preprocessing" id="controlnet-preprocessing">ControlNet Preprocessing</h3>
<p>Preprocessing is critical for optimal ControlNet performance. Each control type requires specific preprocessing to generate appropriate control images:</p>
<div id="tbl-preprocessing" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-preprocessing-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: ControlNet Preprocessing Parameters
</figcaption>
<div aria-describedby="tbl-preprocessing-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 28%">
<col style="width: 55%">
<col style="width: 15%">
</colgroup>
<thead>
<tr class="header">
<th>Control Type</th>
<th>Preprocessing Parameters</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Canny</strong></td>
<td>Low threshold: 100<br>High threshold: 200<br>Gaussian blur: Optional</td>
<td>Captures fine details and strong edges</td>
</tr>
<tr class="even">
<td><strong>Depth</strong></td>
<td>Depth estimation model<br>Depth range normalization<br>Smoothing</td>
<td>MiDaS, DPT model selection</td>
</tr>
<tr class="odd">
<td><strong>OpenPose</strong></td>
<td>Model selection<br>Keypoint confidence<br>Hand/face detection</td>
<td>OpenPose, MediaPipe, DWPose</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="regional-control-techniques" class="level3">
<h3 class="anchored" data-anchor-id="regional-control-techniques" id="regional-control-techniques">Regional Control Techniques</h3>
<p>Advanced users can implement regional control by masking different areas of the control input, allowing for varied control strength across different parts of the image.</p>
<p><strong>Methods:</strong></p>
<ul>
<li><strong>Masked ControlNet</strong>: Apply different control types to different regions</li>
<li><strong>Gradient Masks</strong>: Gradual transition between controlled and uncontrolled areas</li>
<li><strong>Layered Control</strong>: Stack multiple control influences with different regional masks</li>
</ul>
</section>
</section>
<section id="professional-workflows-and-applications" class="level2">
<h2 class="anchored" data-anchor-id="professional-workflows-and-applications" id="professional-workflows-and-applications">Professional Workflows and Applications</h2>
<section id="concept-art-and-pre-visualization" class="level3">
<h3 class="anchored" data-anchor-id="concept-art-and-pre-visualization" id="concept-art-and-pre-visualization">Concept Art and Pre-visualization</h3>
<p>ControlNet has revolutionized concept art workflows by enabling rapid iteration and precise control over composition and lighting.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Concept Art Workflow Example
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li>Create rough 3D blockout or sketch</li>
<li>Generate depth map and normal map</li>
<li>Use ControlNet to generate multiple style variations</li>
<li>Refine with additional ControlNet passes</li>
<li>Final polish with traditional digital painting techniques</li>
</ol>
</div>
</div>
</section>
<section id="architectural-visualization" class="level3">
<h3 class="anchored" data-anchor-id="architectural-visualization" id="architectural-visualization">Architectural Visualization</h3>
<p>Architects and designers use ControlNet to quickly generate photorealistic renderings from technical drawings and 3D models.</p>
<p><strong>Process:</strong></p>
<ol type="1">
<li>Export line drawings from CAD software</li>
<li>Create depth maps from 3D models</li>
<li>Generate semantic segmentation for material control</li>
<li>Use multi-ControlNet setup for comprehensive control</li>
<li>Iterate on lighting and atmosphere with prompt variations</li>
</ol>
</section>
<section id="fashion-and-product-design" class="level3">
<h3 class="anchored" data-anchor-id="fashion-and-product-design" id="fashion-and-product-design">Fashion and Product Design</h3>
<p>ControlNet enables precise product placement and modeling scenarios without expensive photoshoots.</p>
<p><strong>Applications:</strong></p>
<ul>
<li>Virtual try-on visualization</li>
<li>Product catalog generation</li>
<li>Fashion pose and styling exploration</li>
<li>Marketing material creation</li>
</ul>
</section>
<section id="film-and-animation-pre-production" class="level3">
<h3 class="anchored" data-anchor-id="film-and-animation-pre-production" id="film-and-animation-pre-production">Film and Animation Pre-production</h3>
<p>The film industry uses ControlNet for storyboarding, concept development, and pre-visualization.</p>
<p><strong>Benefits:</strong></p>
<ul>
<li>Rapid scene composition testing</li>
<li>Character pose and expression studies</li>
<li>Environment and set design exploration</li>
<li>Visual effects planning</li>
</ul>
</section>
</section>
<section id="technical-implementation" class="level2">
<h2 class="anchored" data-anchor-id="technical-implementation" id="technical-implementation">Technical Implementation</h2>
<section id="model-training-and-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="model-training-and-fine-tuning" id="model-training-and-fine-tuning">Model Training and Fine-tuning</h3>
<p>Understanding ControlNet training helps users optimize their workflows and create custom control types.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Training Process Steps
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Dataset Preparation</strong>: Paired images with corresponding control inputs</li>
<li><strong>Architecture Setup</strong>: Clone base model encoder layers</li>
<li><strong>Zero Convolution Initialization</strong>: Initialize control injection layers to zero</li>
<li><strong>Gradual Training</strong>: Slowly introduce control influence while preserving base model knowledge</li>
<li><strong>Validation</strong>: Test on diverse control inputs and prompts</li>
</ol>
</div>
</div>
<p><strong>Custom ControlNet Training:</strong> Organizations can train custom ControlNets for specific use cases:</p>
<ul>
<li>Industry-specific control types</li>
<li>Style-specific guidance</li>
<li>Domain-adapted models</li>
</ul>
</section>
<section id="integration-with-existing-pipelines" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-existing-pipelines" id="integration-with-existing-pipelines">Integration with Existing Pipelines</h3>
<p>ControlNet integrates with various AI art platforms and tools:</p>
<p><strong>Popular Integrations:</strong></p>
<ul>
<li><strong>Automatic1111 WebUI</strong>: Comprehensive ControlNet extension</li>
<li><strong>ComfyUI</strong>: Node-based workflow integration</li>
<li><strong>InvokeAI</strong>: Professional-grade implementation</li>
<li><strong>Diffusers Library</strong>: Python API integration</li>
<li><strong>Krita Plugin</strong>: Direct integration with digital painting software</li>
</ul>
</section>
<section id="hardware-and-performance-considerations" class="level3">
<h3 class="anchored" data-anchor-id="hardware-and-performance-considerations" id="hardware-and-performance-considerations">Hardware and Performance Considerations</h3>
<p>ControlNet requires additional computational resources compared to standard diffusion model inference.</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">System Requirements</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Optimization Techniques</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<ul>
<li><strong>VRAM</strong>: 6-8GB minimum, 12GB+ recommended for multi-ControlNet</li>
<li><strong>Processing Power</strong>: Modern GPU with CUDA support</li>
<li><strong>Storage</strong>: Additional space for ControlNet model files (1.5-5GB each)</li>
</ul>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<ul>
<li>Model quantization for reduced VRAM usage</li>
<li>Attention slicing for memory efficiency</li>
<li>Batch processing for multiple generations</li>
<li>Control strength adjustment for performance tuning</li>
</ul>
</div>
</div>
</div>
</section>
</section>
<section id="best-practices-and-tips" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-tips" id="best-practices-and-tips">Best Practices and Tips</h2>
<section id="control-weight-optimization" class="level3">
<h3 class="anchored" data-anchor-id="control-weight-optimization" id="control-weight-optimization">Control Weight Optimization</h3>
<p>Finding the right balance between control strength and creative freedom is crucial for professional results.</p>
<div id="tbl-weights" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-weights-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: Control Weight Guidelines
</figcaption>
<div aria-describedby="tbl-weights-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 39%">
<col style="width: 30%">
<col style="width: 30%">
</colgroup>
<thead>
<tr class="header">
<th>Control Strength</th>
<th>Weight Range</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>High Control</strong></td>
<td>1.0-1.5</td>
<td>Precise reproduction, minimal deviation</td>
</tr>
<tr class="even">
<td><strong>Medium Control</strong></td>
<td>0.7-1.0</td>
<td>Good balance of control and creativity</td>
</tr>
<tr class="odd">
<td><strong>Low Control</strong></td>
<td>0.3-0.7</td>
<td>Loose guidance, high creativity</td>
</tr>
<tr class="even">
<td><strong>Subtle Control</strong></td>
<td>0.1-0.3</td>
<td>Gentle influence, maximum flexibility</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="prompt-engineering-with-controlnet" class="level3">
<h3 class="anchored" data-anchor-id="prompt-engineering-with-controlnet" id="prompt-engineering-with-controlnet">Prompt Engineering with ControlNet</h3>
<p>Effective prompting becomes even more important when using ControlNet, as the prompt must work harmoniously with the control input.</p>
<p><strong>Strategies:</strong></p>
<ul>
<li><strong>Descriptive Consistency</strong>: Ensure prompts match control input content</li>
<li><strong>Style Specification</strong>: Clear artistic direction (photorealistic, artistic, etc.)</li>
<li><strong>Negative Prompting</strong>: Exclude unwanted elements that might conflict with control</li>
<li><strong>Weight Balancing</strong>: Balance prompt influence with control influence</li>
</ul>
</section>
<section id="quality-control-and-iteration" class="level3">
<h3 class="anchored" data-anchor-id="quality-control-and-iteration" id="quality-control-and-iteration">Quality Control and Iteration</h3>
<p>Professional workflows require consistent quality and the ability to iterate effectively.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Quality Assurance Checklist
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Multiple generation passes with slight variations</li>
<li>A/B testing different control strengths</li>
<li>Systematic prompt variations</li>
<li>Post-processing integration planning</li>
</ul>
</div>
</div>
</section>
</section>
<section id="limitations-and-considerations" class="level2">
<h2 class="anchored" data-anchor-id="limitations-and-considerations" id="limitations-and-considerations">Limitations and Considerations</h2>
<section id="technical-limitations" class="level3">
<h3 class="anchored" data-anchor-id="technical-limitations" id="technical-limitations">Technical Limitations</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Key Limitations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Control Precision</strong>: Cannot guarantee pixel-perfect reproduction of control inputs</li>
<li><strong>Model Compatibility</strong>: Trained for specific base models</li>
<li><strong>Computational Overhead</strong>: Resource-intensive multi-ControlNet workflows</li>
</ul>
</div>
</div>
</section>
<section id="creative-limitations" class="level3">
<h3 class="anchored" data-anchor-id="creative-limitations" id="creative-limitations">Creative Limitations</h3>
<ul>
<li><strong>Over-reliance on Control</strong>: Excessive control can limit AI’s creative potential</li>
<li><strong>Control Conflicts</strong>: Multiple control inputs may conflict with each other</li>
<li><strong>Learning Curve</strong>: Requires understanding of preprocessing techniques and parameter tuning</li>
</ul>
</section>
</section>
<section id="future-developments-and-trends" class="level2">
<h2 class="anchored" data-anchor-id="future-developments-and-trends" id="future-developments-and-trends">Future Developments and Trends</h2>
<section id="emerging-control-types" class="level3">
<h3 class="anchored" data-anchor-id="emerging-control-types" id="emerging-control-types">Emerging Control Types</h3>
<p>Research continues to expand ControlNet capabilities with new control modalities:</p>
<ul>
<li><strong>Audio-to-Visual Control</strong>: Synchronizing image generation with audio inputs</li>
<li><strong>Temporal Control</strong>: Video generation with frame-to-frame consistency</li>
<li><strong>3D Scene Control</strong>: Full 3D scene understanding and control</li>
<li><strong>Style Transfer Control</strong>: Precise artistic style application</li>
</ul>
</section>
<section id="integration-advancements" class="level3">
<h3 class="anchored" data-anchor-id="integration-advancements" id="integration-advancements">Integration Advancements</h3>
<ul>
<li><strong>Real-time Processing</strong>: Optimization for real-time creative workflows</li>
<li><strong>VR/AR Integration</strong>: Spatial computing applications</li>
<li><strong>Cloud-based Solutions</strong>: Accessible high-performance processing</li>
<li><strong>Mobile Optimization</strong>: Smartphone and tablet compatibility</li>
</ul>
</section>
<section id="professional-adoption" class="level3">
<h3 class="anchored" data-anchor-id="professional-adoption" id="professional-adoption">Professional Adoption</h3>
<p>Industries are increasingly integrating ControlNet into professional pipelines:</p>
<ul>
<li><strong>Architecture and Construction</strong>: Automated rendering from technical drawings</li>
<li><strong>Entertainment Industry</strong>: Rapid concept art and pre-visualization</li>
<li><strong>Marketing and Advertising</strong>: Dynamic content creation</li>
<li><strong>Education and Training</strong>: Visual learning material generation</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>ControlNet represents a paradigm shift in AI image generation, transforming it from a creative experiment to a professional tool capable of precise, predictable outputs. Its ability to bridge the gap between human creative intent and AI capability has opened new possibilities across industries, from entertainment and architecture to fashion and marketing.</p>
<p>The technology’s modular design, allowing multiple control types to work in concert, provides unprecedented flexibility for creative professionals. As the ecosystem continues to evolve with new control modalities, better integration tools, and improved performance optimization, ControlNet is positioned to become an indispensable part of the modern creative workflow.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Takeaway
</div>
</div>
<div class="callout-body-container callout-body">
<p>Success with ControlNet requires understanding both its technical capabilities and creative possibilities. By mastering the balance between control and creativity, understanding the strengths and limitations of different control types, and developing efficient workflows, users can harness ControlNet’s full potential to create compelling, professionally viable AI-generated imagery.</p>
</div>
</div>
<p>The future of AI-assisted creativity lies not in replacing human creativity but in augmenting it with precise, controllable tools like ControlNet. As these technologies continue to mature, they promise to democratize high-quality visual content creation while empowering professionals to achieve new levels of creative expression and productivity.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Stable Diffusion: A Complete Guide to Text-to-Image Generation]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion/</guid>
      <pubDate>Tue, 22 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="stable-diffusion-a-complete-guide-to-text-to-image-generation" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion/sd.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Stable Diffusion represents a watershed moment in artificial intelligence and creative technology. Released in August 2022 by Stability AI in collaboration with the CompVis Group at Ludwig Maximilian University of Munich and Runway, this open-source text-to-image model democratized AI-powered image generation in unprecedented ways. Unlike its predecessors that required massive computational resources and were locked behind proprietary APIs, Stable Diffusion can run on consumer hardware while producing remarkable results.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>The model’s ability to run on consumer hardware while producing high-quality results marked a significant departure from previous text-to-image models that required massive computational resources.</p>
</div>
</div>
<p>The model’s impact extends far beyond technical achievements. It has sparked conversations about creativity, copyright, artistic authenticity, and the future of visual media. From independent artists experimenting with new forms of expression to major studios integrating AI into production pipelines, Stable Diffusion has become a foundational technology in the rapidly evolving landscape of generative AI.</p>
</section>
<section id="technical-foundation" class="level2">
<h2 class="anchored" data-anchor-id="technical-foundation" id="technical-foundation">Technical Foundation</h2>
<section id="sec-diffusion-process" class="level3">
<h3 class="anchored" data-anchor-id="sec-diffusion-process" id="sec-diffusion-process">The Diffusion Process</h3>
<p>At its core, Stable Diffusion employs a diffusion model architecture, a class of generative models that learns to reverse a gradual noising process. The fundamental concept involves two phases: a forward process that systematically adds noise to clean images until they become pure noise, and a reverse process that learns to denoise these images step by step.</p>
<p>The forward process follows a Markov chain where at each timestep, Gaussian noise is added to the image according to a predefined noise schedule. This process is deterministic and can be expressed mathematically as:</p>
<p><span id="eq-forward-process"><span class="math display">\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) \tag{1}\]</span></span></p>
<p>Where <span class="math inline">\(\beta_t\)</span> represents the noise schedule, controlling how much noise is added at each step. The brilliance of diffusion models lies in the reverse process, where a neural network learns to predict and remove the noise that was added at each step.</p>
</section>
<section id="latent-space-innovation" class="level3">
<h3 class="anchored" data-anchor-id="latent-space-innovation" id="latent-space-innovation">Latent Space Innovation</h3>
<p>What sets Stable Diffusion apart from earlier diffusion models like DALL-E 2 is its operation in latent space rather than pixel space. This architectural decision, inspired by the work on Latent Diffusion Models (LDMs), provides several crucial advantages:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Computational Efficiency</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Semantic Coherence</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Training Stability</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>By working in a compressed latent representation, the model reduces computational requirements by factors of 4-8 compared to pixel-space diffusion. This compression is achieved through a Variational Autoencoder (VAE) that maps images to and from the latent space.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>The latent space captures high-level semantic features while abstracting away pixel-level details. This allows the model to focus on meaningful image composition rather than getting caught up in low-level noise patterns.</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>The reduced dimensionality and semantic organization of latent space leads to more stable training dynamics and better convergence properties.</p>
</div>
</div>
</div>
</section>
<section id="model-architecture-components" class="level3">
<h3 class="anchored" data-anchor-id="model-architecture-components" id="model-architecture-components">Model Architecture Components</h3>
<p>Stable Diffusion consists of three main components working in harmony:</p>
<p><strong>Text Encoder</strong>: The model uses CLIP’s text encoder to transform textual prompts into high-dimensional embeddings. These embeddings capture semantic relationships between words and concepts, enabling the model to understand complex prompt instructions. The text encoder processes prompts up to 77 tokens, with longer prompts being truncated.</p>
<p><strong>U-Net Denoising Network</strong>: The heart of the diffusion process is a U-Net architecture that predicts noise to be removed at each denoising step. This network incorporates cross-attention mechanisms to condition the denoising process on the text embeddings, allowing for precise control over image generation based on textual descriptions.</p>
<p><strong>Variational Autoencoder (VAE)</strong>: The VAE handles the conversion between pixel space and latent space. The encoder compresses 512×512 pixel images into 64×64 latent representations, while the decoder reconstructs high-resolution images from these compressed representations.</p>
</section>
</section>
<section id="training-and-data" class="level2">
<h2 class="anchored" data-anchor-id="training-and-data" id="training-and-data">Training and Data</h2>
<section id="dataset-composition" class="level3">
<h3 class="anchored" data-anchor-id="dataset-composition" id="dataset-composition">Dataset Composition</h3>
<p>Stable Diffusion was trained on a subset of LAION-5B, a massive dataset containing 5.85 billion image-text pairs scraped from the internet. The training set consisted of approximately 2.3 billion images, filtered and processed to ensure quality and relevance. This enormous scale allows the model to learn diverse visual concepts, artistic styles, and the relationships between textual descriptions and visual content.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Dataset Scale
</div>
</div>
<div class="callout-body-container callout-body">
<p>The training dataset of 2.3 billion images from LAION-5B represents one of the largest collections of image-text pairs used for training generative models at the time.</p>
</div>
</div>
<p>The dataset’s diversity is both a strength and a source of ongoing discussion. It includes artwork, photographs, diagrams, memes, and virtually every category of visual content found online. This comprehensive coverage enables the model’s remarkable versatility but also raises questions about copyright, consent, and the ethics of training on web-scraped content.</p>
</section>
<section id="training-process" class="level3">
<h3 class="anchored" data-anchor-id="training-process" id="training-process">Training Process</h3>
<p>The training process involves several stages and techniques designed to produce a robust and capable model:</p>
<p><strong>Noise Scheduling</strong>: The model learns to denoise images across different noise levels, from heavily corrupted images to nearly clean ones. This teaches the network to handle various levels of corruption and enables the flexible sampling procedures used during inference.</p>
<p><strong>Classifier-Free Guidance</strong>: During training, the model learns to generate images both with and without text conditioning. This technique, known as classifier-free guidance, allows for better control over how closely the generated image follows the text prompt during inference.</p>
<p><strong>Progressive Training</strong>: The training process often employs progressive techniques, starting with lower resolutions and gradually increasing to the full 512×512 resolution. This approach improves training efficiency and helps the model learn both coarse and fine-grained features.</p>
</section>
</section>
<section id="inference-and-generation-process" class="level2">
<h2 class="anchored" data-anchor-id="inference-and-generation-process" id="inference-and-generation-process">Inference and Generation Process</h2>
<section id="the-sampling-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="the-sampling-pipeline" id="the-sampling-pipeline">The Sampling Pipeline</h3>
<p>Image generation in Stable Diffusion follows a carefully orchestrated sampling pipeline that transforms random noise into coherent images:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    A[Random Noise] --&gt; B[Text Encoding]
    B --&gt; C[Iterative Denoising]
    C --&gt; D[VAE Decoding]
    D --&gt; E[Final Image]
    
    B --&gt; F[CLIP Text Encoder]
    C --&gt; G[U-Net Denoising]
    D --&gt; H[VAE Decoder]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<ol type="1">
<li><p><strong>Initialization</strong>: The process begins with pure random noise in the latent space, typically sampled from a standard Gaussian distribution.</p></li>
<li><p><strong>Text Processing</strong>: The input prompt is tokenized and encoded using the CLIP text encoder, producing conditioning embeddings that guide the generation process.</p></li>
<li><p><strong>Iterative Denoising</strong>: Over multiple timesteps (typically 20-50), the U-Net predicts and removes noise from the latent representation. Each step brings the latent closer to representing a coherent image that matches the text prompt.</p></li>
<li><p><strong>Decoding</strong>: The final denoised latent representation is passed through the VAE decoder to produce the final high-resolution image.</p></li>
</ol>
</section>
<section id="sampling-algorithms" class="level3">
<h3 class="anchored" data-anchor-id="sampling-algorithms" id="sampling-algorithms">Sampling Algorithms</h3>
<p>Various sampling algorithms can be employed during inference, each with different trade-offs between speed and quality:</p>
<div id="tbl-samplers" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-samplers-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Comparison of sampling algorithms
</figcaption>
<div aria-describedby="tbl-samplers-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Algorithm</th>
<th>Speed</th>
<th>Quality</th>
<th>Deterministic</th>
<th>Best Use Case</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>DDPM</td>
<td>Slow</td>
<td>High</td>
<td>No</td>
<td>High-quality generation</td>
</tr>
<tr class="even">
<td>DDIM</td>
<td>Fast</td>
<td>High</td>
<td>Yes</td>
<td>Reproducible results</td>
</tr>
<tr class="odd">
<td>Euler</td>
<td>Medium</td>
<td>Good</td>
<td>No</td>
<td>Balanced approach</td>
</tr>
<tr class="even">
<td>DPM++</td>
<td>Fast</td>
<td>High</td>
<td>Yes</td>
<td>Production workflows</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="guidance-and-control" class="level3">
<h3 class="anchored" data-anchor-id="guidance-and-control" id="guidance-and-control">Guidance and Control</h3>
<p><strong>Classifier-Free Guidance (CFG)</strong>: This technique allows users to control how closely the generated image follows the text prompt. Higher CFG values produce images that more strictly adhere to the prompt but may sacrifice diversity and naturalness.</p>
<p><strong>Negative Prompting</strong>: By specifying what should NOT appear in the image, users can steer generation away from unwanted elements or styles.</p>
<p><strong>Seed Control</strong>: Random seeds provide reproducibility and enable users to generate variations of the same basic composition.</p>
</section>
</section>
<section id="advanced-techniques-and-applications" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques-and-applications" id="advanced-techniques-and-applications">Advanced Techniques and Applications</h2>
<section id="image-to-image-generation" class="level3">
<h3 class="anchored" data-anchor-id="image-to-image-generation" id="image-to-image-generation">Image-to-Image Generation</h3>
<p>Beyond text-to-image generation, Stable Diffusion supports image-to-image transformation, where an existing image serves as a starting point rather than random noise. This technique enables:</p>
<ul>
<li><strong>Style Transfer</strong>: Transforming images into different artistic styles while preserving content structure</li>
<li><strong>Image Editing</strong>: Making targeted modifications to existing images based on textual descriptions</li>
<li><strong>Variation Generation</strong>: Creating multiple variations of a base image with controlled differences</li>
</ul>
</section>
<section id="inpainting-and-outpainting" class="level3">
<h3 class="anchored" data-anchor-id="inpainting-and-outpainting" id="inpainting-and-outpainting">Inpainting and Outpainting</h3>
<p>Specialized versions of Stable Diffusion can fill in masked regions of images (inpainting) or extend images beyond their original boundaries (outpainting). These capabilities enable sophisticated image editing workflows and creative applications.</p>
</section>
<section id="controlnet-and-conditioning" class="level3">
<h3 class="anchored" data-anchor-id="controlnet-and-conditioning" id="controlnet-and-conditioning">ControlNet and Conditioning</h3>
<p>ControlNet represents a significant advancement in controllable generation, allowing users to guide image generation using structural inputs like edge maps, depth maps, pose information, or segmentation masks. This level of control bridges the gap between random generation and precise artistic intent.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>ControlNet Applications
</div>
</div>
<div class="callout-body-container callout-body">
<p>ControlNet enables precise control over composition, pose, and structure while maintaining the creative power of text-to-image generation.</p>
</div>
</div>
</section>
<section id="fine-tuning-and-customization" class="level3">
<h3 class="anchored" data-anchor-id="fine-tuning-and-customization" id="fine-tuning-and-customization">Fine-tuning and Customization</h3>
<p>The open-source nature of Stable Diffusion has spawned numerous fine-tuning techniques:</p>
<p><strong>DreamBooth</strong>: Enables training the model to generate images of specific subjects or styles using just a few example images.</p>
<p><strong>Textual Inversion</strong>: Learns new tokens that represent specific concepts, styles, or objects not well-represented in the original training data.</p>
<p><strong>LoRA (Low-Rank Adaptation)</strong>: An efficient fine-tuning method that requires minimal computational resources while enabling significant customization.</p>
</section>
</section>
<section id="performance-and-hardware-considerations" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-hardware-considerations" id="performance-and-hardware-considerations">Performance and Hardware Considerations</h2>
<section id="system-requirements" class="level3">
<h3 class="anchored" data-anchor-id="system-requirements" id="system-requirements">System Requirements</h3>
<p>Stable Diffusion’s hardware requirements vary significantly based on the desired generation speed and image quality:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Minimum Requirements</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Recommended Specifications</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-3" role="tab" aria-controls="tabset-2-3" aria-selected="false" href="">Optimization Strategies</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<ul>
<li>6GB VRAM (for basic 512×512 generation)</li>
<li>16GB system RAM</li>
<li>Modern CPU (any architecture from the last 5 years)</li>
</ul>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<ul>
<li>12GB+ VRAM (enables higher resolutions and faster generation)</li>
<li>32GB system RAM (for complex workflows and batch processing)</li>
<li>High-end GPU (RTX 3080/4070 or better)</li>
</ul>
</div>
<div id="tabset-2-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-3-tab">
<ul>
<li>Half-precision (FP16) inference reduces memory usage significantly</li>
<li>Attention optimization techniques (xFormers, Flash Attention)</li>
<li>Model quantization for further memory reduction</li>
<li>Tiled VAE for generating images larger than native resolution</li>
</ul>
</div>
</div>
</div>
</section>
<section id="cloud-and-edge-deployment" class="level3">
<h3 class="anchored" data-anchor-id="cloud-and-edge-deployment" id="cloud-and-edge-deployment">Cloud and Edge Deployment</h3>
<p>The model’s relatively modest requirements have enabled deployment across various platforms:</p>
<p><strong>Cloud Platforms</strong>: Services like RunPod, Vast.ai, and Google Colab provide accessible cloud-based generation.</p>
<p><strong>Edge Deployment</strong>: Optimized versions can run on mobile devices and embedded systems, though with reduced capability.</p>
<p><strong>Web Interfaces</strong>: Numerous web-based interfaces democratize access without requiring technical setup.</p>
</section>
</section>
<section id="sec-ethics" class="level2">
<h2 class="anchored" data-anchor-id="sec-ethics" id="sec-ethics">Ethical Considerations and Societal Impact</h2>
<section id="copyright-and-intellectual-property" class="level3">
<h3 class="anchored" data-anchor-id="copyright-and-intellectual-property" id="copyright-and-intellectual-property">Copyright and Intellectual Property</h3>
<p>Stable Diffusion’s training on web-scraped imagery has sparked significant debate about copyright, fair use, and intellectual property rights. Key concerns include:</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Copyright Concerns
</div>
</div>
<div class="callout-body-container callout-body">
<p>The use of copyrighted material in training data without explicit consent raises ongoing legal and ethical questions about fair use and artist rights.</p>
</div>
</div>
<p><strong>Artist Rights</strong>: Many artists’ works were included in training data without explicit consent, raising questions about compensation and attribution.</p>
<p><strong>Style Mimicry</strong>: The model’s ability to generate images “in the style of” specific artists has led to discussions about artistic authenticity and economic impact.</p>
<p><strong>Commercial Use</strong>: The boundaries between transformative use and copyright infringement remain legally unclear in many jurisdictions.</p>
</section>
<section id="bias-and-representation" class="level3">
<h3 class="anchored" data-anchor-id="bias-and-representation" id="bias-and-representation">Bias and Representation</h3>
<p>Like many AI systems trained on internet data, Stable Diffusion exhibits various biases:</p>
<ul>
<li><strong>Demographic Bias</strong>: Default representations often skew toward certain demographics, reflecting biases present in the training data</li>
<li><strong>Cultural Bias</strong>: The model’s understanding of concepts can be influenced by Western-centric perspectives prevalent in English-language internet content</li>
<li><strong>Historical Bias</strong>: Temporal biases in training data can lead to outdated or stereotypical representations</li>
</ul>
</section>
<section id="misuse-and-safety-concerns" class="level3">
<h3 class="anchored" data-anchor-id="misuse-and-safety-concerns" id="misuse-and-safety-concerns">Misuse and Safety Concerns</h3>
<p>The democratization of high-quality image generation raises several safety considerations:</p>
<p><strong>Deepfakes and Misinformation</strong>: While not specifically designed for photorealistic human faces, the technology contributes to broader concerns about synthetic media and misinformation.</p>
<p><strong>Harmful Content</strong>: Despite built-in safety filters, determined users may find ways to generate inappropriate or harmful content.</p>
<p><strong>Economic Disruption</strong>: The technology’s impact on creative industries continues to evolve, with both opportunities and challenges for traditional creative professions.</p>
</section>
</section>
<section id="the-open-source-ecosystem" class="level2">
<h2 class="anchored" data-anchor-id="the-open-source-ecosystem" id="the-open-source-ecosystem">The Open Source Ecosystem</h2>
<section id="community-contributions" class="level3">
<h3 class="anchored" data-anchor-id="community-contributions" id="community-contributions">Community Contributions</h3>
<p>The open-source release of Stable Diffusion catalyzed an unprecedented wave of community innovation:</p>
<p><strong>User Interfaces</strong>: Projects like AUTOMATIC1111’s WebUI, ComfyUI, and InvokeAI provide accessible interfaces for non-technical users.</p>
<p><strong>Extensions and Plugins</strong>: Thousands of community-developed extensions add functionality ranging from advanced sampling methods to integration with other AI models.</p>
<p><strong>Model Variants</strong>: The community has created countless fine-tuned versions optimized for specific use cases, artistic styles, or quality improvements.</p>
</section>
<section id="commercial-applications" class="level3">
<h3 class="anchored" data-anchor-id="commercial-applications" id="commercial-applications">Commercial Applications</h3>
<p>Despite being open-source, Stable Diffusion has enabled numerous commercial applications:</p>
<ul>
<li><strong>Creative Tools</strong>: Integration into professional creative software like Photoshop, Blender, and specialized AI art platforms</li>
<li><strong>Marketing and Advertising</strong>: Rapid prototyping of visual concepts and personalized content generation</li>
<li><strong>Gaming and Entertainment</strong>: Asset generation for games, concept art creation, and virtual world building</li>
<li><strong>Education and Research</strong>: Teaching aids, scientific visualization, and research tool development</li>
</ul>
</section>
</section>
<section id="future-developments-and-research-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-developments-and-research-directions" id="future-developments-and-research-directions">Future Developments and Research Directions</h2>
<section id="technical-improvements" class="level3">
<h3 class="anchored" data-anchor-id="technical-improvements" id="technical-improvements">Technical Improvements</h3>
<p>Active areas of research and development include:</p>
<p><strong>Higher Resolution Generation</strong>: Techniques for generating images at resolutions significantly higher than the training resolution of 512×512.</p>
<p><strong>Improved Consistency</strong>: Better temporal consistency for video generation and improved coherence across multiple images.</p>
<p><strong>Efficiency Optimizations</strong>: Faster sampling methods, more efficient architectures, and better hardware utilization.</p>
<p><strong>Multi-modal Integration</strong>: Better integration with other modalities like audio, 3D geometry, and temporal sequences.</p>
</section>
<section id="architectural-innovations" class="level3">
<h3 class="anchored" data-anchor-id="architectural-innovations" id="architectural-innovations">Architectural Innovations</h3>
<p><strong>Transformer-based Diffusion</strong>: Exploring alternatives to the U-Net architecture using transformer models for potentially better scalability and performance.</p>
<p><strong>Continuous Diffusion</strong>: Moving beyond discrete timesteps to continuous-time formulations that may offer theoretical and practical advantages.</p>
<p><strong>Hierarchical Generation</strong>: Multi-scale approaches that generate images at multiple resolutions simultaneously for better detail and consistency.</p>
</section>
<section id="emerging-applications" class="level3">
<h3 class="anchored" data-anchor-id="emerging-applications" id="emerging-applications">Emerging Applications</h3>
<p><strong>3D Generation</strong>: Extensions of diffusion models to 3D object and scene generation, opening new possibilities for content creation.</p>
<p><strong>Video Generation</strong>: Temporal extensions that enable consistent video generation from text descriptions.</p>
<p><strong>Interactive Generation</strong>: Real-time generation and editing capabilities that enable new forms of creative interaction.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Stable Diffusion represents more than just a technical achievement; it embodies a paradigm shift in how we think about creativity, accessibility, and the democratization of advanced AI capabilities. By making high-quality text-to-image generation freely available and runnable on consumer hardware, it has lowered barriers to entry that previously restricted such capabilities to well-funded research labs and major technology companies.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Impact Summary
</div>
</div>
<div class="callout-body-container callout-body">
<p>Stable Diffusion’s open-source approach has democratized access to advanced AI image generation, sparking innovation while raising important questions about creativity, copyright, and the future of visual media.</p>
</div>
</div>
<p>The model’s impact extends across multiple domains, from empowering individual creators with new tools for expression to enabling businesses to rapidly prototype visual concepts. It has accelerated research in generative AI, inspired countless derivative works and improvements, and sparked important conversations about the future of human creativity in an age of artificial intelligence.</p>
<p>However, this democratization also brings challenges. Questions about copyright, consent, bias, and the economic impact on creative industries remain largely unresolved. As the technology continues to evolve, balancing innovation with ethical considerations will be crucial for realizing its positive potential while mitigating potential harms.</p>
<p>Looking forward, Stable Diffusion has established a foundation that will likely influence AI development for years to come. Its open-source ethos has proven that powerful AI capabilities need not remain locked behind corporate walls, while its technical innovations continue to inspire new research directions and applications.</p>
<p>The story of Stable Diffusion is still being written, with each new fine-tuned model, innovative application, and community contribution adding new chapters to this remarkable technological narrative. As we stand at this inflection point in the history of AI and creativity, Stable Diffusion serves as both a powerful tool and a glimpse into a future where the boundaries between human and artificial creativity continue to blur and evolve.</p>
<p>Whether one views it as a revolutionary creative tool, a concerning disruption to traditional industries, or simply an impressive technical achievement, Stable Diffusion undeniably represents a significant milestone in the ongoing evolution of artificial intelligence and its integration into human creative processes. Its legacy will likely be measured not just in the images it generates, but in the broader conversations, innovations, and transformations it has catalyzed across society.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Stable Diffusion with ControlNet]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion-with-control-net/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion-with-control-net/</guid>
      <pubDate>Tue, 22 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-stable-diffusion-with-controlnet" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/generative-ai/stable-diffusion-with-control-net/sd-cn.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>ControlNet is a neural network architecture that allows you to control Stable Diffusion image generation with additional input conditions like edge maps, depth maps, poses, and more. It provides precise control over the composition, structure, and layout of generated images while maintaining the creative power of diffusion models.</p>
<section id="key-benefits" class="level3">
<h3 class="anchored" data-anchor-id="key-benefits" id="key-benefits">Key Benefits</h3>
<ul>
<li><strong>Precise Control</strong>: Direct influence over image structure and composition</li>
<li><strong>Consistency</strong>: Maintain specific poses, edges, or layouts across generations</li>
<li><strong>Flexibility</strong>: Multiple conditioning types for different use cases</li>
<li><strong>Quality</strong>: Enhanced output quality with structured guidance</li>
</ul>
</section>
</section>
<section id="installation-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-setup" id="installation-setup">Installation &amp; Setup</h2>
<section id="environment-setup" class="level3">
<h3 class="anchored" data-anchor-id="environment-setup" id="environment-setup">Environment Setup</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create conda environment</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> create <span class="at">-n</span> controlnet python=3.10</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> activate controlnet</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyTorch with CUDA support</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision torchaudio <span class="at">--index-url</span> https://download.pytorch.org/whl/cu118</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Install core dependencies</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install diffusers transformers accelerate</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install controlnet-aux</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install opencv-python</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install xformers  <span class="co"># Optional but recommended for performance</span></span></code></pre></div></div>
</section>
<section id="required-libraries" class="level3">
<h3 class="anchored" data-anchor-id="required-libraries" id="required-libraries">Required Libraries</h3>
<div id="imports" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> diffusers <span class="im">import</span> (</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    StableDiffusionControlNetPipeline,</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    ControlNetModel,</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    UniPCMultistepScheduler</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> controlnet_aux <span class="im">import</span> (</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    CannyDetector,</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    OpenposeDetector,</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    MidasDetector,</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>    HEDdetector,</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    MLSDdetector,</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    LineartDetector,</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    LineartAnimeDetector</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> pipeline</span></code></pre></div></div>
</div>
</section>
<section id="basic-setup-function" class="level3">
<h3 class="anchored" data-anchor-id="basic-setup-function" id="basic-setup-function">Basic Setup Function</h3>
<div id="setup-pipeline" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_controlnet_pipeline(controlnet_type<span class="op">=</span><span class="st">"canny"</span>, model_id<span class="op">=</span><span class="st">"runwayml/stable-diffusion-v1-5"</span>):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Setup ControlNet pipeline with specified type and model</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="co">        controlnet_type: Type of ControlNet ('canny', 'openpose', 'depth', etc.)</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co">        model_id: Base Stable Diffusion model to use</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="co">    Returns:</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co">        Configured pipeline</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ControlNet model mapping</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    controlnet_models <span class="op">=</span> {</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">"canny"</span>: <span class="st">"lllyasviel/sd-controlnet-canny"</span>,</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">"openpose"</span>: <span class="st">"lllyasviel/sd-controlnet-openpose"</span>,</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="st">"depth"</span>: <span class="st">"lllyasviel/sd-controlnet-depth"</span>,</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="st">"hed"</span>: <span class="st">"lllyasviel/sd-controlnet-hed"</span>,</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        <span class="st">"mlsd"</span>: <span class="st">"lllyasviel/sd-controlnet-mlsd"</span>,</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        <span class="st">"normal"</span>: <span class="st">"lllyasviel/sd-controlnet-normal-map"</span>,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        <span class="st">"scribble"</span>: <span class="st">"lllyasviel/sd-controlnet-scribble"</span>,</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        <span class="st">"seg"</span>: <span class="st">"lllyasviel/sd-controlnet-seg"</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load ControlNet</span></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    controlnet <span class="op">=</span> ControlNetModel.from_pretrained(</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        controlnet_models[controlnet_type],</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        torch_dtype<span class="op">=</span>torch.float16</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create pipeline</span></span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    pipe <span class="op">=</span> StableDiffusionControlNetPipeline.from_pretrained(</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        model_id,</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        controlnet<span class="op">=</span>controlnet,</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        torch_dtype<span class="op">=</span>torch.float16,</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        safety_checker<span class="op">=</span><span class="va">None</span>,</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        requires_safety_checker<span class="op">=</span><span class="va">False</span></span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimize for GPU</span></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>    pipe <span class="op">=</span> pipe.to(<span class="st">"cuda"</span>)</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>    pipe.scheduler <span class="op">=</span> UniPCMultistepScheduler.from_config(pipe.scheduler.config)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Enable memory efficient attention</span></span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>    pipe.enable_model_cpu_offload()</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>    pipe.enable_xformers_memory_efficient_attention()</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> pipe</span></code></pre></div></div>
</div>
</section>
</section>
<section id="understanding-controlnet" class="level2">
<h2 class="anchored" data-anchor-id="understanding-controlnet" id="understanding-controlnet">Understanding ControlNet</h2>
<section id="core-concept" class="level3">
<h3 class="anchored" data-anchor-id="core-concept" id="core-concept">Core Concept</h3>
<p>ControlNet works by adding additional neural network layers to Stable Diffusion that process conditioning inputs (like edge maps or poses) and inject this information into the generation process. The original model weights remain frozen while the ControlNet layers learn to translate conditioning inputs into meaningful guidance.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Architecture Overview
</div>
</div>
<div class="callout-body-container callout-body">
<p>ControlNet maintains the original Stable Diffusion weights while adding trainable layers that process conditioning inputs and inject control signals at multiple resolution levels in the UNet architecture.</p>
</div>
</div>
</section>
<section id="architecture-overview-1" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview-1" id="architecture-overview-1">Architecture Overview</h3>
<div id="architecture-concept" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ControlNetArchitecture:</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Conceptual overview of ControlNet architecture</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.encoder_layers <span class="op">=</span> []  <span class="co"># Process conditioning input</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.zero_convolutions <span class="op">=</span> []  <span class="co"># Ensure training stability</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.connection_layers <span class="op">=</span> []  <span class="co"># Connect to UNet blocks</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x_noisy, timestep, conditioning_input):</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process conditioning input through encoder</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        control_features <span class="op">=</span> <span class="va">self</span>.process_conditioning(conditioning_input)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply zero convolutions for stable training</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        control_features <span class="op">=</span> <span class="va">self</span>.apply_zero_convs(control_features)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inject into UNet at multiple resolution levels</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.inject_control(x_noisy, timestep, control_features)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="basic-implementation" class="level2">
<h2 class="anchored" data-anchor-id="basic-implementation" id="basic-implementation">Basic Implementation</h2>
<section id="canny-edge-control" class="level3">
<h3 class="anchored" data-anchor-id="canny-edge-control" id="canny-edge-control">Canny Edge Control</h3>
<p>Canny edge detection provides structural control based on edges in the input image.</p>
<div id="canny-generation" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_canny(pipe, image_path, prompt, negative_prompt<span class="op">=</span><span class="st">""</span>, num_inference_steps<span class="op">=</span><span class="dv">20</span>):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate image using Canny edge control</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load and preprocess image</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> original_image.resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create Canny detector</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    canny_detector <span class="op">=</span> CannyDetector()</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate Canny edge map</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    canny_image <span class="op">=</span> canny_detector(original_image)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate image</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>canny_image,</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        negative_prompt<span class="op">=</span>negative_prompt,</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span>num_inference_steps,</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span><span class="fl">1.0</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>], canny_image</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>pipe <span class="op">=</span> setup_controlnet_pipeline(<span class="st">"canny"</span>)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>prompt <span class="op">=</span> <span class="st">"a beautiful landscape painting, oil painting style, vibrant colors"</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>generated_image, control_image <span class="op">=</span> generate_with_canny(pipe, <span class="st">"input.jpg"</span>, prompt)</span></code></pre></div></div>
</div>
</section>
<section id="openpose-human-pose-control" class="level3">
<h3 class="anchored" data-anchor-id="openpose-human-pose-control" id="openpose-human-pose-control">OpenPose Human Pose Control</h3>
<p>OpenPose allows control over human poses and body positions.</p>
<div id="openpose-generation" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_openpose(pipe, image_path, prompt, negative_prompt<span class="op">=</span><span class="st">""</span>):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate image using OpenPose control</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load image</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> original_image.resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create OpenPose detector</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    openpose_detector <span class="op">=</span> OpenposeDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate pose keypoints</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    pose_image <span class="op">=</span> openpose_detector(original_image)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate image with pose control</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>pose_image,</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        negative_prompt<span class="op">=</span>negative_prompt,</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span><span class="fl">1.0</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>], pose_image</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>pipe <span class="op">=</span> setup_controlnet_pipeline(<span class="st">"openpose"</span>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>prompt <span class="op">=</span> <span class="st">"a robot dancing, futuristic style, neon lighting"</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>generated_image, pose_image <span class="op">=</span> generate_with_openpose(pipe, <span class="st">"person_dancing.jpg"</span>, prompt)</span></code></pre></div></div>
</div>
</section>
<section id="depth-map-control" class="level3">
<h3 class="anchored" data-anchor-id="depth-map-control" id="depth-map-control">Depth Map Control</h3>
<p>Depth maps provide 3D structure control for more realistic spatial relationships.</p>
<div id="depth-generation" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_depth(pipe, image_path, prompt, negative_prompt<span class="op">=</span><span class="st">""</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate image using depth map control</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load image</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> original_image.resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create depth estimator</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    depth_estimator <span class="op">=</span> pipeline(<span class="st">'depth-estimation'</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate depth map</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    depth <span class="op">=</span> depth_estimator(original_image)[<span class="st">'depth'</span>]</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    depth_image <span class="op">=</span> Image.fromarray(np.array(depth)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate image with depth control</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>depth_image,</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        negative_prompt<span class="op">=</span>negative_prompt,</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span><span class="fl">1.0</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>], depth_image</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-controlnet-types" class="level2">
<h2 class="anchored" data-anchor-id="advanced-controlnet-types" id="advanced-controlnet-types">Advanced ControlNet Types</h2>
<section id="line-art-control" class="level3">
<h3 class="anchored" data-anchor-id="line-art-control" id="line-art-control">Line Art Control</h3>
<p>Perfect for anime-style generation and clean line art conversion.</p>
<div id="lineart-setup" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_lineart_pipeline():</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Setup pipeline for line art control</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    controlnet <span class="op">=</span> ControlNetModel.from_pretrained(</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"lllyasviel/sd-controlnet-lineart"</span>,</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        torch_dtype<span class="op">=</span>torch.float16</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    pipe <span class="op">=</span> StableDiffusionControlNetPipeline.from_pretrained(</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"runwayml/stable-diffusion-v1-5"</span>,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        controlnet<span class="op">=</span>controlnet,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        torch_dtype<span class="op">=</span>torch.float16</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    ).to(<span class="st">"cuda"</span>)</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> pipe</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_lineart(pipe, image_path, prompt, anime_style<span class="op">=</span><span class="va">False</span>):</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate using line art control</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Choose detector based on style</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> anime_style:</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        detector <span class="op">=</span> LineartAnimeDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        detector <span class="op">=</span> LineartDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    lineart_image <span class="op">=</span> detector(original_image)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>lineart_image,</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span><span class="fl">1.0</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>], lineart_image</span></code></pre></div></div>
</div>
</section>
<section id="scribble-control" class="level3">
<h3 class="anchored" data-anchor-id="scribble-control" id="scribble-control">Scribble Control</h3>
<p>Allows rough sketches to guide generation.</p>
<div id="scribble-control" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_scribble_from_sketch(sketch_path):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Process a rough sketch for scribble control</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    sketch <span class="op">=</span> cv2.imread(sketch_path, <span class="dv">0</span>)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply threshold to create clean binary image</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    _, binary <span class="op">=</span> cv2.threshold(sketch, <span class="dv">127</span>, <span class="dv">255</span>, cv2.THRESH_BINARY)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to 3-channel RGB</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    scribble <span class="op">=</span> cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> Image.fromarray(scribble)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_scribble(pipe, scribble_image, prompt):</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate from scribble input</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>scribble_image,</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span><span class="fl">1.0</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>]</span></code></pre></div></div>
</div>
</section>
<section id="normal-map-control" class="level3">
<h3 class="anchored" data-anchor-id="normal-map-control" id="normal-map-control">Normal Map Control</h3>
<p>Provides detailed surface normal information for realistic lighting.</p>
<div id="normal-map" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_normal_map(image_path):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate normal map from image</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load depth estimator</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    depth_estimator <span class="op">=</span> MidasDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate depth map</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    depth_map <span class="op">=</span> depth_estimator(image)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert depth to normal map (simplified)</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    depth_array <span class="op">=</span> np.array(depth_map)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate gradients</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    grad_x <span class="op">=</span> cv2.Sobel(depth_array, cv2.CV_64F, <span class="dv">1</span>, <span class="dv">0</span>, ksize<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    grad_y <span class="op">=</span> cv2.Sobel(depth_array, cv2.CV_64F, <span class="dv">0</span>, <span class="dv">1</span>, ksize<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create normal vectors</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    normal_x <span class="op">=</span> <span class="op">-</span>grad_x <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    normal_y <span class="op">=</span> <span class="op">-</span>grad_y <span class="op">/</span> <span class="fl">255.0</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    normal_z <span class="op">=</span> np.ones_like(normal_x)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Normalize</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    length <span class="op">=</span> np.sqrt(normal_x<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> normal_y<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> normal_z<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>    normal_x <span class="op">/=</span> length</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>    normal_y <span class="op">/=</span> length</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>    normal_z <span class="op">/=</span> length</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to 0-255 range</span></span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>    normal_map <span class="op">=</span> np.stack([</span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>        ((normal_x <span class="op">+</span> <span class="dv">1</span>) <span class="op">*</span> <span class="fl">127.5</span>).astype(np.uint8),</span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>        ((normal_y <span class="op">+</span> <span class="dv">1</span>) <span class="op">*</span> <span class="fl">127.5</span>).astype(np.uint8),</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>        ((normal_z <span class="op">+</span> <span class="dv">1</span>) <span class="op">*</span> <span class="fl">127.5</span>).astype(np.uint8)</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>    ], axis<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> Image.fromarray(normal_map)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="combining-multiple-controlnets" class="level2">
<h2 class="anchored" data-anchor-id="combining-multiple-controlnets" id="combining-multiple-controlnets">Combining Multiple ControlNets</h2>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Multi-ControlNet Benefits
</div>
</div>
<div class="callout-body-container callout-body">
<p>Combining multiple ControlNets allows for more sophisticated control by leveraging different types of conditioning simultaneously, such as pose + depth or edges + normal maps.</p>
</div>
</div>
<section id="multi-controlnet-setup" class="level3">
<h3 class="anchored" data-anchor-id="multi-controlnet-setup" id="multi-controlnet-setup">Multi-ControlNet Setup</h3>
<div id="multi-controlnet-setup" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_multi_controlnet_pipeline(controlnet_types):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Setup pipeline with multiple ControlNets</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="im">from</span> diffusers <span class="im">import</span> MultiControlNetModel</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    controlnet_models <span class="op">=</span> {</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"canny"</span>: <span class="st">"lllyasviel/sd-controlnet-canny"</span>,</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"openpose"</span>: <span class="st">"lllyasviel/sd-controlnet-openpose"</span>,</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"depth"</span>: <span class="st">"lllyasviel/sd-controlnet-depth"</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load multiple ControlNets</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    controlnets <span class="op">=</span> [</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        ControlNetModel.from_pretrained(controlnet_models[ctype], torch_dtype<span class="op">=</span>torch.float16)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> ctype <span class="kw">in</span> controlnet_types</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create multi-ControlNet</span></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    multi_controlnet <span class="op">=</span> MultiControlNetModel(controlnets)</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create pipeline</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    pipe <span class="op">=</span> StableDiffusionControlNetPipeline.from_pretrained(</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        <span class="st">"runwayml/stable-diffusion-v1-5"</span>,</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        controlnet<span class="op">=</span>multi_controlnet,</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        torch_dtype<span class="op">=</span>torch.float16</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    ).to(<span class="st">"cuda"</span>)</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> pipe</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_with_multiple_controls(pipe, image_path, prompt):</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate using multiple control inputs</span></span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>    original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate different control images</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>    canny_detector <span class="op">=</span> CannyDetector()</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>    openpose_detector <span class="op">=</span> OpenposeDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>    canny_image <span class="op">=</span> canny_detector(original_image)</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>    pose_image <span class="op">=</span> openpose_detector(original_image)</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate with multiple controls</span></span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>[canny_image, pose_image],</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span>,</span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span>[<span class="fl">1.0</span>, <span class="fl">0.8</span>]  <span class="co"># Different weights for each control</span></span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>]</span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-55"><a href="#cb11-55" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb11-56"><a href="#cb11-56" aria-hidden="true" tabindex="-1"></a>pipe <span class="op">=</span> setup_multi_controlnet_pipeline([<span class="st">"canny"</span>, <span class="st">"openpose"</span>])</span>
<span id="cb11-57"><a href="#cb11-57" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> generate_with_multiple_controls(pipe, <span class="st">"input.jpg"</span>, <span class="st">"a cyberpunk warrior"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="fine-tuning-parameters" class="level2">
<h2 class="anchored" data-anchor-id="fine-tuning-parameters" id="fine-tuning-parameters">Fine-tuning Parameters</h2>
<section id="control-strength-and-guidance" class="level3">
<h3 class="anchored" data-anchor-id="control-strength-and-guidance" id="control-strength-and-guidance">Control Strength and Guidance</h3>
<div id="advanced-control" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> advanced_generation_control(pipe, control_image, prompt, <span class="op">**</span>kwargs):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Advanced parameter control for fine-tuning generation</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Default parameters</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    params <span class="op">=</span> {</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'prompt'</span>: prompt,</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'image'</span>: control_image,</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'num_inference_steps'</span>: <span class="dv">20</span>,</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">'guidance_scale'</span>: <span class="fl">7.5</span>,</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">'controlnet_conditioning_scale'</span>: <span class="fl">1.0</span>,</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">'control_guidance_start'</span>: <span class="fl">0.0</span>,</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">'control_guidance_end'</span>: <span class="fl">1.0</span>,</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">'eta'</span>: <span class="fl">0.0</span>,</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">'generator'</span>: torch.manual_seed(<span class="dv">42</span>)</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Update with custom parameters</span></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    params.update(kwargs)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate image</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(<span class="op">**</span>params)</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>]</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Examples of parameter variations</span></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>variations <span class="op">=</span> [</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Strong control throughout</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'controlnet_conditioning_scale'</span>: <span class="fl">1.5</span>},</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Weak control for more creativity</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'controlnet_conditioning_scale'</span>: <span class="fl">0.5</span>},</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Control only in early steps</span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'control_guidance_end'</span>: <span class="fl">0.5</span>},</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Control only in later steps</span></span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'control_guidance_start'</span>: <span class="fl">0.5</span>},</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Higher guidance for more prompt adherence</span></span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'guidance_scale'</span>: <span class="fl">12.0</span>},</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>    <span class="co"># More inference steps for quality</span></span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'num_inference_steps'</span>: <span class="dv">50</span>}</span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>]</span></code></pre></div></div>
</div>
</section>
<section id="adaptive-control-strength" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-control-strength" id="adaptive-control-strength">Adaptive Control Strength</h3>
<div id="adaptive-control" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> adaptive_control_strength(pipe, control_image, prompt, complexity_factor<span class="op">=</span><span class="fl">1.0</span>):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Automatically adjust control strength based on image complexity</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Analyze control image complexity</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    control_array <span class="op">=</span> np.array(control_image.convert(<span class="st">'L'</span>))</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate edge density as complexity measure</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    edges <span class="op">=</span> cv2.Canny(control_array, <span class="dv">50</span>, <span class="dv">150</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    edge_density <span class="op">=</span> np.<span class="bu">sum</span>(edges <span class="op">&gt;</span> <span class="dv">0</span>) <span class="op">/</span> edges.size</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Adjust control strength based on complexity</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    base_strength <span class="op">=</span> <span class="fl">1.0</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> edge_density <span class="op">&gt;</span> <span class="fl">0.1</span>:  <span class="co"># High detail</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        control_strength <span class="op">=</span> base_strength <span class="op">*</span> <span class="fl">0.8</span> <span class="op">*</span> complexity_factor</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> edge_density <span class="op">&lt;</span> <span class="fl">0.05</span>:  <span class="co"># Low detail</span></span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        control_strength <span class="op">=</span> base_strength <span class="op">*</span> <span class="fl">1.2</span> <span class="op">*</span> complexity_factor</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:  <span class="co"># Medium detail</span></span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        control_strength <span class="op">=</span> base_strength <span class="op">*</span> complexity_factor</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> pipe(</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        prompt<span class="op">=</span>prompt,</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span>control_image,</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        controlnet_conditioning_scale<span class="op">=</span>control_strength,</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>        guidance_scale<span class="op">=</span><span class="fl">7.5</span></span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>], control_strength</span></code></pre></div></div>
</div>
</section>
</section>
<section id="production-optimization" class="level2">
<h2 class="anchored" data-anchor-id="production-optimization" id="production-optimization">Production Optimization</h2>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">Memory Management</h3>
<div id="optimized-generator" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedControlNetGenerator:</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Production-ready ControlNet generator with optimization</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, controlnet_type<span class="op">=</span><span class="st">"canny"</span>, enable_cpu_offload<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pipe <span class="op">=</span> setup_controlnet_pipeline(controlnet_type)</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> enable_cpu_offload:</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.pipe.enable_model_cpu_offload()</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Enable memory efficient attention</span></span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pipe.enable_xformers_memory_efficient_attention()</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compile model for faster inference (PyTorch 2.0+)</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.pipe.unet <span class="op">=</span> torch.<span class="bu">compile</span>(<span class="va">self</span>.pipe.unet, mode<span class="op">=</span><span class="st">"reduce-overhead"</span>, fullgraph<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span>:</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Torch compile not available, skipping optimization"</span>)</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_batch(<span class="va">self</span>, control_images, prompts, batch_size<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a><span class="co">        Generate multiple images in batches for efficiency</span></span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(prompts), batch_size):</span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>            batch_prompts <span class="op">=</span> prompts[i:i<span class="op">+</span>batch_size]</span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>            batch_images <span class="op">=</span> control_images[i:i<span class="op">+</span>batch_size]</span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Clear cache before batch</span></span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>            torch.cuda.empty_cache()</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>            batch_results <span class="op">=</span> <span class="va">self</span>.pipe(</span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>                prompt<span class="op">=</span>batch_prompts,</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a>                image<span class="op">=</span>batch_images,</span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>                num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>                guidance_scale<span class="op">=</span><span class="fl">7.5</span></span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>            results.extend(batch_results.images)</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> generate_with_callback(<span class="va">self</span>, control_image, prompt, callback<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a><span class="co">        Generate with progress callback</span></span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> progress_callback(step, timestep, latents):</span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> callback:</span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>                callback(step, timestep)</span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.pipe(</span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a>            prompt<span class="op">=</span>prompt,</span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a>            image<span class="op">=</span>control_image,</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>            callback<span class="op">=</span>progress_callback,</span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>            callback_steps<span class="op">=</span><span class="dv">1</span></span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result.images[<span class="dv">0</span>]</span></code></pre></div></div>
</div>
</section>
<section id="caching-and-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="caching-and-preprocessing" id="caching-and-preprocessing">Caching and Preprocessing</h3>
<div id="caching-system" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> hashlib</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ControlNetCache:</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Cache system for preprocessed control images</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, cache_dir<span class="op">=</span><span class="st">"./controlnet_cache"</span>):</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cache_dir <span class="op">=</span> cache_dir</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        os.makedirs(cache_dir, exist_ok<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.detectors <span class="op">=</span> {}</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_detector(<span class="va">self</span>, detector_type):</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a><span class="co">        Lazy load and cache detectors</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> detector_type <span class="kw">not</span> <span class="kw">in</span> <span class="va">self</span>.detectors:</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>            detector_map <span class="op">=</span> {</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>                <span class="st">'canny'</span>: CannyDetector(),</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>                <span class="st">'openpose'</span>: OpenposeDetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>),</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>                <span class="st">'hed'</span>: HEDdetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>),</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>                <span class="st">'mlsd'</span>: MLSDdetector.from_pretrained(<span class="st">'lllyasviel/Annotators'</span>)</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.detectors[detector_type] <span class="op">=</span> detector_map[detector_type]</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.detectors[detector_type]</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_control_image(<span class="va">self</span>, image_path, control_type, force_refresh<span class="op">=</span><span class="va">False</span>):</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a><span class="co">        Get control image with caching</span></span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create cache key</span></span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>        image_hash <span class="op">=</span> hashlib.md5(<span class="bu">open</span>(image_path, <span class="st">'rb'</span>).read()).hexdigest()</span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>        cache_path <span class="op">=</span> os.path.join(<span class="va">self</span>.cache_dir, <span class="ss">f"</span><span class="sc">{</span>image_hash<span class="sc">}</span><span class="ss">_</span><span class="sc">{</span>control_type<span class="sc">}</span><span class="ss">.png"</span>)</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check cache</span></span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> os.path.exists(cache_path) <span class="kw">and</span> <span class="kw">not</span> force_refresh:</span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> Image.<span class="bu">open</span>(cache_path)</span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate control image</span></span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>        original_image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>        detector <span class="op">=</span> <span class="va">self</span>.get_detector(control_type)</span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>        control_image <span class="op">=</span> detector(original_image)</span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save to cache</span></span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>        control_image.save(cache_path)</span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> control_image</span></code></pre></div></div>
</div>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Common Issues
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>GPU Memory</strong>: ControlNet models require significant GPU memory (8GB+ recommended)</li>
<li><strong>Image Format</strong>: Ensure control images are in RGB format and proper dimensions</li>
<li><strong>Model Compatibility</strong>: Match ControlNet models with compatible Stable Diffusion versions</li>
</ul>
</div>
</div>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<div id="diagnostics" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> diagnose_controlnet_issues(pipe, control_image, prompt):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Diagnostic function for common ControlNet issues</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    issues <span class="op">=</span> []</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Check control image format</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> control_image.mode <span class="op">!=</span> <span class="st">'RGB'</span>:</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        issues.append(<span class="st">"Control image should be RGB format"</span>)</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        control_image <span class="op">=</span> control_image.convert(<span class="st">'RGB'</span>)</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Check image size</span></span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> control_image.size <span class="op">!=</span> (<span class="dv">512</span>, <span class="dv">512</span>):</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        issues.append(<span class="ss">f"Control image size </span><span class="sc">{</span>control_image<span class="sc">.</span>size<span class="sc">}</span><span class="ss"> != (512, 512)"</span>)</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        control_image <span class="op">=</span> control_image.resize((<span class="dv">512</span>, <span class="dv">512</span>))</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Check GPU memory</span></span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>        memory_allocated <span class="op">=</span> torch.cuda.memory_allocated() <span class="op">/</span> <span class="fl">1e9</span></span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>        memory_reserved <span class="op">=</span> torch.cuda.memory_reserved() <span class="op">/</span> <span class="fl">1e9</span></span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> memory_reserved <span class="op">&gt;</span> <span class="dv">10</span>:  <span class="co"># More than 10GB</span></span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>            issues.append(<span class="ss">f"High GPU memory usage: </span><span class="sc">{</span>memory_reserved<span class="sc">:.1f}</span><span class="ss">GB"</span>)</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Check prompt length</span></span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(prompt.split()) <span class="op">&gt;</span> <span class="dv">75</span>:</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>        issues.append(<span class="st">"Very long prompt may cause issues"</span>)</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> issues:</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Detected issues:"</span>)</span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> issue <span class="kw">in</span> issues:</span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"- </span><span class="sc">{</span>issue<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> control_image</span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-36"><a href="#cb16-36" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_cleanup():</span>
<span id="cb16-37"><a href="#cb16-37" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb16-38"><a href="#cb16-38" aria-hidden="true" tabindex="-1"></a><span class="co">    Clean up GPU memory</span></span>
<span id="cb16-39"><a href="#cb16-39" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb16-40"><a href="#cb16-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb16-41"><a href="#cb16-41" aria-hidden="true" tabindex="-1"></a>        torch.cuda.empty_cache()</span>
<span id="cb16-42"><a href="#cb16-42" aria-hidden="true" tabindex="-1"></a>        torch.cuda.ipc_collect()</span>
<span id="cb16-43"><a href="#cb16-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-44"><a href="#cb16-44" aria-hidden="true" tabindex="-1"></a><span class="co"># Error handling wrapper</span></span>
<span id="cb16-45"><a href="#cb16-45" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_generate(pipe, control_image, prompt, max_retries<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb16-46"><a href="#cb16-46" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb16-47"><a href="#cb16-47" aria-hidden="true" tabindex="-1"></a><span class="co">    Generate with error handling and retries</span></span>
<span id="cb16-48"><a href="#cb16-48" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb16-49"><a href="#cb16-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> attempt <span class="kw">in</span> <span class="bu">range</span>(max_retries):</span>
<span id="cb16-50"><a href="#cb16-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb16-51"><a href="#cb16-51" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Diagnose issues</span></span>
<span id="cb16-52"><a href="#cb16-52" aria-hidden="true" tabindex="-1"></a>            control_image <span class="op">=</span> diagnose_controlnet_issues(pipe, control_image, prompt)</span>
<span id="cb16-53"><a href="#cb16-53" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb16-54"><a href="#cb16-54" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Generate</span></span>
<span id="cb16-55"><a href="#cb16-55" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> pipe(</span>
<span id="cb16-56"><a href="#cb16-56" aria-hidden="true" tabindex="-1"></a>                prompt<span class="op">=</span>prompt,</span>
<span id="cb16-57"><a href="#cb16-57" aria-hidden="true" tabindex="-1"></a>                image<span class="op">=</span>control_image,</span>
<span id="cb16-58"><a href="#cb16-58" aria-hidden="true" tabindex="-1"></a>                num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb16-59"><a href="#cb16-59" aria-hidden="true" tabindex="-1"></a>                guidance_scale<span class="op">=</span><span class="fl">7.5</span></span>
<span id="cb16-60"><a href="#cb16-60" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb16-61"><a href="#cb16-61" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb16-62"><a href="#cb16-62" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result.images[<span class="dv">0</span>]</span>
<span id="cb16-63"><a href="#cb16-63" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb16-64"><a href="#cb16-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">RuntimeError</span> <span class="im">as</span> e:</span>
<span id="cb16-65"><a href="#cb16-65" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">"out of memory"</span> <span class="kw">in</span> <span class="bu">str</span>(e).lower():</span>
<span id="cb16-66"><a href="#cb16-66" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"GPU OOM on attempt </span><span class="sc">{</span>attempt <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss">, cleaning memory..."</span>)</span>
<span id="cb16-67"><a href="#cb16-67" aria-hidden="true" tabindex="-1"></a>                memory_cleanup()</span>
<span id="cb16-68"><a href="#cb16-68" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb16-69"><a href="#cb16-69" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> attempt <span class="op">==</span> max_retries <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb16-70"><a href="#cb16-70" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">raise</span> e</span>
<span id="cb16-71"><a href="#cb16-71" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb16-72"><a href="#cb16-72" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> e</span>
<span id="cb16-73"><a href="#cb16-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-74"><a href="#cb16-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb16-75"><a href="#cb16-75" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Unexpected error on attempt </span><span class="sc">{</span>attempt <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb16-76"><a href="#cb16-76" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> attempt <span class="op">==</span> max_retries <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb16-77"><a href="#cb16-77" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> e</span>
<span id="cb16-78"><a href="#cb16-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-79"><a href="#cb16-79" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="va">None</span></span></code></pre></div></div>
</div>
</section>
<section id="performance-benchmarking" class="level3">
<h3 class="anchored" data-anchor-id="performance-benchmarking" id="performance-benchmarking">Performance Benchmarking</h3>
<div id="benchmarking" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="at">@contextmanager</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timer():</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a><span class="co">    Simple timing context manager</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">yield</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    end <span class="op">=</span> time.time()</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Execution time: </span><span class="sc">{</span>end <span class="op">-</span> start<span class="sc">:.2f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_controlnet(pipe, control_image, prompt, runs<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a><span class="co">    Benchmark ControlNet performance</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    times <span class="op">=</span> []</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warmup</span></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>    _ <span class="op">=</span> pipe(prompt<span class="op">=</span>prompt, image<span class="op">=</span>control_image, num_inference_steps<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Benchmark runs</span></span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(runs):</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> pipe(</span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>            prompt<span class="op">=</span>prompt,</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>            image<span class="op">=</span>control_image,</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>            num_inference_steps<span class="op">=</span><span class="dv">20</span>,</span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>            guidance_scale<span class="op">=</span><span class="fl">7.5</span></span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>        times.append(end_time <span class="op">-</span> start_time)</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>    avg_time <span class="op">=</span> <span class="bu">sum</span>(times) <span class="op">/</span> <span class="bu">len</span>(times)</span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Average generation time: </span><span class="sc">{</span>avg_time<span class="sc">:.2f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Images per minute: </span><span class="sc">{</span><span class="dv">60</span> <span class="op">/</span> avg_time<span class="sc">:.1f}</span><span class="ss">"</span>)</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result.images[<span class="dv">0</span>]</span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices-summary" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-summary" id="best-practices-summary">Best Practices Summary</h2>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-4-contents" aria-controls="callout-4" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Recommendations
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-4" class="callout-4-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Memory Management</strong>: Use CPU offloading and memory efficient attention for large models</li>
<li><strong>Preprocessing</strong>: Cache control images when generating multiple variations</li>
<li><strong>Parameter Tuning</strong>: Adjust <code>controlnet_conditioning_scale</code> based on desired control strength</li>
<li><strong>Quality vs Speed</strong>: Balance <code>num_inference_steps</code> with generation time requirements</li>
<li><strong>Multi-Control</strong>: Use different conditioning scales when combining multiple ControlNets</li>
<li><strong>Error Handling</strong>: Implement robust error handling for production systems</li>
<li><strong>Optimization</strong>: Use torch.compile() and xformers for performance improvements</li>
</ol>
</div>
</div>
</div>
<section id="parameter-reference-table" class="level3">
<h3 class="anchored" data-anchor-id="parameter-reference-table" id="parameter-reference-table">Parameter Reference Table</h3>
<div id="tbl-parameters" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-parameters-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: ControlNet Parameter Reference
</figcaption>
<div aria-describedby="tbl-parameters-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Parameter</th>
<th>Range</th>
<th>Effect</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><code>controlnet_conditioning_scale</code></td>
<td>0.5-1.5</td>
<td>Control strength</td>
</tr>
<tr class="even">
<td><code>guidance_scale</code></td>
<td>5.0-15.0</td>
<td>Prompt adherence</td>
</tr>
<tr class="odd">
<td><code>num_inference_steps</code></td>
<td>10-50</td>
<td>Quality vs speed</td>
</tr>
<tr class="even">
<td><code>control_guidance_start</code></td>
<td>0.0-0.5</td>
<td>When control starts</td>
</tr>
<tr class="odd">
<td><code>control_guidance_end</code></td>
<td>0.5-1.0</td>
<td>When control ends</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>This comprehensive guide provides everything needed to implement Stable Diffusion with ControlNet, from basic usage to production-ready systems. The modular structure allows for easy customization and extension based on specific requirements.</p>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DenseNet: A Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-code/</guid>
      <pubDate>Sat, 19 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="densenet-a-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-code/dncode.png" class="img-fluid" width="600"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>DenseNet (Densely Connected Convolutional Networks) represents a paradigm shift in deep learning architecture design, introducing unprecedented connectivity patterns that revolutionize how information flows through neural networks. Proposed by Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Weinberger in 2017, DenseNet challenges the traditional sequential nature of convolutional neural networks by creating direct connections between every layer and all subsequent layers.</p>
<p>The fundamental insight behind DenseNet stems from addressing the vanishing gradient problem that plagued very deep networks. While ResNet introduced skip connections to enable training of deeper networks, DenseNet takes this concept to its logical extreme, creating a densely connected topology that maximizes information flow and gradient propagation throughout the entire network.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>DenseNet’s core innovation lies in connecting each layer to every subsequent layer in the network, creating maximum information flow and feature reuse.</p>
</div>
</div>
</section>
<section id="theoretical-foundation" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-foundation" id="theoretical-foundation">Theoretical Foundation</h2>
<section id="the-dense-connectivity-pattern" class="level3">
<h3 class="anchored" data-anchor-id="the-dense-connectivity-pattern" id="the-dense-connectivity-pattern">The Dense Connectivity Pattern</h3>
<p>The core innovation of DenseNet lies in its connectivity pattern. In traditional CNNs, each layer receives input only from the previous layer. ResNet improved upon this by adding skip connections, allowing layers to receive input from both the previous layer and earlier layers through residual connections. DenseNet generalizes this concept by connecting each layer to every subsequent layer in the network.</p>
<p>Mathematically, if we consider a network with L layers, the lth layer receives feature maps from all preceding layers:</p>
<p><span class="math display">\[
x_l = H_l([x_0, x_1, ..., x_{l-1}])
\]</span></p>
<p>Where <span class="math inline">\([x_0, x_1, ..., x_{l-1}]\)</span> represents the concatenation of feature maps produced by layers 0 through l-1, and <span class="math inline">\(H_l\)</span> denotes the composite function performed by the lth layer.</p>
<p>This dense connectivity pattern creates several theoretical advantages:</p>
<ol type="1">
<li><p><strong>Maximum Information Flow</strong>: Every layer has direct access to the gradients from the loss function and the original input signal, ensuring efficient gradient flow during backpropagation.</p></li>
<li><p><strong>Feature Reuse</strong>: Lower-level features are directly accessible to higher-level layers, promoting feature reuse and reducing the need for redundant feature learning.</p></li>
<li><p><strong>Implicit Deep Supervision</strong>: Each layer receives supervision signals from all subsequent layers, creating an implicit form of deep supervision that improves learning efficiency.</p></li>
</ol>
</section>
<section id="growth-rate-and-feature-map-management" class="level3">
<h3 class="anchored" data-anchor-id="growth-rate-and-feature-map-management" id="growth-rate-and-feature-map-management">Growth Rate and Feature Map Management</h3>
<p>A critical design parameter in DenseNet is the growth rate (k), which determines how many new feature maps each layer contributes to the global feature pool. If each layer produces k feature maps, then the lth layer receives <span class="math inline">\(k \times l\)</span> input feature maps from all preceding layers.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Growth Rate Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<p>Typical values for k range from 12 to 32, which is significantly smaller than the hundreds of feature maps common in traditional architectures like VGG or ResNet.</p>
</div>
</div>
<p>This growth pattern means that while each individual layer remains narrow (small k), the collective input to each layer grows linearly with depth. The growth rate serves as a global hyperparameter that controls the information flow throughout the network. A smaller growth rate forces the network to learn more efficient representations, while a larger growth rate provides more representational capacity at the cost of computational efficiency.</p>
</section>
</section>
<section id="architecture-components" class="level2">
<h2 class="anchored" data-anchor-id="architecture-components" id="architecture-components">Architecture Components</h2>
<section id="dense-blocks" class="level3">
<h3 class="anchored" data-anchor-id="dense-blocks" id="dense-blocks">Dense Blocks</h3>
<p>Dense blocks form the fundamental building units of DenseNet. Within each dense block, every layer is connected to every subsequent layer through concatenation operations. The internal structure of a dense block implements the dense connectivity pattern while maintaining computational efficiency.</p>
<p>Each layer within a dense block typically consists of:</p>
<ul>
<li>Batch normalization</li>
<li>ReLU activation<br>
</li>
<li>3×3 convolution</li>
</ul>
<p>Some variants also include a 1×1 convolution (bottleneck layer) before the 3×3 convolution to reduce computational complexity, creating the DenseNet-BC (Bottleneck-Compression) variant.</p>
</section>
<section id="transition-layers" class="level3">
<h3 class="anchored" data-anchor-id="transition-layers" id="transition-layers">Transition Layers</h3>
<p>Between dense blocks, transition layers serve multiple critical functions:</p>
<ol type="1">
<li><p><strong>Dimensionality Reduction</strong>: As feature maps accumulate through concatenation within dense blocks, transition layers reduce the number of feature maps to control model complexity and computational requirements.</p></li>
<li><p><strong>Spatial Downsampling</strong>: Transition layers typically include average pooling operations to reduce spatial dimensions, enabling the network to learn hierarchical representations at different scales.</p></li>
<li><p><strong>Compression</strong>: The compression factor (θ) in transition layers, typically set to 0.5, determines how many feature maps are retained. This compression helps maintain computational efficiency while preserving essential information.</p></li>
</ol>
<p>A typical transition layer consists of:</p>
<ul>
<li>Batch normalization</li>
<li>1×1 convolution (for compression)</li>
<li>2×2 average pooling</li>
</ul>
</section>
<section id="composite-functions" class="level3">
<h3 class="anchored" data-anchor-id="composite-functions" id="composite-functions">Composite Functions</h3>
<p>The composite function <span class="math inline">\(H_l\)</span> in DenseNet typically follows the pre-activation design pattern:</p>
<p><strong>Batch Normalization → ReLU → Convolution</strong></p>
<p>This ordering, borrowed from ResNet improvements, ensures optimal gradient flow and training stability. The pre-activation design places the normalization and activation functions before the convolution operation, which has been shown to improve training dynamics in very deep networks.</p>
</section>
</section>
<section id="implementation-deep-dive" class="level2">
<h2 class="anchored" data-anchor-id="implementation-deep-dive" id="implementation-deep-dive">Implementation Deep Dive</h2>
<section id="memory-efficiency-considerations" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficiency-considerations" id="memory-efficiency-considerations">Memory Efficiency Considerations</h3>
<p>One of the primary challenges in implementing DenseNet stems from its memory requirements. The concatenation operations required for dense connectivity can lead to significant memory consumption, especially during the backward pass when gradients must be stored for all connections.</p>
<p>Several optimization strategies address these memory concerns:</p>
<ol type="1">
<li><p><strong>Shared Memory Allocation</strong>: Implementing efficient memory sharing for concatenation operations reduces the memory footprint by avoiding unnecessary copying of feature maps.</p></li>
<li><p><strong>Gradient Checkpointing</strong>: For very deep DenseNet models, gradient checkpointing can trade computation for memory by recomputing intermediate activations during the backward pass instead of storing them.</p></li>
<li><p><strong>Efficient Concatenation</strong>: Using in-place operations where possible and optimizing the order of concatenation operations can significantly reduce memory usage.</p></li>
</ol>
</section>
<section id="implementation-variants" class="level3">
<h3 class="anchored" data-anchor-id="implementation-variants" id="implementation-variants">Implementation Variants</h3>
<section id="densenet-bc-bottleneck-compression" class="level4">
<h4 class="anchored" data-anchor-id="densenet-bc-bottleneck-compression">DenseNet-BC (Bottleneck-Compression)</h4>
<p>The BC variant introduces bottleneck layers that use 1×1 convolutions to reduce the number of input feature maps before applying the 3×3 convolution. This modification significantly reduces computational complexity while maintaining representational capacity.</p>
<p>The bottleneck design modifies the composite function to: <strong>BN → ReLU → 1×1 Conv → BN → ReLU → 3×3 Conv</strong></p>
</section>
<section id="densenet-c-compression-only" class="level4">
<h4 class="anchored" data-anchor-id="densenet-c-compression-only">DenseNet-C (Compression Only)</h4>
<p>This variant applies compression in transition layers without using bottleneck layers within dense blocks, providing a middle ground between computational efficiency and architectural simplicity.</p>
</section>
</section>
</section>
<section id="code-implementation" class="level2">
<h2 class="anchored" data-anchor-id="code-implementation" id="code-implementation">Code Implementation</h2>
<p>Here’s a comprehensive PyTorch implementation of DenseNet:</p>
<div id="densenet-implementation" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> OrderedDict</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DenseLayer(nn.Module):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, growth_rate, bottleneck_size<span class="op">=</span><span class="dv">4</span>, dropout_rate<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DenseLayer, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Bottleneck layer (1x1 conv)</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bottleneck <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(in_channels),</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, bottleneck_size <span class="op">*</span> growth_rate, </span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>                     kernel_size<span class="op">=</span><span class="dv">1</span>, stride<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Main convolution layer (3x3 conv)</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.main_conv <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(bottleneck_size <span class="op">*</span> growth_rate),</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(bottleneck_size <span class="op">*</span> growth_rate, growth_rate,</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>                     kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate) <span class="cf">if</span> dropout_rate <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># x can be a tensor or a list of tensors (from concatenation)</span></span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(x, torch.Tensor):</span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>            concatenated_features <span class="op">=</span> x</span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>            concatenated_features <span class="op">=</span> torch.cat(x, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply bottleneck</span></span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>        bottleneck_output <span class="op">=</span> <span class="va">self</span>.bottleneck(concatenated_features)</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply main convolution</span></span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>        new_features <span class="op">=</span> <span class="va">self</span>.main_conv(bottleneck_output)</span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply dropout if specified</span></span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.dropout <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a>            new_features <span class="op">=</span> <span class="va">self</span>.dropout(new_features)</span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> new_features</span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DenseBlock(nn.Module):</span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_layers, in_channels, growth_rate, </span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a>                 bottleneck_size<span class="op">=</span><span class="dv">4</span>, dropout_rate<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DenseBlock, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList()</span>
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_layers):</span>
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a>            current_in_channels <span class="op">=</span> in_channels <span class="op">+</span> i <span class="op">*</span> growth_rate</span>
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a>            layer <span class="op">=</span> DenseLayer(</span>
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a>                current_in_channels, </span>
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a>                growth_rate, </span>
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a>                bottleneck_size, </span>
<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a>                dropout_rate</span>
<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.layers.append(layer)</span>
<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> [x]</span>
<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a>            new_features <span class="op">=</span> layer(features)</span>
<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a>            features.append(new_features)</span>
<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.cat(features[<span class="dv">1</span>:], dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># Exclude original input</span></span>
<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TransitionLayer(nn.Module):</span>
<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, compression_factor<span class="op">=</span><span class="fl">0.5</span>):</span>
<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(TransitionLayer, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a>        out_channels <span class="op">=</span> <span class="bu">int</span>(in_channels <span class="op">*</span> compression_factor)</span>
<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transition <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(in_channels),</span>
<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, out_channels, kernel_size<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a>            nn.AvgPool2d(kernel_size<span class="op">=</span><span class="dv">2</span>, stride<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.out_channels <span class="op">=</span> out_channels</span>
<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.transition(x)</span></code></pre></div></div>
</div>
<div id="densenet-main-class" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DenseNet(nn.Module):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, growth_rate<span class="op">=</span><span class="dv">32</span>, block_config<span class="op">=</span>(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">24</span>, <span class="dv">16</span>),</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>                 num_init_features<span class="op">=</span><span class="dv">64</span>, bottleneck_size<span class="op">=</span><span class="dv">4</span>, </span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>                 compression_factor<span class="op">=</span><span class="fl">0.5</span>, dropout_rate<span class="op">=</span><span class="fl">0.0</span>, </span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>                 num_classes<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DenseNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initial convolution and pooling</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(OrderedDict([</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>            (<span class="st">'conv0'</span>, nn.Conv2d(<span class="dv">3</span>, num_init_features, </span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>                               kernel_size<span class="op">=</span><span class="dv">7</span>, stride<span class="op">=</span><span class="dv">2</span>, padding<span class="op">=</span><span class="dv">3</span>, bias<span class="op">=</span><span class="va">False</span>)),</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>            (<span class="st">'norm0'</span>, nn.BatchNorm2d(num_init_features)),</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>            (<span class="st">'relu0'</span>, nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>)),</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>            (<span class="st">'pool0'</span>, nn.MaxPool2d(kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">2</span>, padding<span class="op">=</span><span class="dv">1</span>))</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        ]))</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Dense blocks and transition layers</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        num_features <span class="op">=</span> num_init_features</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, num_layers <span class="kw">in</span> <span class="bu">enumerate</span>(block_config):</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add dense block</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>            block <span class="op">=</span> DenseBlock(</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>                num_layers<span class="op">=</span>num_layers,</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>                in_channels<span class="op">=</span>num_features,</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>                growth_rate<span class="op">=</span>growth_rate,</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>                bottleneck_size<span class="op">=</span>bottleneck_size,</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>                dropout_rate<span class="op">=</span>dropout_rate</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.features.add_module(<span class="ss">f'denseblock</span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">'</span>, block)</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>            num_features <span class="op">+=</span> num_layers <span class="op">*</span> growth_rate</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add transition layer (except after the last dense block)</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> i <span class="op">!=</span> <span class="bu">len</span>(block_config) <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>                transition <span class="op">=</span> TransitionLayer(num_features, compression_factor)</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.features.add_module(<span class="ss">f'transition</span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">'</span>, transition)</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>                num_features <span class="op">=</span> transition.out_channels</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final batch normalization</span></span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features.add_module(<span class="st">'norm_final'</span>, nn.BatchNorm2d(num_features))</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classifier</span></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(num_features, num_classes)</span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Weight initialization</span></span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._initialize_weights()</span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _initialize_weights(<span class="va">self</span>):</span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> m <span class="kw">in</span> <span class="va">self</span>.modules():</span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>                nn.init.kaiming_normal_(m.weight, mode<span class="op">=</span><span class="st">'fan_out'</span>, </span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>                                      nonlinearity<span class="op">=</span><span class="st">'relu'</span>)</span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.BatchNorm2d):</span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.weight, <span class="dv">1</span>)</span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.Linear):</span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>                nn.init.normal_(m.weight, <span class="dv">0</span>, <span class="fl">0.01</span>)</span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> F.relu(features, inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> F.adaptive_avg_pool2d(out, (<span class="dv">1</span>, <span class="dv">1</span>))</span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> torch.flatten(out, <span class="dv">1</span>)</span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.classifier(out)</span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span></code></pre></div></div>
</div>
<div id="densenet-variants" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Factory functions for common DenseNet variants</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> densenet121(num_classes<span class="op">=</span><span class="dv">1000</span>, <span class="op">**</span>kwargs):</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> DenseNet(growth_rate<span class="op">=</span><span class="dv">32</span>, block_config<span class="op">=</span>(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">24</span>, <span class="dv">16</span>), </span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>                   num_classes<span class="op">=</span>num_classes, <span class="op">**</span>kwargs)</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> densenet169(num_classes<span class="op">=</span><span class="dv">1000</span>, <span class="op">**</span>kwargs):</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> DenseNet(growth_rate<span class="op">=</span><span class="dv">32</span>, block_config<span class="op">=</span>(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">32</span>, <span class="dv">32</span>), </span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>                   num_classes<span class="op">=</span>num_classes, <span class="op">**</span>kwargs)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> densenet201(num_classes<span class="op">=</span><span class="dv">1000</span>, <span class="op">**</span>kwargs):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> DenseNet(growth_rate<span class="op">=</span><span class="dv">32</span>, block_config<span class="op">=</span>(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">48</span>, <span class="dv">32</span>), </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>                   num_classes<span class="op">=</span>num_classes, <span class="op">**</span>kwargs)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> densenet161(num_classes<span class="op">=</span><span class="dv">1000</span>, <span class="op">**</span>kwargs):</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> DenseNet(growth_rate<span class="op">=</span><span class="dv">48</span>, block_config<span class="op">=</span>(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">36</span>, <span class="dv">24</span>), </span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>                   num_init_features<span class="op">=</span><span class="dv">96</span>, num_classes<span class="op">=</span>num_classes, <span class="op">**</span>kwargs)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Create a DenseNet-121 model</span></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> densenet121(num_classes<span class="op">=</span><span class="dv">1000</span>)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Model created with </span><span class="sc">{</span><span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())<span class="sc">}</span><span class="ss"> parameters"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="performance-analysis-and-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis-and-benchmarks" id="performance-analysis-and-benchmarks">Performance Analysis and Benchmarks</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<p>DenseNet’s computational complexity differs significantly from traditional architectures due to its unique connectivity pattern. While the number of parameters can be substantially lower than comparable ResNet models, the memory requirements during training are generally higher due to the concatenation operations.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Complexity Characteristics
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><p><strong>Parameter Efficiency</strong>: DenseNet typically requires fewer parameters than ResNet for comparable performance due to feature reuse and the narrow layer design.</p></li>
<li><p><strong>Memory Complexity</strong>: Memory usage grows quadratically with the number of layers within dense blocks due to concatenation operations.</p></li>
<li><p><strong>Computational Complexity</strong>: While individual layers are computationally lighter, the overall complexity can be higher due to the increased connectivity.</p></li>
</ol>
</div>
</div>
</section>
<section id="benchmark-results" class="level3">
<h3 class="anchored" data-anchor-id="benchmark-results" id="benchmark-results">Benchmark Results</h3>
<p>DenseNet has demonstrated strong performance across various computer vision tasks:</p>
<div id="tbl-imagenet-performance" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-imagenet-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: DenseNet Performance on ImageNet
</figcaption>
<div aria-describedby="tbl-imagenet-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>ImageNet Top-1 Error</th>
<th>Parameters</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>DenseNet-121</td>
<td>25.35%</td>
<td>8.0M</td>
</tr>
<tr class="even">
<td>DenseNet-169</td>
<td>24.00%</td>
<td>14.1M</td>
</tr>
<tr class="odd">
<td>DenseNet-201</td>
<td>22.80%</td>
<td>20.0M</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p><strong>CIFAR Datasets</strong>:</p>
<ul>
<li>CIFAR-10: Error rates as low as 3.46% with appropriate regularization</li>
<li>CIFAR-100: Competitive performance with significantly fewer parameters than ResNet</li>
</ul>
</section>
<section id="memory-optimization-strategies" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization-strategies" id="memory-optimization-strategies">Memory Optimization Strategies</h3>
<p>Several strategies can be employed to optimize DenseNet’s memory usage:</p>
<div id="memory-optimization-example" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example of memory-efficient DenseNet implementation considerations</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MemoryEfficientDenseLayer(nn.Module):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    Memory-efficient implementation using gradient checkpointing</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, growth_rate):</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation with memory optimizations</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use gradient checkpointing for memory efficiency</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.utils.checkpoint.checkpoint(<span class="va">self</span>._forward_impl, x)</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _forward_impl(<span class="va">self</span>, x):</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Actual forward implementation</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</div>
<ol type="1">
<li><p><strong>Memory-Efficient Implementation</strong>: Using shared memory allocation and efficient concatenation operations.</p></li>
<li><p><strong>Mixed Precision Training</strong>: Utilizing half-precision floating-point arithmetic where appropriate.</p></li>
<li><p><strong>Gradient Checkpointing</strong>: Trading computation for memory by recomputing intermediate activations.</p></li>
</ol>
</section>
</section>
<section id="training-considerations" class="level2">
<h2 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h2>
<section id="hyperparameter-selection" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-selection" id="hyperparameter-selection">Hyperparameter Selection</h3>
<p>Training DenseNet effectively requires careful attention to several hyperparameters:</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Critical Hyperparameters
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Growth Rate (k)</strong>: Typically ranges from 12 to 48. Smaller values promote parameter efficiency but may limit representational capacity.</li>
<li><strong>Compression Factor (θ)</strong>: Usually set to 0.5, balancing computational efficiency with information preservation.</li>
<li><strong>Dropout Rate</strong>: Often beneficial for regularization, particularly in deeper variants.</li>
<li><strong>Learning Rate Schedule</strong>: Due to the efficient gradient flow, DenseNet often benefits from different learning rate schedules compared to ResNet.</li>
</ul>
</div>
</div>
</section>
<section id="regularization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="regularization-techniques" id="regularization-techniques">Regularization Techniques</h3>
<p>DenseNet’s dense connectivity can sometimes lead to overfitting, making regularization crucial:</p>
<div id="training-example" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.optim.lr_scheduler <span class="im">import</span> StepLR</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Example training setup for DenseNet</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> densenet121(num_classes<span class="op">=</span><span class="dv">10</span>)  <span class="co"># For CIFAR-10</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> optim.SGD(model.parameters(), lr<span class="op">=</span><span class="fl">0.1</span>, momentum<span class="op">=</span><span class="fl">0.9</span>, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> StepLR(optimizer, step_size<span class="op">=</span><span class="dv">30</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop with proper regularization</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(output, target)</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    scheduler.step()</span></code></pre></div></div>
</div>
<ol type="1">
<li><strong>Dropout</strong>: Applied within dense layers, particularly effective for preventing overfitting.</li>
<li><strong>Data Augmentation</strong>: Standard augmentation techniques remain highly effective.</li>
<li><strong>Weight Decay</strong>: Careful tuning of weight decay is important due to the parameter sharing characteristics.</li>
</ol>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="computer-vision-tasks" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-tasks" id="computer-vision-tasks">Computer Vision Tasks</h3>
<p>DenseNet excels in various computer vision applications:</p>
<ul>
<li><strong>Image Classification</strong>: Strong performance on standard benchmarks with parameter efficiency</li>
<li><strong>Object Detection</strong>: When used as a backbone in detection frameworks like Faster R-CNN or YOLO</li>
<li><strong>Semantic Segmentation</strong>: The feature reuse properties make DenseNet particularly suitable for dense prediction tasks</li>
<li><strong>Medical Imaging</strong>: The parameter efficiency and strong representation learning make it popular for medical image analysis where data is often limited</li>
</ul>
</section>
<section id="transfer-learning" class="level3">
<h3 class="anchored" data-anchor-id="transfer-learning" id="transfer-learning">Transfer Learning</h3>
<p>DenseNet’s feature reuse properties make it particularly effective for transfer learning scenarios:</p>
<div id="transfer-learning-example" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Transfer learning with pre-trained DenseNet</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.models <span class="im">as</span> models</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained DenseNet-121</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.densenet121(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Freeze feature extraction layers</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> param <span class="kw">in</span> model.features.parameters():</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Replace classifier for new task</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>num_features <span class="op">=</span> model.classifier.in_features</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>model.classifier <span class="op">=</span> nn.Linear(num_features, num_classes_new_task)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Only classifier parameters will be updated during training</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> optim.Adam(model.classifier.parameters(), lr<span class="op">=</span><span class="fl">0.001</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="comparison-with-other-architectures" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-other-architectures" id="comparison-with-other-architectures">Comparison with Other Architectures</h2>
<section id="densenet-vs-resnet" class="level3">
<h3 class="anchored" data-anchor-id="densenet-vs-resnet" id="densenet-vs-resnet">DenseNet vs ResNet</h3>
<div id="tbl-densenet-resnet" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-densenet-resnet-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: DenseNet vs ResNet Comparison
</figcaption>
<div aria-describedby="tbl-densenet-resnet-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Aspect</th>
<th>DenseNet</th>
<th>ResNet</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Parameter Efficiency</td>
<td>✅ Better</td>
<td>❌ More parameters</td>
</tr>
<tr class="even">
<td>Gradient Flow</td>
<td>✅ Stronger</td>
<td>✅ Good</td>
</tr>
<tr class="odd">
<td>Memory Requirements</td>
<td>❌ Higher during training</td>
<td>✅ Lower</td>
</tr>
<tr class="even">
<td>Implementation</td>
<td>❌ More complex</td>
<td>✅ Simpler</td>
</tr>
<tr class="odd">
<td>Feature Reuse</td>
<td>✅ Excellent</td>
<td>❌ Limited</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="densenet-vs-inception" class="level3">
<h3 class="anchored" data-anchor-id="densenet-vs-inception" id="densenet-vs-inception">DenseNet vs Inception</h3>
<p><strong>DenseNet Advantages</strong>:</p>
<ul>
<li>Simpler architectural design</li>
<li>More consistent performance across tasks<br>
</li>
<li>Better parameter efficiency</li>
</ul>
<p><strong>Inception Advantages</strong>:</p>
<ul>
<li>More flexible computational budget allocation</li>
<li>Better computational efficiency in some scenarios</li>
</ul>
</section>
</section>
<section id="recent-developments-and-variants" class="level2">
<h2 class="anchored" data-anchor-id="recent-developments-and-variants" id="recent-developments-and-variants">Recent Developments and Variants</h2>
<section id="densenet-extensions" class="level3">
<h3 class="anchored" data-anchor-id="densenet-extensions" id="densenet-extensions">DenseNet Extensions</h3>
<p>Several extensions and improvements to DenseNet have been proposed:</p>
<ul>
<li><strong>CondenseNet</strong>: Introduces learned sparse connectivity to improve computational efficiency while maintaining the benefits of dense connections</li>
<li><strong>PeleeNet</strong>: Optimizes DenseNet for mobile and embedded applications through architectural modifications and compression techniques</li>
<li><strong>DenseNet with Attention</strong>: Incorporates attention mechanisms to further improve feature selection and representation learning</li>
</ul>
</section>
<section id="integration-with-modern-techniques" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-modern-techniques" id="integration-with-modern-techniques">Integration with Modern Techniques</h3>
<p>DenseNet continues to be relevant in modern deep learning through integration with contemporary techniques:</p>
<ol type="1">
<li><strong>Neural Architecture Search (NAS)</strong>: DenseNet-inspired connectivity patterns appear in many NAS-discovered architectures</li>
<li><strong>Vision Transformers</strong>: Some hybrid approaches combine DenseNet-style connectivity with transformer architectures</li>
<li><strong>EfficientNet Integration</strong>: Combining DenseNet principles with compound scaling methods for improved efficiency</li>
</ol>
</section>
</section>
<section id="best-practices-and-recommendations" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-recommendations" id="best-practices-and-recommendations">Best Practices and Recommendations</h2>
<section id="architecture-design" class="level3">
<h3 class="anchored" data-anchor-id="architecture-design" id="architecture-design">Architecture Design</h3>
<p>When designing DenseNet-based architectures:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Design Guidelines
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Growth Rate Selection</strong>: Start with k=32 for large-scale tasks, k=12 for smaller datasets or computational constraints</li>
<li><strong>Block Configuration</strong>: Use proven configurations (6,12,24,16 for DenseNet-121) as starting points, adjusting based on specific requirements<br>
</li>
<li><strong>Compression Strategy</strong>: Maintain θ=0.5 unless specific memory or computational constraints require adjustment</li>
</ol>
</div>
</div>
</section>
<section id="implementation-guidelines" class="level3">
<h3 class="anchored" data-anchor-id="implementation-guidelines" id="implementation-guidelines">Implementation Guidelines</h3>
<ol type="1">
<li><strong>Memory Management</strong>: Implement efficient concatenation operations and consider memory-efficient variants for resource-constrained environments</li>
<li><strong>Batch Normalization</strong>: Ensure proper batch normalization placement and initialization for optimal training dynamics</li>
<li><strong>Regularization</strong>: Apply dropout judiciously, particularly in deeper layers and for smaller datasets</li>
</ol>
</section>
<section id="training-optimization" class="level3">
<h3 class="anchored" data-anchor-id="training-optimization" id="training-optimization">Training Optimization</h3>
<ol type="1">
<li><strong>Learning Rate</strong>: Start with standard learning rates but be prepared to adjust based on the specific connectivity pattern effects</li>
<li><strong>Batch Size</strong>: Use larger batch sizes when possible to leverage the batch normalization layers effectively</li>
<li><strong>Augmentation</strong>: Standard augmentation techniques remain highly effective and often crucial for preventing overfitting</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>DenseNet represents a fundamental advancement in convolutional neural network design, demonstrating that architectural innovations can achieve better performance with fewer parameters through improved connectivity patterns. The dense connectivity paradigm offers several key advantages: enhanced gradient flow, feature reuse, parameter efficiency, and implicit deep supervision.</p>
<p>While DenseNet introduces some implementation complexity and memory considerations, these challenges are outweighed by its strong empirical performance and theoretical elegance. The architecture’s influence extends beyond its direct applications, inspiring subsequent architectural innovations and contributing to our understanding of effective connectivity patterns in deep networks.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>DenseNet achieves better parameter efficiency through feature reuse</li>
<li>Dense connectivity ensures robust gradient flow and training stability<br>
</li>
<li>Memory optimization strategies are crucial for practical implementation</li>
<li>The architecture remains relevant through integration with modern techniques</li>
</ul>
</div>
</div>
<p>The continued relevance of DenseNet in modern deep learning, through extensions, variants, and integration with contemporary techniques, underscores its fundamental contribution to the field. For practitioners, DenseNet offers a compelling choice when parameter efficiency, strong performance, and architectural elegance are priorities.</p>
<p>As the field continues to evolve, the principles underlying DenseNet—maximizing information flow, promoting feature reuse, and enabling efficient gradient propagation—remain valuable guideposts for future architectural innovations. The dense connectivity pattern pioneered by DenseNet continues to influence modern architecture design, from Vision Transformers to Neural Architecture Search discoveries, ensuring its lasting impact on deep learning research and practice.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DenseNet: Densely Connected Convolutional Networks]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-summary/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-summary/</guid>
      <pubDate>Sat, 19 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="densenet-densely-connected-convolutional-networks" class="level1 page-columns page-full">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-summary/densenet.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>DenseNet (Densely Connected Convolutional Networks) represents a significant advancement in deep learning architecture design, introduced by Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Q. Weinberger in their 2017 paper “Densely Connected Convolutional Networks.” This architecture addresses fundamental challenges in training very deep neural networks while achieving remarkable efficiency and performance across various computer vision tasks.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>The core innovation of DenseNet lies in its dense connectivity pattern, where each layer receives feature maps from all preceding layers and passes its own feature maps to all subsequent layers.</p>
</div>
</div>
<p>This seemingly simple modification to traditional convolutional architectures yields profound improvements in gradient flow, feature reuse, and parameter efficiency.</p>
</section>
<section id="the-problem-with-traditional-deep-networks" class="level2">
<h2 class="anchored" data-anchor-id="the-problem-with-traditional-deep-networks" id="the-problem-with-traditional-deep-networks">The Problem with Traditional Deep Networks</h2>
<p>Before understanding DenseNet’s innovations, it’s crucial to recognize the challenges that deep convolutional networks face as they grow deeper. Traditional architectures like VGG and early versions of ResNet suffered from several key issues:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Vanishing Gradients</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Information Loss</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Parameter Inefficiency</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-4" role="tab" aria-controls="tabset-1-4" aria-selected="false" href="">Feature Reuse Limitations</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>As networks become deeper, gradients can become exponentially smaller during backpropagation, making it difficult to train the early layers effectively. This leads to poor convergence and suboptimal performance.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>In conventional feed-forward architectures, information flows linearly from input to output. As information passes through multiple layers, important details from earlier layers can be lost or diluted.</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Many deep networks contain redundant parameters that don’t contribute meaningfully to the final prediction. This inefficiency leads to larger models without proportional performance gains.</p>
</div>
<div id="tabset-1-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-4-tab">
<p>Traditional architectures don’t effectively reuse features computed in earlier layers, leading to redundant computations and missed opportunities for feature combination.</p>
</div>
</div>
</div>
</section>
<section id="sec-architecture" class="level2 page-columns page-full">
<h2 class="anchored" data-anchor-id="sec-architecture" id="sec-architecture">DenseNet Architecture Overview</h2>
<p>DenseNet addresses these challenges through its distinctive dense connectivity pattern. The architecture is built around dense blocks, where each layer within a block receives inputs from all preceding layers in that block.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-densenet-architecture" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-densenet-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-densenet-architecture">graph TD
    A[Input Image] --&gt; B[Initial Conv Layer]
    B --&gt; C[Dense Block 1]
    C --&gt; D[Transition Layer 1]
    D --&gt; E[Dense Block 2]
    E --&gt; F[Transition Layer 2]
    F --&gt; G[Dense Block 3]
    G --&gt; H[Transition Layer 3]
    H --&gt; I[Dense Block 4]
    I --&gt; J[Global Average Pooling]
    J --&gt; K[Classifier]
    K --&gt; L[Output]
    
    style C fill:#e1f5fe
    style E fill:#e1f5fe
    style G fill:#e1f5fe
    style I fill:#e1f5fe
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-densenet-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: DenseNet Architecture Overview showing dense blocks and transition layers
</figcaption>
</figure>
</div>
</div>
</div>
<section id="dense-blocks" class="level3">
<h3 class="anchored" data-anchor-id="dense-blocks" id="dense-blocks">Dense Blocks</h3>
<p>The fundamental building unit of DenseNet is the dense block. Within each dense block, the <span class="math inline">\(l\)</span>-th layer receives feature maps from all preceding layers <span class="math inline">\((x_0, x_1, ..., x_{l-1})\)</span> and produces <span class="math inline">\(k\)</span> feature maps as output.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Composite Function
</div>
</div>
<div class="callout-body-container callout-body">
<p>The composite function <span class="math inline">\(H_l\)</span> typically consists of:</p>
<ol type="1">
<li><strong>Batch Normalization (BN)</strong></li>
<li><strong>ReLU activation</strong></li>
<li><strong>3×3 Convolution</strong></li>
</ol>
</div>
</div>
<p>The key equation governing dense connectivity is:</p>
<p><span class="math display">\[
x_l = H_l([x_0, x_1, ..., x_{l-1}])
\]</span></p>
<p>Where <span class="math inline">\([x_0, x_1, ..., x_{l-1}]\)</span> represents the concatenation of feature maps from layers 0, 1, …, <span class="math inline">\(l-1\)</span>.</p>
</section>
<section id="growth-rate" class="level3">
<h3 class="anchored" data-anchor-id="growth-rate" id="growth-rate">Growth Rate</h3>
<p>The growth rate (<span class="math inline">\(k\)</span>) is a hyperparameter that determines how many feature maps each layer adds to the “collective knowledge” of the network. Even with small growth rates (<span class="math inline">\(k=12\)</span> or <span class="math inline">\(k=32\)</span>), DenseNet achieves excellent performance because each layer has access to all preceding feature maps within the block.</p>
</section>
<section id="transition-layers" class="level3 page-columns page-full">
<h3 class="anchored" data-anchor-id="transition-layers" id="transition-layers">Transition Layers</h3>
<p>Between dense blocks, transition layers perform dimensionality reduction and spatial downsampling:</p>

<div class="no-row-height column-margin column-container"><div class="">
<p><strong>Transition Layer Components:</strong></p>
<ul>
<li>Batch Normalization</li>
<li>1×1 Convolution (channel reduction)</li>
<li>2×2 Average Pooling (spatial downsampling)</li>
</ul>
</div></div><p>The compression factor <span class="math inline">\(\theta\)</span> (typically 0.5) determines how much the number of channels is reduced in transition layers, helping control model complexity.</p>
</section>
</section>
<section id="sec-innovations" class="level2">
<h2 class="anchored" data-anchor-id="sec-innovations" id="sec-innovations">Key Innovations and Benefits</h2>
<section id="enhanced-gradient-flow" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-gradient-flow" id="enhanced-gradient-flow">Enhanced Gradient Flow</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-gradient-flow" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-gradient-flow-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-gradient-flow">graph LR
    subgraph id2[Traditional Network]
        A1[Layer 1] --&gt; A2[Layer 2] --&gt; A3[Layer 3] --&gt; A4[Layer 4]
    end
    
    subgraph id1[DenseNet]
        B1[Layer 1] --&gt; B2[Layer 2]
        B1 --&gt; B3[Layer 3]
        B1 --&gt; B4[Layer 4]
        B2 --&gt; B3
        B2 --&gt; B4
        B3 --&gt; B4
    end

    style id1 fill:#ffffff
    style id2 fill:#ffffff
    style A1 fill:#c8e6c9
    style A2 fill:#c8e6c9
    style A3 fill:#c8e6c9
    style A4 fill:#c8e6c9
    style B1 fill:#c8e6c9
    style B2 fill:#c8e6c9
    style B3 fill:#c8e6c9
    style B4 fill:#c8e6c9
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-gradient-flow-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;2: Comparison of gradient flow in traditional networks vs DenseNet
</figcaption>
</figure>
</div>
</div>
</div>
<p>DenseNet’s dense connections create multiple short paths between any two layers, significantly improving gradient flow during backpropagation. This addresses the vanishing gradient problem that plagued earlier deep architectures.</p>
</section>
<section id="feature-reuse-and-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="feature-reuse-and-efficiency" id="feature-reuse-and-efficiency">Feature Reuse and Efficiency</h3>
<p>The dense connectivity pattern maximizes information flow and feature reuse throughout the network. Later layers can directly access features from all earlier layers, eliminating the need to recompute similar features.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Parameter Efficiency
</div>
</div>
<div class="callout-body-container callout-body">
<p>A DenseNet-121 with 7.0M parameters can outperform a ResNet-152 with 60.2M parameters on ImageNet.</p>
</div>
</div>
</section>
<section id="regularization-effect" class="level3">
<h3 class="anchored" data-anchor-id="regularization-effect" id="regularization-effect">Regularization Effect</h3>
<p>The dense connections inherently provide a regularization effect. Since each layer contributes to multiple subsequent layers’ inputs, the network is less likely to overfit to specific pathways.</p>
</section>
</section>
<section id="sec-variants" class="level2">
<h2 class="anchored" data-anchor-id="sec-variants" id="sec-variants">DenseNet Variants and Configurations</h2>
<section id="standard-densenet-architectures" class="level3">
<h3 class="anchored" data-anchor-id="standard-densenet-architectures" id="standard-densenet-architectures">Standard DenseNet Architectures</h3>
<div id="tbl-densenet-configs" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-densenet-configs-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: DenseNet standard configurations
</figcaption>
<div aria-describedby="tbl-densenet-configs-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Dense Blocks</th>
<th>Layers per Block</th>
<th>Growth Rate</th>
<th>Parameters</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>DenseNet-121</td>
<td>4</td>
<td>[6, 12, 24, 16]</td>
<td>k=32</td>
<td>7.0M</td>
</tr>
<tr class="even">
<td>DenseNet-169</td>
<td>4</td>
<td>[6, 12, 32, 32]</td>
<td>k=32</td>
<td>12.6M</td>
</tr>
<tr class="odd">
<td>DenseNet-201</td>
<td>4</td>
<td>[6, 12, 48, 32]</td>
<td>k=32</td>
<td>18.3M</td>
</tr>
<tr class="even">
<td>DenseNet-264</td>
<td>4</td>
<td>[6, 12, 64, 48]</td>
<td>k=32</td>
<td>33.3M</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="densenet-bc-bottleneck-and-compression" class="level3">
<h3 class="anchored" data-anchor-id="densenet-bc-bottleneck-and-compression" id="densenet-bc-bottleneck-and-compression">DenseNet-BC (Bottleneck and Compression)</h3>
<p>DenseNet-BC introduces two important modifications:</p>
<div class="columns">
<div class="column" style="width:50%;">
<p><strong>Bottleneck Layers</strong> Each 3×3 convolution is preceded by a 1×1 convolution that reduces the number of input channels to 4k, improving computational efficiency.</p>
</div><div class="column" style="width:50%;">
<p><strong>Compression</strong> Transition layers reduce the number of channels by a factor <span class="math inline">\(\theta &lt; 1\)</span>, typically 0.5, which helps control model size and computational cost.</p>
</div>
</div>
</section>
</section>
<section id="implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="implementation-details" id="implementation-details">Implementation Details</h2>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Pseudocode for memory-efficient DenseNet implementation</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_densenet_forward(x, layers):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>    features <span class="op">=</span> [x]</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> layer <span class="kw">in</span> layers:</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use checkpointing for memory efficiency</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>        new_features <span class="op">=</span> checkpoint(layer, torch.cat(features, <span class="dv">1</span>))</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        features.append(new_features)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.cat(features, <span class="dv">1</span>)</span></code></pre></div></div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Memory Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<p>One challenge with DenseNet is memory consumption due to concatenating feature maps from all previous layers. Several optimization strategies address this:</p>
<ul>
<li><strong>Memory-Efficient Implementation</strong>: Using checkpointing and careful memory management</li>
<li><strong>Shared Memory Allocations</strong>: Reusing memory buffers for intermediate computations</li>
<li><strong>Gradient Checkpointing</strong>: Trading computation for memory</li>
</ul>
</div>
</div>
</section>
<section id="training-considerations" class="level3">
<h3 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h3>
<p>Training DenseNet effectively requires attention to several factors:</p>
<ul>
<li><strong>Learning Rate Schedule</strong>: Often benefits from more gradual decay compared to ResNet</li>
<li><strong>Batch Size</strong>: Due to memory requirements, smaller batch sizes are often necessary</li>
<li><strong>Data Augmentation</strong>: Standard techniques work well (random crops, horizontal flips, color jittering)</li>
</ul>
</section>
</section>
<section id="performance-and-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-benchmarks" id="performance-and-benchmarks">Performance and Benchmarks</h2>
<section id="imagenet-classification" class="level3">
<h3 class="anchored" data-anchor-id="imagenet-classification" id="imagenet-classification">ImageNet Classification</h3>
<div id="cell-fig-imagenet-performance" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Data for different models</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> [<span class="st">'DenseNet-121'</span>, <span class="st">'DenseNet-169'</span>, <span class="st">'DenseNet-201'</span>, <span class="st">'ResNet-50'</span>, <span class="st">'ResNet-101'</span>, <span class="st">'ResNet-152'</span>]</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>params <span class="op">=</span> [<span class="fl">7.0</span>, <span class="fl">12.6</span>, <span class="fl">18.3</span>, <span class="fl">25.6</span>, <span class="fl">44.5</span>, <span class="fl">60.2</span>]  <span class="co"># in millions</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>error_rates <span class="op">=</span> [<span class="fl">25.35</span>, <span class="fl">24.00</span>, <span class="fl">22.58</span>, <span class="fl">23.85</span>, <span class="fl">22.63</span>, <span class="fl">23.00</span>]</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the plot</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>colors <span class="op">=</span> [<span class="st">'#2E8B57'</span> <span class="cf">if</span> <span class="st">'DenseNet'</span> <span class="kw">in</span> model <span class="cf">else</span> <span class="st">'#CD5C5C'</span> <span class="cf">for</span> model <span class="kw">in</span> models]</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>plt.scatter(params, error_rates, c<span class="op">=</span>colors, s<span class="op">=</span><span class="dv">100</span>, alpha<span class="op">=</span><span class="fl">0.7</span>)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, model <span class="kw">in</span> <span class="bu">enumerate</span>(models):</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    plt.annotate(model, (params[i], error_rates[i]), </span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>                xytext<span class="op">=</span>(<span class="dv">5</span>, <span class="dv">5</span>), textcoords<span class="op">=</span><span class="st">'offset points'</span>, fontsize<span class="op">=</span><span class="dv">9</span>)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'Parameters (millions)'</span>)</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'ImageNet Top-1 Error Rate (%)'</span>)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Parameter Efficiency Comparison'</span>)</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>plt.grid(<span class="va">True</span>, alpha<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>plt.legend([<span class="st">'DenseNet'</span>, <span class="st">'ResNet'</span>], loc<span class="op">=</span><span class="st">'upper right'</span>)</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div id="fig-imagenet-performance" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-imagenet-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/dense-net/dense-net-summary/fig-imagenet-performance-output-1.png" width="863" height="523" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-imagenet-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;3: ImageNet top-1 error rates vs number of parameters for different architectures
</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="cifar-datasets" class="level3">
<h3 class="anchored" data-anchor-id="cifar-datasets" id="cifar-datasets">CIFAR Datasets</h3>
<div class="quarto-layout-panel" data-layout-ncol="2">
<div class="quarto-layout-row">
<section id="cifar-10-results" class="level4 quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<h4 class="anchored" data-anchor-id="cifar-10-results">CIFAR-10 Results</h4>
<ul>
<li><strong>DenseNet (L=190, k=40)</strong>: 3.46% error rate</li>
<li>Excellent performance on this benchmark dataset</li>
</ul>
</section>
<section id="cifar-100-results" class="level4 quarto-layout-cell" style="flex-basis: 50.0%;justify-content: flex-start;">
<h4 class="anchored" data-anchor-id="cifar-100-results">CIFAR-100 Results</h4>
<ul>
<li><strong>DenseNet (L=190, k=40)</strong>: 17.18% error rate</li>
<li>Superior to many contemporary architectures</li>
</ul>
</section>
</div>
</div>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="computer-vision-tasks" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-tasks" id="computer-vision-tasks">Computer Vision Tasks</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div id="fig-applications" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-applications-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<div>
<pre class="mermaid mermaid-js" data-label="fig-applications">mindmap
  root((DenseNet Applications))
    Classification
      ImageNet
      Medical Imaging
      Remote Sensing
    Detection
      Object Detection
      Face Detection
      Autonomous Driving
    Segmentation
      Semantic Segmentation
      Medical Segmentation
      Industrial Inspection
    Transfer Learning
      Fine-grained Classification
      Domain Adaptation
      Few-shot Learning
</pre>
</div>
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-applications-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;4: DenseNet applications across different computer vision tasks
</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="domain-specific-adaptations" class="level3">
<h3 class="anchored" data-anchor-id="domain-specific-adaptations" id="domain-specific-adaptations">Domain-Specific Adaptations</h3>
<ul>
<li><strong>Medical Imaging</strong>: Parameter efficiency valuable when data is limited</li>
<li><strong>Remote Sensing</strong>: Multi-scale feature capture for satellite imagery</li>
<li><strong>Industrial Applications</strong>: Quality control and defect detection</li>
</ul>
</section>
</section>
<section id="advantages-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="advantages-and-limitations" id="advantages-and-limitations">Advantages and Limitations</h2>
<div class="grid">
<section id="advantages" class="level3 g-col-6">
<h3 class="anchored" data-anchor-id="advantages" id="advantages">✅ Advantages</h3>
<ul>
<li><strong>Parameter Efficiency</strong>: Better performance with fewer parameters</li>
<li><strong>Strong Gradient Flow</strong>: Robust gradient propagation</li>
<li><strong>Feature Reuse</strong>: Maximum utilization of learned features</li>
<li><strong>Implicit Regularization</strong>: Natural overfitting resistance</li>
<li><strong>Transfer Learning</strong>: Features transfer well to new domains</li>
</ul>
</section>
<section id="limitations" class="level3 g-col-6">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">⚠️ Limitations</h3>
<ul>
<li><strong>Memory Consumption</strong>: Higher memory usage due to concatenations</li>
<li><strong>Computational Overhead</strong>: Feature concatenation operations</li>
<li><strong>Training Complexity</strong>: Requires careful hyperparameter tuning</li>
<li><strong>Scalability</strong>: Memory constraints for very large inputs</li>
</ul>
</section>
</div>
</section>
<section id="comparison-with-other-architectures" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-other-architectures" id="comparison-with-other-architectures">Comparison with Other Architectures</h2>
<section id="densenet-vs.-resnet" class="level3">
<h3 class="anchored" data-anchor-id="densenet-vs.-resnet" id="densenet-vs.-resnet">DenseNet vs.&nbsp;ResNet</h3>
<div id="tbl-densenet-resnet" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-densenet-resnet-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: Comparison between DenseNet and ResNet architectures
</figcaption>
<div aria-describedby="tbl-densenet-resnet-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Aspect</th>
<th>DenseNet</th>
<th>ResNet</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Connections</strong></td>
<td>Feature concatenation</td>
<td>Element-wise addition</td>
</tr>
<tr class="even">
<td><strong>Parameters</strong></td>
<td>More efficient</td>
<td>More parameters needed</td>
</tr>
<tr class="odd">
<td><strong>Memory</strong></td>
<td>Higher usage</td>
<td>Lower usage</td>
</tr>
<tr class="even">
<td><strong>Feature Reuse</strong></td>
<td>Explicit reuse</td>
<td>Limited reuse</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
</section>
<section id="recent-developments-and-extensions" class="level2">
<h2 class="anchored" data-anchor-id="recent-developments-and-extensions" id="recent-developments-and-extensions">Recent Developments and Extensions</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Modern Extensions
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>3D DenseNet</strong>: For video analysis and 3D medical imaging</li>
<li><strong>Attention-enhanced DenseNet</strong>: Integration with self-attention mechanisms</li>
<li><strong>Mobile DenseNet</strong>: Lightweight variants for edge deployment</li>
<li><strong>NAS-discovered DenseNet</strong>: Architectures found through Neural Architecture Search</li>
</ul>
</div>
</div>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<p>The dense connectivity principle continues to influence modern architecture design:</p>
<ol type="1">
<li><strong>Adaptive Connectivity</strong>: Learning optimal connection patterns</li>
<li><strong>Memory-Efficient Variants</strong>: Maintaining benefits while reducing memory</li>
<li><strong>Multi-Modal Applications</strong>: Extending to multi-modal learning</li>
<li><strong>Continual Learning</strong>: Leveraging dense connectivity for lifelong learning</li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>DenseNet represents a fundamental shift in how we think about information flow in deep neural networks. By connecting each layer to every other layer in a feed-forward fashion, DenseNet addresses key challenges in training very deep networks while achieving remarkable parameter efficiency.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<p>The architecture’s success stems from its ability to:</p>
<ul>
<li><strong>Maximize information flow</strong> and feature reuse</li>
<li><strong>Achieve stronger gradient flow</strong> and implicit regularization<br>
</li>
<li><strong>Create compact yet powerful</strong> models</li>
<li><strong>Provide excellent transferability</strong> across domains</li>
</ul>
</div>
</div>
<p>For practitioners, DenseNet offers an excellent balance of performance, efficiency, and transferability, making it a valuable tool in the deep learning toolkit. Its principles continue to inspire new developments in neural architecture design.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete MobileNet Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/</guid>
      <pubDate>Sat, 19 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="complete-mobilenet-code-guide" class="level1 page-columns page-full">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/mobnet.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>MobileNet is a family of efficient neural network architectures designed specifically for mobile and embedded devices. The key innovation is the use of <strong>depthwise separable convolutions</strong> which dramatically reduce the number of parameters and computational cost while maintaining reasonable accuracy.</p>
</section>
<section id="prerequisites-and-setup" class="level2 page-columns page-full">
<h2 class="anchored" data-anchor-id="prerequisites-and-setup" id="prerequisites-and-setup">Prerequisites and Setup</h2>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Requirements
</div>
</div>
<div class="callout-body-container callout-body">
<p>Before diving into MobileNet implementation, ensure you have the following prerequisites:</p>
<p><strong>Software Requirements:</strong></p>
<ul>
<li>Python 3.8+</li>
<li>PyTorch 1.12+</li>
<li>torchvision 0.13+</li>
<li>CUDA (optional, for GPU acceleration)</li>
</ul>
<p><strong>Hardware Recommendations:</strong></p>
<ul>
<li>8GB+ RAM for training</li>
<li>NVIDIA GPU with 4GB+ VRAM (recommended)</li>
<li>SSD storage for faster data loading</li>
</ul>
</div>
</div>
<section id="installation" class="level3">
<h3 class="anchored" data-anchor-id="installation" id="installation">Installation</h3>
<div id="installation" class="cell" data-caption="Install required packages" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Core dependencies</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>pip install torch torchvision torchaudio</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>pip install numpy matplotlib seaborn</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>pip install pillow opencv<span class="op">-</span>python</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Optional dependencies for advanced features</span></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>pip install tensorboard  <span class="co"># For training visualization</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>pip install ptflops      <span class="co"># For FLOPs calculation  </span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>pip install onnx onnxruntime  <span class="co"># For model export</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>pip install coremltools      <span class="co"># For iOS deployment</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>pip install tensorflow<span class="op">-</span>lite  <span class="co"># For Android deployment</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Development tools</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>pip install jupyter notebook</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>pip install black isort      <span class="co"># Code formatting</span></span></code></pre></div></div>
</div>
</section>
<section id="quick-start-example" class="level3">
<h3 class="anchored" data-anchor-id="quick-start-example" id="quick-start-example">Quick Start Example</h3>
<p>Here’s a minimal example to get you started with MobileNet:</p>
<div id="quick-start" class="cell" data-caption="Quick start example with pre-trained MobileNet" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.models <span class="im">as</span> models</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained model</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"✅ Model loaded successfully!"</span>)</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"📊 Total parameters: </span><span class="sc">{</span><span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"💾 Model size: </span><span class="sc">{</span><span class="bu">sum</span>(p.numel() <span class="op">*</span> p.element_size() <span class="cf">for</span> p <span class="kw">in</span> model.parameters()) <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span><span class="sc">:.1f}</span><span class="ss"> MB"</span>)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Test with random input</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> model(dummy_input)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"🎯 Output shape: </span><span class="sc">{</span>output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"🔥 Top prediction: Class </span><span class="sc">{</span>torch<span class="sc">.</span>argmax(output)<span class="sc">.</span>item()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>✅ Model loaded successfully!
📊 Total parameters: 3,504,872
💾 Model size: 13.4 MB
🎯 Output shape: torch.Size([1, 1000])
🔥 Top prediction: Class 644</code></pre>
</div>
</div>
</section>
<section id="key-features" class="level3">
<h3 class="anchored" data-anchor-id="key-features" id="key-features">Key Features</h3>
<ul>
<li><strong>Efficient</strong>: 50x fewer parameters than AlexNet</li>
<li><strong>Fast</strong>: Optimized for mobile inference</li>
<li><strong>Flexible</strong>: Width and resolution multipliers for different use cases</li>
<li><strong>Accurate</strong>: Competitive performance on ImageNet</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>MobileNet Efficiency
</div>
</div>
<div class="callout-body-container callout-body">
<p>MobileNet achieves its efficiency through depthwise separable convolutions, which split standard convolutions into two operations: depthwise and pointwise convolutions.</p>
</div>
</div>
</section>
<section id="architecture-comparison-table" class="level3 page-columns page-full">
<h3 class="anchored" data-anchor-id="architecture-comparison-table" id="architecture-comparison-table">Architecture Comparison Table</h3>
<table class="caption-top table">
<colgroup>
<col style="width: 15%">
<col style="width: 17%">
<col style="width: 11%">
<col style="width: 20%">
<col style="width: 18%">
<col style="width: 16%">
</colgroup>
<thead>
<tr class="header">
<th>Architecture</th>
<th>Parameters (M)</th>
<th>FLOPs (M)</th>
<th>Top-1 Accuracy (%)</th>
<th>Model Size (MB)</th>
<th>Target Device</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>AlexNet</td>
<td>61.0</td>
<td>714</td>
<td>56.5</td>
<td>233</td>
<td>Desktop</td>
</tr>
<tr class="even">
<td>VGG-16</td>
<td>138.0</td>
<td>15500</td>
<td>71.5</td>
<td>528</td>
<td>Desktop</td>
</tr>
<tr class="odd">
<td>ResNet-50</td>
<td>25.6</td>
<td>4100</td>
<td>76.1</td>
<td>98</td>
<td>Server</td>
</tr>
<tr class="even">
<td>MobileNet-V1</td>
<td>4.2</td>
<td>569</td>
<td>70.6</td>
<td>16</td>
<td>Mobile</td>
</tr>
<tr class="odd">
<td>MobileNet-V2</td>
<td>3.4</td>
<td>300</td>
<td>72.0</td>
<td>14</td>
<td>Mobile</td>
</tr>
<tr class="even">
<td>EfficientNet-B0</td>
<td>5.3</td>
<td>390</td>
<td>77.3</td>
<td>21</td>
<td>Mobile</td>
</tr>
</tbody>
</table>

<div class="no-row-height column-margin column-container"><div class="">
<p><strong>Note:</strong> Accuracy values are for ImageNet classification. FLOPs calculated for 224×224 input images.</p>
</div></div></section>
</section>
<section id="mobilenet-architecture" class="level2">
<h2 class="anchored" data-anchor-id="mobilenet-architecture" id="mobilenet-architecture">MobileNet Architecture</h2>
<p>The MobileNet architecture consists of:</p>
<ol type="1">
<li><strong>Standard 3×3 convolution</strong> (first layer)</li>
<li><strong>13 depthwise separable convolution blocks</strong></li>
<li><strong>Average pooling and fully connected layer</strong></li>
</ol>
<section id="architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h3>
<div id="mobilenet-architecture" class="cell" data-caption="Complete MobileNet architecture implementation" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNet(nn.Module):</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">1000</span>, width_mult<span class="op">=</span><span class="fl">1.0</span>):</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(MobileNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># First standard convolution</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="bu">int</span>(<span class="dv">32</span> <span class="op">*</span> width_mult), <span class="dv">3</span>, stride<span class="op">=</span><span class="dv">2</span>, padding<span class="op">=</span><span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bn1 <span class="op">=</span> nn.BatchNorm2d(<span class="bu">int</span>(<span class="dv">32</span> <span class="op">*</span> width_mult))</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Depthwise separable convolution blocks</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList([</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">32</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">64</span> <span class="op">*</span> width_mult), <span class="dv">1</span>),</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">64</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">128</span> <span class="op">*</span> width_mult), <span class="dv">2</span>),</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">128</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">128</span> <span class="op">*</span> width_mult), <span class="dv">1</span>),</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">128</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">256</span> <span class="op">*</span> width_mult), <span class="dv">2</span>),</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">256</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">256</span> <span class="op">*</span> width_mult), <span class="dv">1</span>),</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">256</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">512</span> <span class="op">*</span> width_mult), <span class="dv">2</span>),</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># 5 layers with stride 1</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>            <span class="op">*</span>[<span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">512</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">512</span> <span class="op">*</span> width_mult), <span class="dv">1</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>)],</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">512</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">1024</span> <span class="op">*</span> width_mult), <span class="dv">2</span>),</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._make_layer(<span class="bu">int</span>(<span class="dv">1024</span> <span class="op">*</span> width_mult), <span class="bu">int</span>(<span class="dv">1024</span> <span class="op">*</span> width_mult), <span class="dv">1</span>),</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global average pooling and classifier</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.avgpool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="bu">int</span>(<span class="dv">1024</span> <span class="op">*</span> width_mult), num_classes)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _make_layer(<span class="va">self</span>, in_channels, out_channels, stride):</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DepthwiseSeparableConv(in_channels, out_channels, stride)</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu6(<span class="va">self</span>.bn1(<span class="va">self</span>.conv1(x)))</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> layer(x)</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.avgpool(x)</span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</div>
</section>
</section>
<section id="depthwise-separable-convolutions" class="level2">
<h2 class="anchored" data-anchor-id="depthwise-separable-convolutions" id="depthwise-separable-convolutions">Depthwise Separable Convolutions</h2>
<p>The core innovation of MobileNet is the depthwise separable convolution, which splits a standard convolution into two operations:</p>
<section id="depthwise-convolution" class="level3">
<h3 class="anchored" data-anchor-id="depthwise-convolution" id="depthwise-convolution">Depthwise Convolution</h3>
<p>Applies a single filter per input channel (spatial filtering):</p>
<div id="depthwise-conv" class="cell" data-caption="Depthwise convolution implementation" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DepthwiseConv(nn.Module):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, kernel_size<span class="op">=</span><span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DepthwiseConv, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.depthwise <span class="op">=</span> nn.Conv2d(</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>            in_channels, in_channels, </span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>            kernel_size<span class="op">=</span>kernel_size, </span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>            stride<span class="op">=</span>stride, </span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>            padding<span class="op">=</span>padding, </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>            groups<span class="op">=</span>in_channels,  <span class="co"># Key: groups = in_channels</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>            bias<span class="op">=</span><span class="va">False</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bn <span class="op">=</span> nn.BatchNorm2d(in_channels)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> F.relu6(<span class="va">self</span>.bn(<span class="va">self</span>.depthwise(x)))</span></code></pre></div></div>
</div>
</section>
<section id="pointwise-convolution" class="level3">
<h3 class="anchored" data-anchor-id="pointwise-convolution" id="pointwise-convolution">Pointwise Convolution</h3>
<p>Applies 1×1 convolution to combine features (channel mixing):</p>
<div id="pointwise-conv" class="cell" data-caption="Pointwise convolution implementation" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PointwiseConv(nn.Module):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(PointwiseConv, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pointwise <span class="op">=</span> nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bn <span class="op">=</span> nn.BatchNorm2d(out_channels)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> F.relu6(<span class="va">self</span>.bn(<span class="va">self</span>.pointwise(x)))</span></code></pre></div></div>
</div>
</section>
<section id="complete-depthwise-separable-block" class="level3">
<h3 class="anchored" data-anchor-id="complete-depthwise-separable-block" id="complete-depthwise-separable-block">Complete Depthwise Separable Block</h3>
<div id="depthwise-separable-block" class="cell" data-caption="Complete depthwise separable convolution block" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DepthwiseSeparableConv(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, stride<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DepthwiseSeparableConv, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.depthwise <span class="op">=</span> DepthwiseConv(in_channels, stride<span class="op">=</span>stride)</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pointwise <span class="op">=</span> PointwiseConv(in_channels, out_channels)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.depthwise(x)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pointwise(x)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</div>
</section>
<section id="computational-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h3>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Efficiency Gains
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Standard Convolution:</strong></p>
<ul>
<li>Parameters: <code>Dk × Dk × M × N</code></li>
<li>Computation: <code>Dk × Dk × M × N × Df × Df</code></li>
</ul>
<p><strong>Depthwise Separable Convolution:</strong></p>
<ul>
<li>Parameters: <code>Dk × Dk × M + M × N</code><br>
</li>
<li>Computation: <code>Dk × Dk × M × Df × Df + M × N × Df × Df</code></li>
</ul>
<p><strong>Reduction Factor:</strong> <code>1/N + 1/Dk²</code> (typically 8-9x reduction)</p>
</div>
</div>
</section>
<section id="efficiency-visualization" class="level3">
<h3 class="anchored" data-anchor-id="efficiency-visualization" id="efficiency-visualization">Efficiency Visualization</h3>
<p>Let’s visualize the efficiency gains of depthwise separable convolutions:</p>
<div id="cell-fig-efficiency-comparison" class="cell" data-fig-height="8" data-fig-width="12" data-execution_count="7">
<div class="cell-output cell-output-display">
<div id="fig-efficiency-comparison" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-efficiency-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/fig-efficiency-comparison-output-1.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-efficiency-comparison-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: Computational cost comparison: Standard vs Depthwise Separable Convolutions
</figcaption>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>
📊 **Efficiency Summary:**
   • Total FLOPs - Standard: 19710.7M
   • Total FLOPs - Depthwise: 2218.1M
   • **Overall Reduction: 8.9×**</code></pre>
</div>
</div>
</section>
</section>
<section id="implementation-from-scratch" class="level2">
<h2 class="anchored" data-anchor-id="implementation-from-scratch" id="implementation-from-scratch">Implementation from Scratch</h2>
<p>Here’s a complete implementation with detailed explanations:</p>
<div id="mobilenet-complete" class="cell" data-caption="Complete MobileNetV1 implementation with configurable parameters" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Optional</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetV1(nn.Module):</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="co">    MobileNetV1 implementation with configurable width and resolution multipliers.</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="co">        num_classes: Number of output classes</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a><span class="co">        width_mult: Width multiplier for channels (0.25, 0.5, 0.75, 1.0)</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="co">        resolution_mult: Resolution multiplier for input size</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="co">        dropout_rate: Dropout rate before classifier</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, </span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>                 num_classes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1000</span>, </span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>                 width_mult: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1.0</span>,</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>                 dropout_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.2</span>):</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(MobileNetV1, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.width_mult <span class="op">=</span> width_mult</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Helper function to make channels divisible by 8</span></span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> _make_divisible(v, divisor<span class="op">=</span><span class="dv">8</span>):</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>            new_v <span class="op">=</span> <span class="bu">max</span>(divisor, <span class="bu">int</span>(v <span class="op">+</span> divisor <span class="op">/</span> <span class="dv">2</span>) <span class="op">//</span> divisor <span class="op">*</span> divisor)</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> new_v <span class="op">&lt;</span> <span class="fl">0.9</span> <span class="op">*</span> v:</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>                new_v <span class="op">+=</span> divisor</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> new_v</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define channel configurations</span></span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        input_channel <span class="op">=</span> _make_divisible(<span class="dv">32</span> <span class="op">*</span> width_mult)</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># First standard convolution</span></span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, input_channel, <span class="dv">3</span>, <span class="dv">2</span>, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(input_channel),</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Configuration: [output_channels, stride]</span></span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        configs <span class="op">=</span> [</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">64</span>, <span class="dv">1</span>],</span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">128</span>, <span class="dv">2</span>], [<span class="dv">128</span>, <span class="dv">1</span>],</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">256</span>, <span class="dv">2</span>], [<span class="dv">256</span>, <span class="dv">1</span>],</span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">512</span>, <span class="dv">2</span>], [<span class="dv">512</span>, <span class="dv">1</span>], [<span class="dv">512</span>, <span class="dv">1</span>], [<span class="dv">512</span>, <span class="dv">1</span>], [<span class="dv">512</span>, <span class="dv">1</span>], [<span class="dv">512</span>, <span class="dv">1</span>],</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">1024</span>, <span class="dv">2</span>], [<span class="dv">1024</span>, <span class="dv">1</span>]</span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build depthwise separable layers</span></span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>        layers <span class="op">=</span> []</span>
<span id="cb9-53"><a href="#cb9-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> output_channel, stride <span class="kw">in</span> configs:</span>
<span id="cb9-54"><a href="#cb9-54" aria-hidden="true" tabindex="-1"></a>            output_channel <span class="op">=</span> _make_divisible(output_channel <span class="op">*</span> width_mult)</span>
<span id="cb9-55"><a href="#cb9-55" aria-hidden="true" tabindex="-1"></a>            layers.append(</span>
<span id="cb9-56"><a href="#cb9-56" aria-hidden="true" tabindex="-1"></a>                DepthwiseSeparableConv(input_channel, output_channel, stride)</span>
<span id="cb9-57"><a href="#cb9-57" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb9-58"><a href="#cb9-58" aria-hidden="true" tabindex="-1"></a>            input_channel <span class="op">=</span> output_channel</span>
<span id="cb9-59"><a href="#cb9-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-60"><a href="#cb9-60" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(<span class="op">*</span>layers)</span>
<span id="cb9-61"><a href="#cb9-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-62"><a href="#cb9-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classifier</span></span>
<span id="cb9-63"><a href="#cb9-63" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.avgpool <span class="op">=</span> nn.AdaptiveAvgPool2d((<span class="dv">1</span>, <span class="dv">1</span>))</span>
<span id="cb9-64"><a href="#cb9-64" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb9-65"><a href="#cb9-65" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(input_channel, num_classes)</span>
<span id="cb9-66"><a href="#cb9-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-67"><a href="#cb9-67" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize weights</span></span>
<span id="cb9-68"><a href="#cb9-68" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._initialize_weights()</span>
<span id="cb9-69"><a href="#cb9-69" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-70"><a href="#cb9-70" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _initialize_weights(<span class="va">self</span>):</span>
<span id="cb9-71"><a href="#cb9-71" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize weights using He initialization for ReLU networks."""</span></span>
<span id="cb9-72"><a href="#cb9-72" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> m <span class="kw">in</span> <span class="va">self</span>.modules():</span>
<span id="cb9-73"><a href="#cb9-73" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(m, nn.Conv2d):</span>
<span id="cb9-74"><a href="#cb9-74" aria-hidden="true" tabindex="-1"></a>                nn.init.kaiming_normal_(m.weight, mode<span class="op">=</span><span class="st">'fan_out'</span>, nonlinearity<span class="op">=</span><span class="st">'relu'</span>)</span>
<span id="cb9-75"><a href="#cb9-75" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> m.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb9-76"><a href="#cb9-76" aria-hidden="true" tabindex="-1"></a>                    nn.init.zeros_(m.bias)</span>
<span id="cb9-77"><a href="#cb9-77" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.BatchNorm2d):</span>
<span id="cb9-78"><a href="#cb9-78" aria-hidden="true" tabindex="-1"></a>                nn.init.ones_(m.weight)</span>
<span id="cb9-79"><a href="#cb9-79" aria-hidden="true" tabindex="-1"></a>                nn.init.zeros_(m.bias)</span>
<span id="cb9-80"><a href="#cb9-80" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.Linear):</span>
<span id="cb9-81"><a href="#cb9-81" aria-hidden="true" tabindex="-1"></a>                nn.init.normal_(m.weight, <span class="dv">0</span>, <span class="fl">0.01</span>)</span>
<span id="cb9-82"><a href="#cb9-82" aria-hidden="true" tabindex="-1"></a>                nn.init.zeros_(m.bias)</span>
<span id="cb9-83"><a href="#cb9-83" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-84"><a href="#cb9-84" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-85"><a href="#cb9-85" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv1(x)</span>
<span id="cb9-86"><a href="#cb9-86" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb9-87"><a href="#cb9-87" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.avgpool(x)</span>
<span id="cb9-88"><a href="#cb9-88" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb9-89"><a href="#cb9-89" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb9-90"><a href="#cb9-90" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb9-91"><a href="#cb9-91" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb9-92"><a href="#cb9-92" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-93"><a href="#cb9-93" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimized Depthwise Separable Convolution with better efficiency</span></span>
<span id="cb9-94"><a href="#cb9-94" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DepthwiseSeparableConv(nn.Module):</span>
<span id="cb9-95"><a href="#cb9-95" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, stride<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb9-96"><a href="#cb9-96" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DepthwiseSeparableConv, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb9-97"><a href="#cb9-97" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-98"><a href="#cb9-98" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-99"><a href="#cb9-99" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Depthwise convolution</span></span>
<span id="cb9-100"><a href="#cb9-100" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, in_channels, <span class="dv">3</span>, stride, <span class="dv">1</span>, </span>
<span id="cb9-101"><a href="#cb9-101" aria-hidden="true" tabindex="-1"></a>                     groups<span class="op">=</span>in_channels, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb9-102"><a href="#cb9-102" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(in_channels),</span>
<span id="cb9-103"><a href="#cb9-103" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb9-104"><a href="#cb9-104" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-105"><a href="#cb9-105" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Pointwise convolution</span></span>
<span id="cb9-106"><a href="#cb9-106" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">0</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb9-107"><a href="#cb9-107" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(out_channels),</span>
<span id="cb9-108"><a href="#cb9-108" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb9-109"><a href="#cb9-109" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-110"><a href="#cb9-110" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-111"><a href="#cb9-111" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-112"><a href="#cb9-112" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.conv(x)</span>
<span id="cb9-113"><a href="#cb9-113" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-114"><a href="#cb9-114" aria-hidden="true" tabindex="-1"></a><span class="co"># Model factory function</span></span>
<span id="cb9-115"><a href="#cb9-115" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mobilenet_v1(num_classes<span class="op">=</span><span class="dv">1000</span>, width_mult<span class="op">=</span><span class="fl">1.0</span>, pretrained<span class="op">=</span><span class="va">False</span>):</span>
<span id="cb9-116"><a href="#cb9-116" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb9-117"><a href="#cb9-117" aria-hidden="true" tabindex="-1"></a><span class="co">    Create MobileNetV1 model.</span></span>
<span id="cb9-118"><a href="#cb9-118" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb9-119"><a href="#cb9-119" aria-hidden="true" tabindex="-1"></a><span class="co">    Args:</span></span>
<span id="cb9-120"><a href="#cb9-120" aria-hidden="true" tabindex="-1"></a><span class="co">        num_classes: Number of classes for classification</span></span>
<span id="cb9-121"><a href="#cb9-121" aria-hidden="true" tabindex="-1"></a><span class="co">        width_mult: Width multiplier (0.25, 0.5, 0.75, 1.0)</span></span>
<span id="cb9-122"><a href="#cb9-122" aria-hidden="true" tabindex="-1"></a><span class="co">        pretrained: Load pretrained weights (if available)</span></span>
<span id="cb9-123"><a href="#cb9-123" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb9-124"><a href="#cb9-124" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> MobileNetV1(num_classes<span class="op">=</span>num_classes, width_mult<span class="op">=</span>width_mult)</span>
<span id="cb9-125"><a href="#cb9-125" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-126"><a href="#cb9-126" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> pretrained:</span>
<span id="cb9-127"><a href="#cb9-127" aria-hidden="true" tabindex="-1"></a>        <span class="co"># In practice, you would load pretrained weights here</span></span>
<span id="cb9-128"><a href="#cb9-128" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Loading pretrained MobileNetV1 with width_mult=</span><span class="sc">{</span>width_mult<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-129"><a href="#cb9-129" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-130"><a href="#cb9-130" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span></code></pre></div></div>
</div>
</section>
<section id="using-pre-trained-mobilenet" class="level2">
<h2 class="anchored" data-anchor-id="using-pre-trained-mobilenet" id="using-pre-trained-mobilenet">Using Pre-trained MobileNet</h2>
<section id="with-pytorch-torchvision" class="level3">
<h3 class="anchored" data-anchor-id="with-pytorch-torchvision" id="with-pytorch-torchvision">With PyTorch (torchvision)</h3>
<div id="pretrained-usage" class="cell" data-caption="Using pre-trained MobileNet for inference" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.models <span class="im">as</span> models</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained MobileNetV2</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Preprocessing pipeline</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>preprocess <span class="op">=</span> transforms.Compose([</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>                        std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Inference function</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> predict_image(image_path, model, preprocess, top_k<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Predict top-k classes for an image."""</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load and preprocess image</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).convert(<span class="st">'RGB'</span>)</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    input_tensor <span class="op">=</span> preprocess(image)</span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    input_batch <span class="op">=</span> input_tensor.unsqueeze(<span class="dv">0</span>)  <span class="co"># Add batch dimension</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(input_batch)</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>        probabilities <span class="op">=</span> torch.nn.functional.softmax(output[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get top-k predictions</span></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>    top_prob, top_indices <span class="op">=</span> torch.topk(probabilities, top_k)</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [(idx.item(), prob.item()) <span class="cf">for</span> idx, prob <span class="kw">in</span> <span class="bu">zip</span>(top_indices, top_prob)]</span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a><span class="co"># predictions = predict_image('cat.jpg', model, preprocess)</span></span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a><span class="co"># print(predictions)</span></span></code></pre></div></div>
</div>
</section>
<section id="fine-tuning-pre-trained-mobilenet" class="level3">
<h3 class="anchored" data-anchor-id="fine-tuning-pre-trained-mobilenet" id="fine-tuning-pre-trained-mobilenet">Fine-tuning Pre-trained MobileNet</h3>
<div id="fine-tuning" class="cell" data-caption="Fine-tuning MobileNet for custom classification tasks" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_mobilenet_classifier(num_classes, pretrained<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create MobileNet for custom classification task."""</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load pre-trained model</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span>pretrained)</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Modify classifier for custom number of classes</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    model.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        nn.Dropout(<span class="fl">0.2</span>),</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        nn.Linear(model.last_channel, num_classes),</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Training setup for fine-tuning</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_training(model, num_classes, learning_rate<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Setup optimizer and loss function for fine-tuning."""</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Freeze feature extraction layers (optional)</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> param <span class="kw">in</span> model.features.parameters():</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Only train classifier</span></span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.Adam(model.classifier.parameters(), lr<span class="op">=</span>learning_rate)</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> optimizer, criterion</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop</span></span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_epoch(model, dataloader, optimizer, criterion, device):</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Train model for one epoch."""</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>    running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> inputs, labels <span class="kw">in</span> dataloader:</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        inputs, labels <span class="op">=</span> inputs.to(device), labels.to(device)</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(inputs)</span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>        _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs.data, <span class="dv">1</span>)</span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">+=</span> (predicted <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> running_loss <span class="op">/</span> <span class="bu">len</span>(dataloader), <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span></code></pre></div></div>
</div>
</section>
</section>
<section id="training-mobilenet" class="level2">
<h2 class="anchored" data-anchor-id="training-mobilenet" id="training-mobilenet">Training MobileNet</h2>
<section id="complete-training-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="complete-training-pipeline" id="complete-training-pipeline">Complete Training Pipeline</h3>
<div id="training-pipeline" class="cell" data-caption="Complete training pipeline for MobileNet" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetTrainer:</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, device<span class="op">=</span><span class="st">'cuda'</span>):</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model.to(device)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.history <span class="op">=</span> {<span class="st">'train_loss'</span>: [], <span class="st">'train_acc'</span>: [], <span class="st">'val_loss'</span>: [], <span class="st">'val_acc'</span>: []}</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train(<span class="va">self</span>, train_loader, val_loader, epochs<span class="op">=</span><span class="dv">10</span>, lr<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Complete training pipeline."""</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup optimizer and scheduler</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.RMSprop(<span class="va">self</span>.model.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        scheduler <span class="op">=</span> optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">7</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>        best_acc <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">'</span>)</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">'-'</span> <span class="op">*</span> <span class="dv">10</span>)</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Training phase</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>            train_loss, train_acc <span class="op">=</span> <span class="va">self</span>._train_epoch(train_loader, optimizer, criterion)</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validation phase</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>            val_loss, val_acc <span class="op">=</span> <span class="va">self</span>._validate_epoch(val_loader, criterion)</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update scheduler</span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>            scheduler.step()</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Save best model</span></span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> val_acc <span class="op">&gt;</span> best_acc:</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>                best_acc <span class="op">=</span> val_acc</span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>                torch.save(<span class="va">self</span>.model.state_dict(), <span class="st">'best_mobilenet.pth'</span>)</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update history</span></span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.history[<span class="st">'train_loss'</span>].append(train_loss)</span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.history[<span class="st">'train_acc'</span>].append(train_acc)</span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.history[<span class="st">'val_loss'</span>].append(val_loss)</span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.history[<span class="st">'val_acc'</span>].append(val_acc)</span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss">, Train Acc: </span><span class="sc">{</span>train_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">, Val Acc: </span><span class="sc">{</span>val_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>()</span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _train_epoch(<span class="va">self</span>, dataloader, optimizer, criterion):</span>
<span id="cb12-53"><a href="#cb12-53" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train for one epoch."""</span></span>
<span id="cb12-54"><a href="#cb12-54" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.train()</span>
<span id="cb12-55"><a href="#cb12-55" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb12-56"><a href="#cb12-56" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-57"><a href="#cb12-57" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-58"><a href="#cb12-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-59"><a href="#cb12-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> inputs, labels <span class="kw">in</span> dataloader:</span>
<span id="cb12-60"><a href="#cb12-60" aria-hidden="true" tabindex="-1"></a>            inputs, labels <span class="op">=</span> inputs.to(<span class="va">self</span>.device), labels.to(<span class="va">self</span>.device)</span>
<span id="cb12-61"><a href="#cb12-61" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-62"><a href="#cb12-62" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb12-63"><a href="#cb12-63" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(inputs)</span>
<span id="cb12-64"><a href="#cb12-64" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb12-65"><a href="#cb12-65" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb12-66"><a href="#cb12-66" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb12-67"><a href="#cb12-67" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-68"><a href="#cb12-68" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb12-69"><a href="#cb12-69" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs, <span class="dv">1</span>)</span>
<span id="cb12-70"><a href="#cb12-70" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb12-71"><a href="#cb12-71" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (predicted <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb12-72"><a href="#cb12-72" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-73"><a href="#cb12-73" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> running_loss <span class="op">/</span> <span class="bu">len</span>(dataloader), <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span>
<span id="cb12-74"><a href="#cb12-74" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-75"><a href="#cb12-75" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _validate_epoch(<span class="va">self</span>, dataloader, criterion):</span>
<span id="cb12-76"><a href="#cb12-76" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Validate for one epoch."""</span></span>
<span id="cb12-77"><a href="#cb12-77" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb12-78"><a href="#cb12-78" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb12-79"><a href="#cb12-79" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-80"><a href="#cb12-80" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-81"><a href="#cb12-81" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-82"><a href="#cb12-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb12-83"><a href="#cb12-83" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> inputs, labels <span class="kw">in</span> dataloader:</span>
<span id="cb12-84"><a href="#cb12-84" aria-hidden="true" tabindex="-1"></a>                inputs, labels <span class="op">=</span> inputs.to(<span class="va">self</span>.device), labels.to(<span class="va">self</span>.device)</span>
<span id="cb12-85"><a href="#cb12-85" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb12-86"><a href="#cb12-86" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.model(inputs)</span>
<span id="cb12-87"><a href="#cb12-87" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb12-88"><a href="#cb12-88" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb12-89"><a href="#cb12-89" aria-hidden="true" tabindex="-1"></a>                running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb12-90"><a href="#cb12-90" aria-hidden="true" tabindex="-1"></a>                _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs, <span class="dv">1</span>)</span>
<span id="cb12-91"><a href="#cb12-91" aria-hidden="true" tabindex="-1"></a>                total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb12-92"><a href="#cb12-92" aria-hidden="true" tabindex="-1"></a>                correct <span class="op">+=</span> (predicted <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb12-93"><a href="#cb12-93" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-94"><a href="#cb12-94" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> running_loss <span class="op">/</span> <span class="bu">len</span>(dataloader), <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span></code></pre></div></div>
</div>
</section>
<section id="data-loading-and-augmentation" class="level3">
<h3 class="anchored" data-anchor-id="data-loading-and-augmentation" id="data-loading-and-augmentation">Data Loading and Augmentation</h3>
<div id="data-loading" class="cell" data-caption="Data loading and augmentation pipeline" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Data loading and augmentation</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_dataloaders(data_dir, batch_size<span class="op">=</span><span class="dv">32</span>, num_workers<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create training and validation dataloaders."""</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data augmentation for training</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    train_transforms <span class="op">=</span> transforms.Compose([</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        transforms.RandomResizedCrop(<span class="dv">224</span>),</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        transforms.RandomHorizontalFlip(),</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        transforms.ColorJitter(brightness<span class="op">=</span><span class="fl">0.2</span>, contrast<span class="op">=</span><span class="fl">0.2</span>, saturation<span class="op">=</span><span class="fl">0.2</span>, hue<span class="op">=</span><span class="fl">0.1</span>),</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize([<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], [<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validation transforms</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    val_transforms <span class="op">=</span> transforms.Compose([</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize([<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], [<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create datasets</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> datasets.ImageFolder(<span class="ss">f'</span><span class="sc">{</span>data_dir<span class="sc">}</span><span class="ss">/train'</span>, train_transforms)</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>    val_dataset <span class="op">=</span> datasets.ImageFolder(<span class="ss">f'</span><span class="sc">{</span>data_dir<span class="sc">}</span><span class="ss">/val'</span>, val_transforms)</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create dataloaders</span></span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>batch_size, </span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>                             shuffle<span class="op">=</span><span class="va">True</span>, num_workers<span class="op">=</span>num_workers)</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    val_loader <span class="op">=</span> DataLoader(val_dataset, batch_size<span class="op">=</span>batch_size, </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>                           shuffle<span class="op">=</span><span class="va">False</span>, num_workers<span class="op">=</span>num_workers)</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_loader, val_loader, <span class="bu">len</span>(train_dataset.classes)</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup</span></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load data</span></span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># train_loader, val_loader, num_classes = get_dataloaders('path/to/data')</span></span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model</span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>    <span class="co"># model = mobilenet_v1(num_classes=num_classes, width_mult=1.0)</span></span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>    <span class="co"># trainer = MobileNetTrainer(model, device)</span></span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># trainer.train(train_loader, val_loader, epochs=20)</span></span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="mobilenet-variants" class="level2">
<h2 class="anchored" data-anchor-id="mobilenet-variants" id="mobilenet-variants">MobileNet Variants</h2>
<section id="mobilenetv2-implementation" class="level3">
<h3 class="anchored" data-anchor-id="mobilenetv2-implementation" id="mobilenetv2-implementation">MobileNetV2 Implementation</h3>
<div id="mobilenetv2" class="cell" data-caption="MobileNetV2 with inverted residual blocks" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> InvertedResidual(nn.Module):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Inverted residual block for MobileNetV2."""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, stride, expand_ratio):</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(InvertedResidual, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        hidden_dim <span class="op">=</span> <span class="bu">int</span>(<span class="bu">round</span>(in_channels <span class="op">*</span> expand_ratio))</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.use_residual <span class="op">=</span> stride <span class="op">==</span> <span class="dv">1</span> <span class="kw">and</span> in_channels <span class="op">==</span> out_channels</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        layers <span class="op">=</span> []</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Expansion phase</span></span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> expand_ratio <span class="op">!=</span> <span class="dv">1</span>:</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>            layers.extend([</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(in_channels, hidden_dim, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>                nn.BatchNorm2d(hidden_dim),</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>                nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>            ])</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Depthwise convolution</span></span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        layers.extend([</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(hidden_dim, hidden_dim, <span class="dv">3</span>, stride, <span class="dv">1</span>, </span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>                     groups<span class="op">=</span>hidden_dim, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(hidden_dim),</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Pointwise linear projection</span></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(hidden_dim, out_channels, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(out_channels),</span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Sequential(<span class="op">*</span>layers)</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.use_residual:</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> x <span class="op">+</span> <span class="va">self</span>.conv(x)</span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.conv(x)</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetV2(nn.Module):</span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""MobileNetV2 with inverted residuals and linear bottlenecks."""</span></span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">1000</span>, width_mult<span class="op">=</span><span class="fl">1.0</span>):</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(MobileNetV2, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        input_channel <span class="op">=</span> <span class="dv">32</span></span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>        last_channel <span class="op">=</span> <span class="dv">1280</span></span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inverted residual settings</span></span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># t: expansion factor, c: output channels, n: number of blocks, s: stride</span></span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>        inverted_residual_setting <span class="op">=</span> [</span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a>            <span class="co"># t, c, n, s</span></span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">1</span>, <span class="dv">16</span>, <span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">24</span>, <span class="dv">2</span>, <span class="dv">2</span>],</span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">2</span>],</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">64</span>, <span class="dv">4</span>, <span class="dv">2</span>],</span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">96</span>, <span class="dv">3</span>, <span class="dv">1</span>],</span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">160</span>, <span class="dv">3</span>, <span class="dv">2</span>],</span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>            [<span class="dv">6</span>, <span class="dv">320</span>, <span class="dv">1</span>, <span class="dv">1</span>],</span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb14-61"><a href="#cb14-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-62"><a href="#cb14-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply width multiplier</span></span>
<span id="cb14-63"><a href="#cb14-63" aria-hidden="true" tabindex="-1"></a>        input_channel <span class="op">=</span> <span class="bu">int</span>(input_channel <span class="op">*</span> width_mult)</span>
<span id="cb14-64"><a href="#cb14-64" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.last_channel <span class="op">=</span> <span class="bu">int</span>(last_channel <span class="op">*</span> <span class="bu">max</span>(<span class="fl">1.0</span>, width_mult))</span>
<span id="cb14-65"><a href="#cb14-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-66"><a href="#cb14-66" aria-hidden="true" tabindex="-1"></a>        <span class="co"># First convolution</span></span>
<span id="cb14-67"><a href="#cb14-67" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> [nn.Sequential(</span>
<span id="cb14-68"><a href="#cb14-68" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, input_channel, <span class="dv">3</span>, <span class="dv">2</span>, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb14-69"><a href="#cb14-69" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(input_channel),</span>
<span id="cb14-70"><a href="#cb14-70" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-71"><a href="#cb14-71" aria-hidden="true" tabindex="-1"></a>        )]</span>
<span id="cb14-72"><a href="#cb14-72" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-73"><a href="#cb14-73" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inverted residual blocks</span></span>
<span id="cb14-74"><a href="#cb14-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> t, c, n, s <span class="kw">in</span> inverted_residual_setting:</span>
<span id="cb14-75"><a href="#cb14-75" aria-hidden="true" tabindex="-1"></a>            output_channel <span class="op">=</span> <span class="bu">int</span>(c <span class="op">*</span> width_mult)</span>
<span id="cb14-76"><a href="#cb14-76" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n):</span>
<span id="cb14-77"><a href="#cb14-77" aria-hidden="true" tabindex="-1"></a>                stride <span class="op">=</span> s <span class="cf">if</span> i <span class="op">==</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">1</span></span>
<span id="cb14-78"><a href="#cb14-78" aria-hidden="true" tabindex="-1"></a>                features.append(InvertedResidual(input_channel, output_channel, </span>
<span id="cb14-79"><a href="#cb14-79" aria-hidden="true" tabindex="-1"></a>                                               stride, t))</span>
<span id="cb14-80"><a href="#cb14-80" aria-hidden="true" tabindex="-1"></a>                input_channel <span class="op">=</span> output_channel</span>
<span id="cb14-81"><a href="#cb14-81" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-82"><a href="#cb14-82" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Last convolution</span></span>
<span id="cb14-83"><a href="#cb14-83" aria-hidden="true" tabindex="-1"></a>        features.append(nn.Sequential(</span>
<span id="cb14-84"><a href="#cb14-84" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(input_channel, <span class="va">self</span>.last_channel, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb14-85"><a href="#cb14-85" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="va">self</span>.last_channel),</span>
<span id="cb14-86"><a href="#cb14-86" aria-hidden="true" tabindex="-1"></a>            nn.ReLU6(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-87"><a href="#cb14-87" aria-hidden="true" tabindex="-1"></a>        ))</span>
<span id="cb14-88"><a href="#cb14-88" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-89"><a href="#cb14-89" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(<span class="op">*</span>features)</span>
<span id="cb14-90"><a href="#cb14-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-91"><a href="#cb14-91" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classifier</span></span>
<span id="cb14-92"><a href="#cb14-92" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb14-93"><a href="#cb14-93" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.2</span>),</span>
<span id="cb14-94"><a href="#cb14-94" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="va">self</span>.last_channel, num_classes),</span>
<span id="cb14-95"><a href="#cb14-95" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb14-96"><a href="#cb14-96" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-97"><a href="#cb14-97" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-98"><a href="#cb14-98" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb14-99"><a href="#cb14-99" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> nn.functional.adaptive_avg_pool2d(x, (<span class="dv">1</span>, <span class="dv">1</span>))</span>
<span id="cb14-100"><a href="#cb14-100" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb14-101"><a href="#cb14-101" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb14-102"><a href="#cb14-102" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</div>
</section>
<section id="mobilenetv3-features" class="level3">
<h3 class="anchored" data-anchor-id="mobilenetv3-features" id="mobilenetv3-features">MobileNetV3 Features</h3>
<div id="mobilenetv3-components" class="cell" data-caption="Key components of MobileNetV3" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SEBlock(nn.Module):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Squeeze-and-Excitation block."""</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, reduction<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(SEBlock, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.se <span class="op">=</span> nn.Sequential(</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>            nn.AdaptiveAvgPool2d(<span class="dv">1</span>),</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, in_channels <span class="op">//</span> reduction, <span class="dv">1</span>),</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels <span class="op">//</span> reduction, in_channels, <span class="dv">1</span>),</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            nn.Hardsigmoid(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> <span class="va">self</span>.se(x)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> HardSwish(nn.Module):</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Hard Swish activation function."""</span></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> F.hardsigmoid(x)</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a><span class="co"># MobileNetV3 would use these components along with:</span></span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a><span class="co"># - Neural Architecture Search (NAS) for optimal architecture</span></span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a><span class="co"># - Hard Swish activation instead of ReLU6</span></span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a><span class="co"># - Squeeze-and-Excitation blocks</span></span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a><span class="co"># - Optimized last layers</span></span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>MobileNetV3 Improvements
</div>
</div>
<div class="callout-body-container callout-body">
<p>MobileNetV3 incorporates several advanced techniques:</p>
<ul>
<li><strong>Neural Architecture Search</strong> for optimal layer configurations</li>
<li><strong>Squeeze-and-Excitation blocks</strong> for attention mechanisms</li>
<li><strong>Hard Swish activation</strong> for better performance</li>
<li><strong>Optimized head and tail</strong> layers for efficiency</li>
</ul>
</div>
</div>
</section>
</section>
<section id="optimization-techniques" class="level2">
<h2 class="anchored" data-anchor-id="optimization-techniques" id="optimization-techniques">Optimization Techniques</h2>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h3>
<div id="quantization" class="cell" data-caption="Post-training and dynamic quantization techniques" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization <span class="im">as</span> quantization</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> quantize_mobilenet(model, calibration_loader):</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply post-training quantization to MobileNet."""</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set model to evaluation mode</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Fuse modules for better quantization</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    model_fused <span class="op">=</span> torch.quantization.fuse_modules(model, [</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>        [<span class="st">'conv'</span>, <span class="st">'bn'</span>, <span class="st">'relu'</span>] <span class="cf">for</span> conv, bn, relu <span class="kw">in</span> model.named_modules()</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(conv, nn.Conv2d) <span class="kw">and</span> <span class="bu">isinstance</span>(bn, nn.BatchNorm2d)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set quantization config</span></span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    model_fused.qconfig <span class="op">=</span> quantization.get_default_qconfig(<span class="st">'qnnpack'</span>)</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare model for quantization</span></span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>    model_prepared <span class="op">=</span> quantization.prepare(model_fused)</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calibrate with representative data</span></span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> inputs, _ <span class="kw">in</span> calibration_loader:</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>            model_prepared(inputs)</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to quantized model</span></span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>    model_quantized <span class="op">=</span> quantization.convert(model_prepared)</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model_quantized</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Dynamic quantization (easier but less optimal)</span></span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> dynamic_quantize_mobilenet(model):</span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply dynamic quantization."""</span></span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> quantization.quantize_dynamic(</span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a>        model, </span>
<span id="cb16-36"><a href="#cb16-36" aria-hidden="true" tabindex="-1"></a>        {nn.Linear, nn.Conv2d}, </span>
<span id="cb16-37"><a href="#cb16-37" aria-hidden="true" tabindex="-1"></a>        dtype<span class="op">=</span>torch.qint8</span>
<span id="cb16-38"><a href="#cb16-38" aria-hidden="true" tabindex="-1"></a>    )</span></code></pre></div></div>
</div>
</section>
<section id="pruning" class="level3">
<h3 class="anchored" data-anchor-id="pruning" id="pruning">Pruning</h3>
<div id="pruning" class="cell" data-caption="Magnitude-based and structured pruning techniques" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.prune <span class="im">as</span> prune</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> prune_mobilenet(model, pruning_ratio<span class="op">=</span><span class="fl">0.2</span>):</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply magnitude-based pruning to MobileNet."""</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    parameters_to_prune <span class="op">=</span> []</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Collect Conv2d and Linear layers for pruning</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(module, (nn.Conv2d, nn.Linear)):</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>            parameters_to_prune.append((module, <span class="st">'weight'</span>))</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply global magnitude pruning</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    prune.global_unstructured(</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        parameters_to_prune,</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        pruning_method<span class="op">=</span>prune.L1Unstructured,</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        amount<span class="op">=</span>pruning_ratio,</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Make pruning permanent</span></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> module, param_name <span class="kw">in</span> parameters_to_prune:</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>        prune.remove(module, param_name)</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Structured pruning example</span></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> structured_prune_mobilenet(model, pruning_ratio<span class="op">=</span><span class="fl">0.2</span>):</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply structured channel pruning."""</span></span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(module, nn.Conv2d) <span class="kw">and</span> module.groups <span class="op">==</span> <span class="dv">1</span>:  <span class="co"># Skip depthwise</span></span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>            prune.ln_structured(</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>                module, </span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>                name<span class="op">=</span><span class="st">'weight'</span>, </span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>                amount<span class="op">=</span>pruning_ratio, </span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>                n<span class="op">=</span><span class="dv">2</span>, </span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>                dim<span class="op">=</span><span class="dv">0</span>  <span class="co"># Prune output channels</span></span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span></code></pre></div></div>
</div>
</section>
</section>
<section id="deployment-considerations" class="level2">
<h2 class="anchored" data-anchor-id="deployment-considerations" id="deployment-considerations">Deployment Considerations</h2>
<section id="onnx-export" class="level3">
<h3 class="anchored" data-anchor-id="onnx-export" id="onnx-export">ONNX Export</h3>
<div id="onnx-export" class="cell" data-caption="Exporting MobileNet to ONNX format for cross-platform deployment" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.onnx</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> export_to_onnx(model, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>), onnx_path<span class="op">=</span><span class="st">"mobilenet.onnx"</span>):</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Export MobileNet to ONNX format."""</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    dummy_input <span class="op">=</span> torch.randn(input_shape)</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    torch.onnx.export(</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        model,</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        dummy_input,</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        onnx_path,</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        export_params<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        opset_version<span class="op">=</span><span class="dv">11</span>,</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        do_constant_folding<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>        input_names<span class="op">=</span>[<span class="st">'input'</span>],</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>        output_names<span class="op">=</span>[<span class="st">'output'</span>],</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        dynamic_axes<span class="op">=</span>{</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>            <span class="st">'input'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>},</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'output'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>}</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Model exported to </span><span class="sc">{</span>onnx_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a><span class="co"># TensorRT optimization (requires tensorrt)</span></span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> optimize_with_tensorrt(onnx_path):</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Optimize ONNX model with TensorRT."""</span></span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> tensorrt <span class="im">as</span> trt</span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create TensorRT logger and builder</span></span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>        logger <span class="op">=</span> trt.Logger(trt.Logger.WARNING)</span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>        builder <span class="op">=</span> trt.Builder(logger)</span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Parse ONNX model</span></span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>        network <span class="op">=</span> builder.create_network(<span class="dv">1</span> <span class="op">&lt;&lt;</span> <span class="bu">int</span>(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))</span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        parser <span class="op">=</span> trt.OnnxParser(network, logger)</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(onnx_path, <span class="st">'rb'</span>) <span class="im">as</span> model:</span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>            parser.parse(model.read())</span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build optimized engine</span></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>        config <span class="op">=</span> builder.create_builder_config()</span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>        config.max_workspace_size <span class="op">=</span> <span class="dv">1</span> <span class="op">&lt;&lt;</span> <span class="dv">28</span>  <span class="co"># 256MB</span></span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>        config.set_flag(trt.BuilderFlag.FP16)  <span class="co"># Enable FP16 precision</span></span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>        engine <span class="op">=</span> builder.build_engine(network, config)</span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save engine</span></span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="st">"mobilenet.trt"</span>, <span class="st">"wb"</span>) <span class="im">as</span> f:</span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>            f.write(engine.serialize())</span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> engine</span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">ImportError</span>:</span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"TensorRT not installed. Please install TensorRT for optimization."</span>)</span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span></code></pre></div></div>
</div>
</section>
<section id="mobile-deployment" class="level3">
<h3 class="anchored" data-anchor-id="mobile-deployment" id="mobile-deployment">Mobile Deployment</h3>
<div id="mobile-deployment" class="cell" data-caption="Converting models for mobile deployment platforms" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TensorFlow Lite conversion (if using TensorFlow)</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> convert_to_tflite(model_path, tflite_path<span class="op">=</span><span class="st">"mobilenet.tflite"</span>):</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert model to TensorFlow Lite format."""</span></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> tensorflow <span class="im">as</span> tf</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load model (assuming saved as TensorFlow model)</span></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        converter <span class="op">=</span> tf.lite.TFLiteConverter.from_saved_model(model_path)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Optimization settings</span></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>        converter.optimizations <span class="op">=</span> [tf.lite.Optimize.DEFAULT]</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        converter.target_spec.supported_types <span class="op">=</span> [tf.float16]</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convert</span></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>        tflite_model <span class="op">=</span> converter.convert()</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(tflite_path, <span class="st">'wb'</span>) <span class="im">as</span> f:</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>            f.write(tflite_model)</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"TFLite model saved to </span><span class="sc">{</span>tflite_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">ImportError</span>:</span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"TensorFlow not installed. Install with: pip install tensorflow"</span>)</span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a><span class="co"># CoreML conversion (for iOS)</span></span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> convert_to_coreml(model, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert PyTorch model to CoreML format."""</span></span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> coremltools <span class="im">as</span> ct</span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>        example_input <span class="op">=</span> torch.rand(input_shape)</span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Trace the model</span></span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a>        traced_model <span class="op">=</span> torch.jit.trace(model, example_input)</span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convert to CoreML</span></span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a>        coreml_model <span class="op">=</span> ct.convert(</span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a>            traced_model,</span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a>            inputs<span class="op">=</span>[ct.ImageType(shape<span class="op">=</span>input_shape, bias<span class="op">=</span>[<span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>], scale<span class="op">=</span><span class="dv">1</span><span class="op">/</span><span class="fl">127.5</span>)]</span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save</span></span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a>        coreml_model.save(<span class="st">"MobileNet.mlmodel"</span>)</span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"CoreML model saved successfully"</span>)</span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">ImportError</span>:</span>
<span id="cb19-48"><a href="#cb19-48" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"coremltools not installed. Install with: pip install coremltools"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="edge-deployment-with-optimization" class="level3">
<h3 class="anchored" data-anchor-id="edge-deployment-with-optimization" id="edge-deployment-with-optimization">Edge Deployment with Optimization</h3>
<div id="edge-deployment" class="cell" data-caption="Optimized inference class for edge deployment" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Edge deployment with optimization</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedMobileNetInference:</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Optimized inference class for edge deployment."""</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>.load_optimized_model(model_path)</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.preprocess <span class="op">=</span> <span class="va">self</span>.get_preprocessing()</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_optimized_model(<span class="va">self</span>, model_path):</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load and optimize model for inference."""</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> torch.load(model_path, map_location<span class="op">=</span><span class="va">self</span>.device)</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply optimizations</span></span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.device <span class="op">==</span> <span class="st">'cpu'</span>:</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Optimize for CPU inference</span></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>            model <span class="op">=</span> torch.jit.optimize_for_inference(torch.jit.script(model))</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> model</span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_preprocessing(<span class="va">self</span>):</span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get optimized preprocessing pipeline."""</span></span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> transforms.Compose([</span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>, interpolation<span class="op">=</span>transforms.InterpolationMode.BILINEAR),</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize([<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], [<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>    <span class="at">@torch.no_grad</span>()</span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, image):</span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Fast inference on single image."""</span></span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(image, <span class="bu">str</span>):</span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>            <span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(image).convert(<span class="st">'RGB'</span>)</span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess</span></span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> <span class="va">self</span>.preprocess(image).unsqueeze(<span class="dv">0</span>).to(<span class="va">self</span>.device)</span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference</span></span>
<span id="cb20-42"><a href="#cb20-42" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.model(input_tensor)</span>
<span id="cb20-43"><a href="#cb20-43" aria-hidden="true" tabindex="-1"></a>        probabilities <span class="op">=</span> F.softmax(output[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb20-44"><a href="#cb20-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-45"><a href="#cb20-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> probabilities.cpu().numpy()</span>
<span id="cb20-46"><a href="#cb20-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-47"><a href="#cb20-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> batch_predict(<span class="va">self</span>, images, batch_size<span class="op">=</span><span class="dv">32</span>):</span>
<span id="cb20-48"><a href="#cb20-48" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Batch inference for multiple images."""</span></span>
<span id="cb20-49"><a href="#cb20-49" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb20-50"><a href="#cb20-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-51"><a href="#cb20-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(images), batch_size):</span>
<span id="cb20-52"><a href="#cb20-52" aria-hidden="true" tabindex="-1"></a>            batch <span class="op">=</span> images[i:i<span class="op">+</span>batch_size]</span>
<span id="cb20-53"><a href="#cb20-53" aria-hidden="true" tabindex="-1"></a>            batch_tensor <span class="op">=</span> torch.stack([</span>
<span id="cb20-54"><a href="#cb20-54" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.preprocess(img) <span class="cf">for</span> img <span class="kw">in</span> batch</span>
<span id="cb20-55"><a href="#cb20-55" aria-hidden="true" tabindex="-1"></a>            ]).to(<span class="va">self</span>.device)</span>
<span id="cb20-56"><a href="#cb20-56" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-57"><a href="#cb20-57" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(batch_tensor)</span>
<span id="cb20-58"><a href="#cb20-58" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> F.softmax(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb20-59"><a href="#cb20-59" aria-hidden="true" tabindex="-1"></a>            results.extend(probabilities.cpu().numpy())</span>
<span id="cb20-60"><a href="#cb20-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-61"><a href="#cb20-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span></code></pre></div></div>
</div>
</section>
</section>
<section id="performance-analysis" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis" id="performance-analysis">Performance Analysis</h2>
<section id="benchmarking-tools" class="level3">
<h3 class="anchored" data-anchor-id="benchmarking-tools" id="benchmarking-tools">Benchmarking Tools</h3>
<div id="benchmarking" class="cell" data-caption="Comprehensive benchmarking suite for MobileNet" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetBenchmark:</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive benchmarking suite for MobileNet."""</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model.to(device)</span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>    <span class="at">@contextmanager</span></span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> timer(<span class="va">self</span>):</span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Context manager for timing operations."""</span></span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>        start <span class="op">=</span> time.time()</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a>        end <span class="op">=</span> time.time()</span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.last_time <span class="op">=</span> end <span class="op">-</span> start</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> benchmark_inference(<span class="va">self</span>, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>), num_runs<span class="op">=</span><span class="dv">100</span>, warmup<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Benchmark inference speed."""</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a>        dummy_input <span class="op">=</span> torch.randn(input_shape).to(<span class="va">self</span>.device)</span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(warmup):</span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>                _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark</span></span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>        times <span class="op">=</span> []</span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_runs):</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> <span class="va">self</span>.timer():</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a>                    _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>                times.append(<span class="va">self</span>.last_time)</span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a>            <span class="st">'mean_time'</span>: np.mean(times),</span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a>            <span class="st">'std_time'</span>: np.std(times),</span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a>            <span class="st">'min_time'</span>: np.<span class="bu">min</span>(times),</span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_time'</span>: np.<span class="bu">max</span>(times),</span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">'fps'</span>: <span class="fl">1.0</span> <span class="op">/</span> np.mean(times)</span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> benchmark_memory(<span class="va">self</span>, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Benchmark memory usage."""</span></span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.device <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb21-49"><a href="#cb21-49" aria-hidden="true" tabindex="-1"></a>            torch.cuda.empty_cache()</span>
<span id="cb21-50"><a href="#cb21-50" aria-hidden="true" tabindex="-1"></a>            torch.cuda.reset_peak_memory_stats()</span>
<span id="cb21-51"><a href="#cb21-51" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-52"><a href="#cb21-52" aria-hidden="true" tabindex="-1"></a>            dummy_input <span class="op">=</span> torch.randn(input_shape).to(<span class="va">self</span>.device)</span>
<span id="cb21-53"><a href="#cb21-53" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-54"><a href="#cb21-54" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-55"><a href="#cb21-55" aria-hidden="true" tabindex="-1"></a>                _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb21-56"><a href="#cb21-56" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-57"><a href="#cb21-57" aria-hidden="true" tabindex="-1"></a>            memory_stats <span class="op">=</span> {</span>
<span id="cb21-58"><a href="#cb21-58" aria-hidden="true" tabindex="-1"></a>                <span class="st">'peak_memory_mb'</span>: torch.cuda.max_memory_allocated() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span>,</span>
<span id="cb21-59"><a href="#cb21-59" aria-hidden="true" tabindex="-1"></a>                <span class="st">'current_memory_mb'</span>: torch.cuda.memory_allocated() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span></span>
<span id="cb21-60"><a href="#cb21-60" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb21-61"><a href="#cb21-61" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb21-62"><a href="#cb21-62" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> memory_stats</span>
<span id="cb21-63"><a href="#cb21-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb21-64"><a href="#cb21-64" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {<span class="st">'message'</span>: <span class="st">'Memory benchmarking only available for CUDA'</span>}</span>
<span id="cb21-65"><a href="#cb21-65" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-66"><a href="#cb21-66" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> profile_layers(<span class="va">self</span>, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb21-67"><a href="#cb21-67" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Profile individual layers."""</span></span>
<span id="cb21-68"><a href="#cb21-68" aria-hidden="true" tabindex="-1"></a>        dummy_input <span class="op">=</span> torch.randn(input_shape).to(<span class="va">self</span>.device)</span>
<span id="cb21-69"><a href="#cb21-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-70"><a href="#cb21-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.profiler.profile(</span>
<span id="cb21-71"><a href="#cb21-71" aria-hidden="true" tabindex="-1"></a>            activities<span class="op">=</span>[torch.profiler.ProfilerActivity.CPU, </span>
<span id="cb21-72"><a href="#cb21-72" aria-hidden="true" tabindex="-1"></a>                       torch.profiler.ProfilerActivity.CUDA],</span>
<span id="cb21-73"><a href="#cb21-73" aria-hidden="true" tabindex="-1"></a>            record_shapes<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb21-74"><a href="#cb21-74" aria-hidden="true" tabindex="-1"></a>            profile_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb21-75"><a href="#cb21-75" aria-hidden="true" tabindex="-1"></a>            with_stack<span class="op">=</span><span class="va">True</span></span>
<span id="cb21-76"><a href="#cb21-76" aria-hidden="true" tabindex="-1"></a>        ) <span class="im">as</span> prof:</span>
<span id="cb21-77"><a href="#cb21-77" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-78"><a href="#cb21-78" aria-hidden="true" tabindex="-1"></a>                _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb21-79"><a href="#cb21-79" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-80"><a href="#cb21-80" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> prof</span>
<span id="cb21-81"><a href="#cb21-81" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-82"><a href="#cb21-82" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compare_models(<span class="va">self</span>, models_dict, input_shape<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb21-83"><a href="#cb21-83" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compare multiple model variants."""</span></span>
<span id="cb21-84"><a href="#cb21-84" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> {}</span>
<span id="cb21-85"><a href="#cb21-85" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-86"><a href="#cb21-86" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, model <span class="kw">in</span> models_dict.items():</span>
<span id="cb21-87"><a href="#cb21-87" aria-hidden="true" tabindex="-1"></a>            benchmark <span class="op">=</span> MobileNetBenchmark(model, <span class="va">self</span>.device)</span>
<span id="cb21-88"><a href="#cb21-88" aria-hidden="true" tabindex="-1"></a>            results[name] <span class="op">=</span> {</span>
<span id="cb21-89"><a href="#cb21-89" aria-hidden="true" tabindex="-1"></a>                <span class="st">'inference'</span>: benchmark.benchmark_inference(input_shape),</span>
<span id="cb21-90"><a href="#cb21-90" aria-hidden="true" tabindex="-1"></a>                <span class="st">'memory'</span>: benchmark.benchmark_memory(input_shape),</span>
<span id="cb21-91"><a href="#cb21-91" aria-hidden="true" tabindex="-1"></a>                <span class="st">'parameters'</span>: <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters()),</span>
<span id="cb21-92"><a href="#cb21-92" aria-hidden="true" tabindex="-1"></a>                <span class="st">'model_size_mb'</span>: <span class="bu">sum</span>(p.numel() <span class="op">*</span> p.element_size() </span>
<span id="cb21-93"><a href="#cb21-93" aria-hidden="true" tabindex="-1"></a>                                   <span class="cf">for</span> p <span class="kw">in</span> model.parameters()) <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span></span>
<span id="cb21-94"><a href="#cb21-94" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb21-95"><a href="#cb21-95" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-96"><a href="#cb21-96" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span></code></pre></div></div>
</div>
</section>
<section id="model-comparison" class="level3">
<h3 class="anchored" data-anchor-id="model-comparison" id="model-comparison">Model Comparison</h3>
<div id="model-comparison" class="cell" data-caption="Comparing different MobileNet variants" data-execution_count="21">
<div class="cell-output cell-output-stdout">
<pre><code>Model           Params (M)   Size (MB)  FPS      Peak Mem (MB)  
----------------------------------------------------------------------
MobileNet_1.0   4.23         16.14      41.1     N/A            
MobileNet_0.75  2.59         9.86       53.5     N/A            
MobileNet_0.5   1.33         5.08       77.5     N/A            
MobileNet_0.25  0.47         1.79       143.6    N/A            </code></pre>
</div>
</div>
</section>
<section id="accuracy-vs-efficiency-analysis" class="level3">
<h3 class="anchored" data-anchor-id="accuracy-vs-efficiency-analysis" id="accuracy-vs-efficiency-analysis">Accuracy vs Efficiency Analysis</h3>
<div id="cell-accuracy-efficiency" class="cell" data-caption="Analyzing trade-offs between accuracy and efficiency" data-execution_count="22">
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/accuracy-efficiency-output-1.png" id="accuracy-efficiency" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="real-world-deployment-simulation" class="level3">
<h3 class="anchored" data-anchor-id="real-world-deployment-simulation" id="real-world-deployment-simulation">Real-world Deployment Simulation</h3>
<div id="deployment-simulation" class="cell" data-caption="Simulating real-world deployment scenarios" data-execution_count="23">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Real-world deployment simulation</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DeploymentSimulator:</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simulate real-world deployment scenarios."""</span></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model):</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> simulate_mobile_inference(<span class="va">self</span>, num_images<span class="op">=</span><span class="dv">1000</span>, target_fps<span class="op">=</span><span class="dv">30</span>):</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Simulate mobile device inference."""</span></span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>        device <span class="op">=</span> <span class="st">'cpu'</span>  <span class="co"># Mobile devices typically use CPU</span></span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> <span class="va">self</span>.model.to(device)</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate various image sizes</span></span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>        image_sizes <span class="op">=</span> [(<span class="dv">224</span>, <span class="dv">224</span>), (<span class="dv">320</span>, <span class="dv">320</span>), (<span class="dv">416</span>, <span class="dv">416</span>)]</span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> {}</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> size <span class="kw">in</span> image_sizes:</span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>            input_tensor <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="op">*</span>size)</span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Measure inference time</span></span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a>            times <span class="op">=</span> []</span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb23-24"><a href="#cb23-24" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):  <span class="co"># Warmup and measurement</span></span>
<span id="cb23-25"><a href="#cb23-25" aria-hidden="true" tabindex="-1"></a>                    start <span class="op">=</span> time.time()</span>
<span id="cb23-26"><a href="#cb23-26" aria-hidden="true" tabindex="-1"></a>                    _ <span class="op">=</span> model(input_tensor)</span>
<span id="cb23-27"><a href="#cb23-27" aria-hidden="true" tabindex="-1"></a>                    times.append(time.time() <span class="op">-</span> start)</span>
<span id="cb23-28"><a href="#cb23-28" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb23-29"><a href="#cb23-29" aria-hidden="true" tabindex="-1"></a>            avg_time <span class="op">=</span> np.mean(times[<span class="dv">10</span>:])  <span class="co"># Skip first 10 for warmup</span></span>
<span id="cb23-30"><a href="#cb23-30" aria-hidden="true" tabindex="-1"></a>            fps <span class="op">=</span> <span class="fl">1.0</span> <span class="op">/</span> avg_time</span>
<span id="cb23-31"><a href="#cb23-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb23-32"><a href="#cb23-32" aria-hidden="true" tabindex="-1"></a>            results[<span class="ss">f'</span><span class="sc">{</span>size[<span class="dv">0</span>]<span class="sc">}</span><span class="ss">x</span><span class="sc">{</span>size[<span class="dv">1</span>]<span class="sc">}</span><span class="ss">'</span>] <span class="op">=</span> {</span>
<span id="cb23-33"><a href="#cb23-33" aria-hidden="true" tabindex="-1"></a>                <span class="st">'fps'</span>: fps,</span>
<span id="cb23-34"><a href="#cb23-34" aria-hidden="true" tabindex="-1"></a>                <span class="st">'meets_target'</span>: fps <span class="op">&gt;=</span> target_fps,</span>
<span id="cb23-35"><a href="#cb23-35" aria-hidden="true" tabindex="-1"></a>                <span class="st">'latency_ms'</span>: avg_time <span class="op">*</span> <span class="dv">1000</span></span>
<span id="cb23-36"><a href="#cb23-36" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb23-37"><a href="#cb23-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-38"><a href="#cb23-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb23-39"><a href="#cb23-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-40"><a href="#cb23-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> battery_consumption_estimate(<span class="va">self</span>, inference_time_ms, device_type<span class="op">=</span><span class="st">'mobile'</span>):</span>
<span id="cb23-41"><a href="#cb23-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Estimate battery consumption per inference."""</span></span>
<span id="cb23-42"><a href="#cb23-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-43"><a href="#cb23-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Rough estimates based on device type</span></span>
<span id="cb23-44"><a href="#cb23-44" aria-hidden="true" tabindex="-1"></a>        power_consumption <span class="op">=</span> {</span>
<span id="cb23-45"><a href="#cb23-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'mobile'</span>: <span class="fl">2.0</span>,  <span class="co"># Watts during inference</span></span>
<span id="cb23-46"><a href="#cb23-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">'edge'</span>: <span class="fl">5.0</span>,    <span class="co"># Edge devices</span></span>
<span id="cb23-47"><a href="#cb23-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">'embedded'</span>: <span class="fl">0.5</span>  <span class="co"># Low-power embedded</span></span>
<span id="cb23-48"><a href="#cb23-48" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb23-49"><a href="#cb23-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-50"><a href="#cb23-50" aria-hidden="true" tabindex="-1"></a>        power_w <span class="op">=</span> power_consumption.get(device_type, <span class="fl">2.0</span>)</span>
<span id="cb23-51"><a href="#cb23-51" aria-hidden="true" tabindex="-1"></a>        energy_per_inference <span class="op">=</span> (inference_time_ms <span class="op">/</span> <span class="dv">1000</span>) <span class="op">*</span> power_w  <span class="co"># Joules</span></span>
<span id="cb23-52"><a href="#cb23-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-53"><a href="#cb23-53" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convert to more meaningful metrics</span></span>
<span id="cb23-54"><a href="#cb23-54" aria-hidden="true" tabindex="-1"></a>        battery_capacity_wh <span class="op">=</span> <span class="dv">15</span>  <span class="co"># Typical smartphone battery ~15 Wh</span></span>
<span id="cb23-55"><a href="#cb23-55" aria-hidden="true" tabindex="-1"></a>        inferences_per_battery <span class="op">=</span> (battery_capacity_wh <span class="op">*</span> <span class="dv">3600</span>) <span class="op">/</span> energy_per_inference</span>
<span id="cb23-56"><a href="#cb23-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb23-57"><a href="#cb23-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb23-58"><a href="#cb23-58" aria-hidden="true" tabindex="-1"></a>            <span class="st">'energy_per_inference_j'</span>: energy_per_inference,</span>
<span id="cb23-59"><a href="#cb23-59" aria-hidden="true" tabindex="-1"></a>            <span class="st">'estimated_inferences_per_battery'</span>: <span class="bu">int</span>(inferences_per_battery),</span>
<span id="cb23-60"><a href="#cb23-60" aria-hidden="true" tabindex="-1"></a>            <span class="st">'power_consumption_w'</span>: power_w</span>
<span id="cb23-61"><a href="#cb23-61" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</div>
</section>
<section id="performance-metrics-dashboard" class="level3">
<h3 class="anchored" data-anchor-id="performance-metrics-dashboard" id="performance-metrics-dashboard">Performance Metrics Dashboard</h3>
<div id="cell-performance-dashboard" class="cell" data-caption="Comprehensive performance analysis dashboard" data-execution_count="24">
<div class="cell-output cell-output-stdout">
<pre><code>🔬 **Running Comprehensive MobileNet Analysis...**

🖥️  Using device: cpu</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-code/performance-dashboard-output-2.png" id="performance-dashboard" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>
🎯 **Deployment Recommendations:**

🚀 **Fastest Inference:** MobileNet_0.25 (135.3 FPS)
   ✅ Best for: Real-time applications, video processing
   📱 Recommended: High-end mobile devices, edge servers

💾 **Most Memory Efficient:** MobileNet_1.0 (inf MB peak)
   ✅ Best for: Memory-constrained devices
   📱 Recommended: Budget smartphones, IoT devices

📦 **Smallest Model:** MobileNet_0.25 (1.8 MB)
   ✅ Best for: App size constraints, OTA updates
   📱 Recommended: Mobile apps with size limits

⚡ **Fewest Parameters:** MobileNet_0.25 (0.5M params)
   ✅ Best for: Ultra-low power devices
   📱 Recommended: Microcontrollers, embedded systems

🏆 **Best Overall Balance:** MobileNet_0.25
   💡 Efficiency Score: 160.509
   ✅ Best for: General-purpose mobile AI applications
   📱 Recommended: Production deployments</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="neural-architecture-search-nas-for-mobilenet" class="level3">
<h3 class="anchored" data-anchor-id="neural-architecture-search-nas-for-mobilenet" id="neural-architecture-search-nas-for-mobilenet">Neural Architecture Search (NAS) for MobileNet</h3>
<div id="nas-mobilenet" class="cell" data-caption="Neural Architecture Search for MobileNet optimization" data-execution_count="25">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetSearchSpace:</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Define search space for MobileNet architecture optimization."""</span></span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.width_multipliers <span class="op">=</span> [<span class="fl">0.25</span>, <span class="fl">0.35</span>, <span class="fl">0.5</span>, <span class="fl">0.75</span>, <span class="fl">1.0</span>, <span class="fl">1.4</span>]</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.depth_multipliers <span class="op">=</span> [<span class="fl">0.5</span>, <span class="fl">0.75</span>, <span class="fl">1.0</span>, <span class="fl">1.25</span>]</span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.kernel_sizes <span class="op">=</span> [<span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>]</span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.activation_functions <span class="op">=</span> [<span class="st">'relu6'</span>, <span class="st">'swish'</span>, <span class="st">'hard_swish'</span>]</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>):</span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample a random architecture from search space."""</span></span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> random</span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb26-15"><a href="#cb26-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">'width_mult'</span>: random.choice(<span class="va">self</span>.width_multipliers),</span>
<span id="cb26-16"><a href="#cb26-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">'depth_mult'</span>: random.choice(<span class="va">self</span>.depth_multipliers),</span>
<span id="cb26-17"><a href="#cb26-17" aria-hidden="true" tabindex="-1"></a>            <span class="st">'kernel_size'</span>: random.choice(<span class="va">self</span>.kernel_sizes),</span>
<span id="cb26-18"><a href="#cb26-18" aria-hidden="true" tabindex="-1"></a>            <span class="st">'activation'</span>: random.choice(<span class="va">self</span>.activation_functions)</span>
<span id="cb26-19"><a href="#cb26-19" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb26-20"><a href="#cb26-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-21"><a href="#cb26-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_architecture(<span class="va">self</span>, arch_config, train_loader, val_loader):</span>
<span id="cb26-22"><a href="#cb26-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate a sampled architecture."""</span></span>
<span id="cb26-23"><a href="#cb26-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-24"><a href="#cb26-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create model with sampled configuration</span></span>
<span id="cb26-25"><a href="#cb26-25" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> <span class="va">self</span>.create_model_from_config(arch_config)</span>
<span id="cb26-26"><a href="#cb26-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-27"><a href="#cb26-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Quick training (few epochs for NAS efficiency)</span></span>
<span id="cb26-28"><a href="#cb26-28" aria-hidden="true" tabindex="-1"></a>        trainer <span class="op">=</span> MobileNetTrainer(model)</span>
<span id="cb26-29"><a href="#cb26-29" aria-hidden="true" tabindex="-1"></a>        trainer.train(train_loader, val_loader, epochs<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb26-30"><a href="#cb26-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-31"><a href="#cb26-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate efficiency metrics</span></span>
<span id="cb26-32"><a href="#cb26-32" aria-hidden="true" tabindex="-1"></a>        benchmark <span class="op">=</span> MobileNetBenchmark(model)</span>
<span id="cb26-33"><a href="#cb26-33" aria-hidden="true" tabindex="-1"></a>        perf_stats <span class="op">=</span> benchmark.benchmark_inference()</span>
<span id="cb26-34"><a href="#cb26-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-35"><a href="#cb26-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Return multi-objective score</span></span>
<span id="cb26-36"><a href="#cb26-36" aria-hidden="true" tabindex="-1"></a>        accuracy <span class="op">=</span> trainer.history[<span class="st">'val_acc'</span>][<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb26-37"><a href="#cb26-37" aria-hidden="true" tabindex="-1"></a>        latency <span class="op">=</span> perf_stats[<span class="st">'mean_time'</span>]</span>
<span id="cb26-38"><a href="#cb26-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-39"><a href="#cb26-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Pareto efficiency score</span></span>
<span id="cb26-40"><a href="#cb26-40" aria-hidden="true" tabindex="-1"></a>        score <span class="op">=</span> accuracy <span class="op">/</span> (latency <span class="op">*</span> <span class="dv">1000</span>)  <span class="co"># Accuracy per ms</span></span>
<span id="cb26-41"><a href="#cb26-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-42"><a href="#cb26-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb26-43"><a href="#cb26-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">'score'</span>: score,</span>
<span id="cb26-44"><a href="#cb26-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">'accuracy'</span>: accuracy,</span>
<span id="cb26-45"><a href="#cb26-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'latency'</span>: latency,</span>
<span id="cb26-46"><a href="#cb26-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">'config'</span>: arch_config</span>
<span id="cb26-47"><a href="#cb26-47" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb26-48"><a href="#cb26-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-49"><a href="#cb26-49" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_model_from_config(<span class="va">self</span>, config):</span>
<span id="cb26-50"><a href="#cb26-50" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create MobileNet model from configuration."""</span></span>
<span id="cb26-51"><a href="#cb26-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simplified - in practice would build full architecture</span></span>
<span id="cb26-52"><a href="#cb26-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> mobilenet_v1(</span>
<span id="cb26-53"><a href="#cb26-53" aria-hidden="true" tabindex="-1"></a>            width_mult<span class="op">=</span>config[<span class="st">'width_mult'</span>],</span>
<span id="cb26-54"><a href="#cb26-54" aria-hidden="true" tabindex="-1"></a>            num_classes<span class="op">=</span><span class="dv">1000</span></span>
<span id="cb26-55"><a href="#cb26-55" aria-hidden="true" tabindex="-1"></a>        )</span></code></pre></div></div>
</div>
</section>
<section id="sec-knowledge-distillation" class="level3">
<h3 class="anchored" data-anchor-id="sec-knowledge-distillation" id="sec-knowledge-distillation">Knowledge Distillation for MobileNet</h3>
<div id="knowledge-distillation" class="cell" data-caption="Knowledge distillation to improve MobileNet performance" data-execution_count="26">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Knowledge Distillation for MobileNet</span></span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> KnowledgeDistillation:</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Knowledge distillation to improve MobileNet performance."""</span></span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, teacher_model, student_model, temperature<span class="op">=</span><span class="fl">4.0</span>, alpha<span class="op">=</span><span class="fl">0.3</span>):</span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher <span class="op">=</span> teacher_model</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.student <span class="op">=</span> student_model</span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.temperature <span class="op">=</span> temperature</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> alpha  <span class="co"># Weight for distillation loss</span></span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Freeze teacher model</span></span>
<span id="cb27-12"><a href="#cb27-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.teacher.parameters():</span>
<span id="cb27-13"><a href="#cb27-13" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb27-14"><a href="#cb27-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher.<span class="bu">eval</span>()</span>
<span id="cb27-15"><a href="#cb27-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb27-16"><a href="#cb27-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> distillation_loss(<span class="va">self</span>, student_outputs, teacher_outputs, labels):</span>
<span id="cb27-17"><a href="#cb27-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Calculate knowledge distillation loss."""</span></span>
<span id="cb27-18"><a href="#cb27-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-19"><a href="#cb27-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Soft targets from teacher</span></span>
<span id="cb27-20"><a href="#cb27-20" aria-hidden="true" tabindex="-1"></a>        teacher_probs <span class="op">=</span> F.softmax(teacher_outputs <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb27-21"><a href="#cb27-21" aria-hidden="true" tabindex="-1"></a>        student_log_probs <span class="op">=</span> F.log_softmax(student_outputs <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb27-22"><a href="#cb27-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-23"><a href="#cb27-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># KL divergence loss</span></span>
<span id="cb27-24"><a href="#cb27-24" aria-hidden="true" tabindex="-1"></a>        distillation_loss <span class="op">=</span> F.kl_div(</span>
<span id="cb27-25"><a href="#cb27-25" aria-hidden="true" tabindex="-1"></a>            student_log_probs, teacher_probs, reduction<span class="op">=</span><span class="st">'batchmean'</span></span>
<span id="cb27-26"><a href="#cb27-26" aria-hidden="true" tabindex="-1"></a>        ) <span class="op">*</span> (<span class="va">self</span>.temperature <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb27-27"><a href="#cb27-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-28"><a href="#cb27-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Standard cross-entropy loss</span></span>
<span id="cb27-29"><a href="#cb27-29" aria-hidden="true" tabindex="-1"></a>        student_loss <span class="op">=</span> F.cross_entropy(student_outputs, labels)</span>
<span id="cb27-30"><a href="#cb27-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-31"><a href="#cb27-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combined loss</span></span>
<span id="cb27-32"><a href="#cb27-32" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> (</span>
<span id="cb27-33"><a href="#cb27-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.alpha <span class="op">*</span> distillation_loss <span class="op">+</span> </span>
<span id="cb27-34"><a href="#cb27-34" aria-hidden="true" tabindex="-1"></a>            (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.alpha) <span class="op">*</span> student_loss</span>
<span id="cb27-35"><a href="#cb27-35" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb27-36"><a href="#cb27-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-37"><a href="#cb27-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss</span>
<span id="cb27-38"><a href="#cb27-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb27-39"><a href="#cb27-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_with_distillation(<span class="va">self</span>, train_loader, val_loader, epochs<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb27-40"><a href="#cb27-40" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train student model with knowledge distillation."""</span></span>
<span id="cb27-41"><a href="#cb27-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-42"><a href="#cb27-42" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.Adam(<span class="va">self</span>.student.parameters(), lr<span class="op">=</span><span class="fl">0.001</span>)</span>
<span id="cb27-43"><a href="#cb27-43" aria-hidden="true" tabindex="-1"></a>        scheduler <span class="op">=</span> optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">7</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb27-44"><a href="#cb27-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-45"><a href="#cb27-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb27-46"><a href="#cb27-46" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.student.train()</span>
<span id="cb27-47"><a href="#cb27-47" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb27-48"><a href="#cb27-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb27-49"><a href="#cb27-49" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> inputs, labels <span class="kw">in</span> train_loader:</span>
<span id="cb27-50"><a href="#cb27-50" aria-hidden="true" tabindex="-1"></a>                optimizer.zero_grad()</span>
<span id="cb27-51"><a href="#cb27-51" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb27-52"><a href="#cb27-52" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Get predictions from both models</span></span>
<span id="cb27-53"><a href="#cb27-53" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> torch.no_grad():</span>
<span id="cb27-54"><a href="#cb27-54" aria-hidden="true" tabindex="-1"></a>                    teacher_outputs <span class="op">=</span> <span class="va">self</span>.teacher(inputs)</span>
<span id="cb27-55"><a href="#cb27-55" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb27-56"><a href="#cb27-56" aria-hidden="true" tabindex="-1"></a>                student_outputs <span class="op">=</span> <span class="va">self</span>.student(inputs)</span>
<span id="cb27-57"><a href="#cb27-57" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb27-58"><a href="#cb27-58" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Calculate distillation loss</span></span>
<span id="cb27-59"><a href="#cb27-59" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> <span class="va">self</span>.distillation_loss(student_outputs, teacher_outputs, labels)</span>
<span id="cb27-60"><a href="#cb27-60" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb27-61"><a href="#cb27-61" aria-hidden="true" tabindex="-1"></a>                loss.backward()</span>
<span id="cb27-62"><a href="#cb27-62" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb27-63"><a href="#cb27-63" aria-hidden="true" tabindex="-1"></a>                running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb27-64"><a href="#cb27-64" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb27-65"><a href="#cb27-65" aria-hidden="true" tabindex="-1"></a>            scheduler.step()</span>
<span id="cb27-66"><a href="#cb27-66" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb27-67"><a href="#cb27-67" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validation</span></span>
<span id="cb27-68"><a href="#cb27-68" aria-hidden="true" tabindex="-1"></a>            val_acc <span class="op">=</span> <span class="va">self</span>.validate(val_loader)</span>
<span id="cb27-69"><a href="#cb27-69" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>running_loss<span class="op">/</span><span class="bu">len</span>(train_loader)<span class="sc">:.4f}</span><span class="ss">, Val Acc: </span><span class="sc">{</span>val_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb27-70"><a href="#cb27-70" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb27-71"><a href="#cb27-71" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate(<span class="va">self</span>, val_loader):</span>
<span id="cb27-72"><a href="#cb27-72" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Validate student model."""</span></span>
<span id="cb27-73"><a href="#cb27-73" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.student.<span class="bu">eval</span>()</span>
<span id="cb27-74"><a href="#cb27-74" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb27-75"><a href="#cb27-75" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb27-76"><a href="#cb27-76" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-77"><a href="#cb27-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb27-78"><a href="#cb27-78" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> inputs, labels <span class="kw">in</span> val_loader:</span>
<span id="cb27-79"><a href="#cb27-79" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.student(inputs)</span>
<span id="cb27-80"><a href="#cb27-80" aria-hidden="true" tabindex="-1"></a>                _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs, <span class="dv">1</span>)</span>
<span id="cb27-81"><a href="#cb27-81" aria-hidden="true" tabindex="-1"></a>                total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb27-82"><a href="#cb27-82" aria-hidden="true" tabindex="-1"></a>                correct <span class="op">+=</span> (predicted <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb27-83"><a href="#cb27-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb27-84"><a href="#cb27-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This comprehensive guide has covered MobileNet from fundamental concepts to production deployment. The journey through depthwise separable convolutions, implementation details, optimization techniques, and real-world deployment strategies provides a complete foundation for building efficient mobile AI applications.</p>
<section id="key-takeaways" class="level3">
<h3 class="anchored" data-anchor-id="key-takeaways" id="key-takeaways">Key Takeaways</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>🎯 <strong>Essential Insights</strong>
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Architectural Innovation:</strong> - Depthwise separable convolutions reduce computation by 8-9× with minimal accuracy loss - Width multipliers provide flexible trade-offs between accuracy and efficiency - The architecture scales gracefully across different hardware constraints</p>
<p><strong>Implementation Best Practices:</strong> - Always profile on target hardware before deployment - Use appropriate data augmentation for robust training - Consider knowledge distillation for improved student model performance - Apply quantization and pruning strategically based on deployment requirements</p>
</div>
</div>
</section>
<section id="performance-summary" class="level3">
<h3 class="anchored" data-anchor-id="performance-summary" id="performance-summary">Performance Summary</h3>
<p>Based on our comprehensive analysis, here are the recommended MobileNet configurations:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 16%">
<col style="width: 22%">
<col style="width: 31%">
<col style="width: 28%">
</colgroup>
<thead>
<tr class="header">
<th><strong>Use Case</strong></th>
<th><strong>Configuration</strong></th>
<th><strong>Expected Performance</strong></th>
<th><strong>Deployment Target</strong></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Real-time Video</strong></td>
<td>MobileNet-V2 1.0×</td>
<td>30+ FPS, 72% accuracy</td>
<td>High-end mobile devices</td>
</tr>
<tr class="even">
<td><strong>General Mobile AI</strong></td>
<td>MobileNet-V1 0.75×</td>
<td>45+ FPS, 68% accuracy</td>
<td>Mid-range smartphones</td>
</tr>
<tr class="odd">
<td><strong>Edge Computing</strong></td>
<td>MobileNet-V1 0.5×</td>
<td>60+ FPS, 64% accuracy</td>
<td>Edge servers, IoT hubs</td>
</tr>
<tr class="even">
<td><strong>Embedded Systems</strong></td>
<td>MobileNet-V1 0.25×</td>
<td>80+ FPS, 51% accuracy</td>
<td>Microcontrollers, sensors</td>
</tr>
</tbody>
</table>
</section>
<section id="deployment-recommendations" class="level3">
<h3 class="anchored" data-anchor-id="deployment-recommendations" id="deployment-recommendations">Deployment Recommendations</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>🚀 <strong>Production Deployment Checklist</strong>
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Pre-deployment:</strong></p>
<ul class="task-list">
<li><label><input type="checkbox">Benchmark on actual target hardware</label></li>
<li><label><input type="checkbox">Validate accuracy on representative test data<br>
</label></li>
<li><label><input type="checkbox">Measure memory usage under realistic conditions</label></li>
<li><label><input type="checkbox">Test battery consumption (for mobile devices)</label></li>
<li><label><input type="checkbox">Verify model export/conversion pipeline</label></li>
</ul>
<p><strong>Optimization Pipeline:</strong></p>
<ul class="task-list">
<li><label><input type="checkbox">Apply appropriate quantization (dynamic/static)</label></li>
<li><label><input type="checkbox">Consider structured pruning for further compression</label></li>
<li><label><input type="checkbox">Export to platform-specific formats (ONNX, TFLite, CoreML)</label></li>
<li><label><input type="checkbox">Implement efficient preprocessing pipelines</label></li>
<li><label><input type="checkbox">Add monitoring and performance tracking</label></li>
</ul>
<p><strong>Platform Integration:</strong></p>
<ul class="task-list">
<li><label><input type="checkbox">Handle model loading and initialization efficiently</label></li>
<li><label><input type="checkbox">Implement proper error handling and fallbacks</label></li>
<li><label><input type="checkbox">Use background threads for inference</label></li>
<li><label><input type="checkbox">Cache models and avoid repeated loading</label></li>
<li><label><input type="checkbox">Plan for model updates and versioning</label></li>
</ul>
</div>
</div>
</section>
<section id="common-pitfalls-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-pitfalls-and-solutions" id="common-pitfalls-and-solutions">Common Pitfalls and Solutions</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>⚠️ <strong>Avoid These Mistakes</strong>
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Performance Issues:</strong></p>
<ul>
<li><strong>Problem</strong>: Model runs slower on device than benchmarks suggest</li>
<li><strong>Solution</strong>: Always test with realistic input pipelines and preprocessing</li>
</ul>
<p><strong>Memory Problems:</strong></p>
<ul>
<li><strong>Problem</strong>: Out of memory errors during inference<br>
</li>
<li><strong>Solution</strong>: Monitor peak memory usage, not just model size</li>
</ul>
<p><strong>Accuracy Degradation:</strong></p>
<ul>
<li><strong>Problem</strong>: Significant accuracy drop after optimization</li>
<li><strong>Solution</strong>: Use quantization-aware training and gradual pruning</li>
</ul>
<p><strong>Integration Challenges:</strong></p>
<ul>
<li><strong>Problem</strong>: Model format incompatibility with deployment platform</li>
<li><strong>Solution</strong>: Test export pipeline early and validate outputs</li>
</ul>
</div>
</div>
</section>
<section id="future-directions" class="level3">
<h3 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h3>
<p>The field of efficient neural networks continues to evolve rapidly:</p>
<p><strong>Next-Generation Architectures:</strong></p>
<ul>
<li><strong>EfficientNet</strong> and <strong>EfficientNetV2</strong>: Better scaling strategies with compound scaling</li>
<li><strong>MobileViT</strong>: Combining CNNs with Vision Transformers for mobile deployment</li>
<li><strong>Once-for-All Networks</strong>: Single networks supporting multiple deployment scenarios</li>
</ul>
<p><strong>Advanced Optimization Techniques:</strong></p>
<ul>
<li><strong>Neural Architecture Search (NAS)</strong>: Automated architecture optimization</li>
<li><strong>Differentiable Architecture Search</strong>: End-to-end learnable architectures<br>
</li>
<li><strong>Hardware-aware NAS</strong>: Optimizing specifically for target hardware</li>
</ul>
<p><strong>Deployment Innovations:</strong></p>
<ul>
<li><strong>Edge AI Accelerators</strong>: Custom silicon for mobile AI (Apple Neural Engine, Google Edge TPU)</li>
<li><strong>Federated Learning</strong>: Training models across distributed mobile devices</li>
<li><strong>Model Compression</strong>: Advanced techniques beyond pruning and quantization</li>
</ul>
</section>
<section id="resources-and-further-reading" class="level3">
<h3 class="anchored" data-anchor-id="resources-and-further-reading" id="resources-and-further-reading">Resources and Further Reading</h3>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>📚 <strong>Additional Resources</strong>
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Essential Papers:</strong></p>
<ul>
<li><a href="https://arxiv.org/abs/1704.04861">MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications</a></li>
<li><a href="https://arxiv.org/abs/1801.04381">MobileNetV2: Inverted Residuals and Linear Bottlenecks</a></li>
<li><a href="https://arxiv.org/abs/1905.02244">Searching for MobileNetV3</a></li>
</ul>
<p><strong>Implementation Resources:</strong></p>
<ul>
<li><a href="https://pytorch.org/mobile/home/">PyTorch Mobile Documentation</a></li>
<li><a href="https://www.tensorflow.org/lite">TensorFlow Lite Guide</a></li>
<li><a href="https://onnxruntime.ai/docs/tutorials/mobile/">ONNX Runtime Mobile</a></li>
</ul>
<p><strong>Community and Support:</strong></p>
<ul>
<li><a href="https://discuss.pytorch.org/c/mobile/19">PyTorch Forums - Mobile</a></li>
<li><a href="https://www.tensorflow.org/community">TensorFlow Community</a></li>
<li><a href="https://paperswithcode.com/task/mobile-ai">Papers With Code - Mobile AI</a></li>
</ul>
</div>
</div>
</section>
<section id="sec-final-thoughts" class="level3">
<h3 class="anchored" data-anchor-id="sec-final-thoughts" id="sec-final-thoughts">Final Thoughts</h3>
<p>MobileNet represents a paradigm shift in how we approach deep learning for resource-constrained environments. The techniques and principles covered in this guide extend beyond MobileNet itself – they form the foundation for understanding and implementing efficient AI systems across a wide range of applications.</p>
<p>As mobile and edge AI continues to grow, the ability to design, implement, and deploy efficient neural networks becomes increasingly valuable. Whether you’re building the next generation of mobile apps, edge computing solutions, or embedded AI systems, the concepts and code in this guide provide a solid foundation for success.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>🎯 <strong>Remember</strong>
</div>
</div>
<div class="callout-body-container callout-body">
<p>The best model is not necessarily the most accurate one, but the one that best serves your users within the constraints of your deployment environment. Always optimize for the complete user experience, not just benchmark metrics.</p>
</div>
</div>
</section>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references" id="references">References</h2>
<ul>
<li>Howard, A. G., et al.&nbsp;(2017). MobileNets: Efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861.</li>
<li>Sandler, M., et al.&nbsp;(2018). MobileNetV2: Inverted residuals and linear bottlenecks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp.&nbsp;4510-4520).</li>
<li>Howard, A., et al.&nbsp;(2019). Searching for mobilenetv3. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp.&nbsp;1314-1324).</li>
</ul>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[MobileNet: Efficient Neural Networks for Mobile Vision Applications]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-summary/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-summary/</guid>
      <pubDate>Sat, 19 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="mobilenet-efficient-neural-networks-for-mobile-vision-applications" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/mobile-net/mobile-net-summary/mobnet.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>MobileNet represents a revolutionary approach to deep learning architecture design, specifically optimized for mobile and embedded vision applications. Introduced by Google researchers in 2017, MobileNet addresses one of the most pressing challenges in deploying deep neural networks: achieving high accuracy while maintaining computational efficiency on resource-constrained devices.</p>
<p>The traditional approach to neural network design focused primarily on accuracy, often at the expense of computational complexity. Networks like VGGNet, ResNet, and Inception achieved remarkable performance on image classification tasks but required substantial computational resources, making them impractical for mobile deployment. MobileNet fundamentally changed this paradigm by introducing depthwise separable convolutions, a technique that dramatically reduces the number of parameters and computational operations while preserving much of the representational power of traditional convolutional neural networks.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Innovation
</div>
</div>
<div class="callout-body-container callout-body">
<p>MobileNet’s primary contribution is the introduction of <strong>depthwise separable convolutions</strong>, which provide an 8-9x reduction in computational cost compared to standard convolutions with minimal accuracy loss.</p>
</div>
</div>
</section>
<section id="core-innovation-depthwise-separable-convolutions" class="level2">
<h2 class="anchored" data-anchor-id="core-innovation-depthwise-separable-convolutions" id="core-innovation-depthwise-separable-convolutions">Core Innovation: Depthwise Separable Convolutions</h2>
<section id="understanding-standard-convolutions" class="level3">
<h3 class="anchored" data-anchor-id="understanding-standard-convolutions" id="understanding-standard-convolutions">Understanding Standard Convolutions</h3>
<p>To appreciate MobileNet’s innovation, it’s essential to understand how standard convolutions work. A standard convolutional layer applies a set of filters across the input feature map. For an input feature map of size <span class="math inline">\(D_F \times D_F \times M\)</span> (height, width, channels) and <span class="math inline">\(N\)</span> output channels with kernel size <span class="math inline">\(D_K \times D_K\)</span>, a standard convolution requires:</p>
<ul>
<li><strong>Parameters</strong>: <span class="math inline">\(D_K \times D_K \times M \times N\)</span></li>
<li><strong>Computational cost</strong>: <span class="math inline">\(D_K \times D_K \times M \times N \times D_F \times D_F\)</span></li>
</ul>
<p>This computational cost grows rapidly with the number of input and output channels, making standard convolutions expensive for mobile applications.</p>
</section>
<section id="depthwise-separable-convolutions" class="level3">
<h3 class="anchored" data-anchor-id="depthwise-separable-convolutions" id="depthwise-separable-convolutions">Depthwise Separable Convolutions</h3>
<p>MobileNet’s key innovation lies in factorizing standard convolutions into two separate operations:</p>
<ol type="1">
<li><strong>Depthwise Convolution</strong>: Applies a single filter to each input channel separately</li>
<li><strong>Pointwise Convolution</strong>: Uses 1×1 convolutions to combine the outputs of the depthwise convolution</li>
</ol>
<section id="depthwise-convolution" class="level4">
<h4 class="anchored" data-anchor-id="depthwise-convolution">Depthwise Convolution</h4>
<p>The depthwise convolution applies a single convolutional filter to each input channel. For <span class="math inline">\(M\)</span> input channels, this requires <span class="math inline">\(M\)</span> filters of size <span class="math inline">\(D_K \times D_K \times 1\)</span>. The computational cost is:</p>
<ul>
<li><strong>Parameters</strong>: <span class="math inline">\(D_K \times D_K \times M\)</span></li>
<li><strong>Computational cost</strong>: <span class="math inline">\(D_K \times D_K \times M \times D_F \times D_F\)</span></li>
</ul>
</section>
<section id="pointwise-convolution" class="level4">
<h4 class="anchored" data-anchor-id="pointwise-convolution">Pointwise Convolution</h4>
<p>The pointwise convolution uses 1×1 convolutions to create new features by computing linear combinations of the input channels. This step requires:</p>
<ul>
<li><strong>Parameters</strong>: <span class="math inline">\(M \times N\)</span></li>
<li><strong>Computational cost</strong>: <span class="math inline">\(M \times N \times D_F \times D_F\)</span></li>
</ul>
</section>
</section>
<section id="efficiency-gains" class="level3">
<h3 class="anchored" data-anchor-id="efficiency-gains" id="efficiency-gains">Efficiency Gains</h3>
<p>The total cost of depthwise separable convolution is the sum of depthwise and pointwise convolutions:</p>
<ul>
<li><strong>Total parameters</strong>: <span class="math inline">\(D_K \times D_K \times M + M \times N\)</span></li>
<li><strong>Total computational cost</strong>: <span class="math inline">\((D_K \times D_K \times M \times D_F \times D_F) + (M \times N \times D_F \times D_F)\)</span></li>
</ul>
<p>Compared to standard convolution, the reduction in computational cost is:</p>
<p><span class="math display">\[
\text{Reduction} = \frac{D_K^2 \times M \times D_F^2 + M \times N \times D_F^2}{D_K^2 \times M \times N \times D_F^2} = \frac{1}{N} + \frac{1}{D_K^2}
\]</span></p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Efficiency Example
</div>
</div>
<div class="callout-body-container callout-body">
<p>For typical values (<span class="math inline">\(D_K = 3\)</span>, <span class="math inline">\(N = 256\)</span>), this represents approximately an <strong>8-9x reduction</strong> in computational cost with minimal accuracy loss.</p>
</div>
</div>
</section>
</section>
<section id="mobilenet-architecture" class="level2">
<h2 class="anchored" data-anchor-id="mobilenet-architecture" id="mobilenet-architecture">MobileNet Architecture</h2>
<section id="overall-structure" class="level3">
<h3 class="anchored" data-anchor-id="overall-structure" id="overall-structure">Overall Structure</h3>
<p>MobileNet follows a straightforward architecture based on depthwise separable convolutions. The network begins with a standard 3×3 convolution followed by 13 depthwise separable convolution layers. Each depthwise separable convolution is followed by batch normalization and ReLU activation.</p>
<p>The architecture progressively reduces spatial resolution while increasing the number of channels, following the general pattern established by successful CNN architectures. The network concludes with global average pooling, a fully connected layer, and softmax activation for classification.</p>
</section>
<section id="width-and-resolution-multipliers" class="level3">
<h3 class="anchored" data-anchor-id="width-and-resolution-multipliers" id="width-and-resolution-multipliers">Width and Resolution Multipliers</h3>
<p>MobileNet introduces two hyperparameters to provide additional control over the trade-off between accuracy and efficiency:</p>
<section id="width-multiplier-α" class="level4">
<h4 class="anchored" data-anchor-id="width-multiplier-α">Width Multiplier (α)</h4>
<p>The width multiplier <span class="math inline">\(\alpha \in (0,1]\)</span> uniformly reduces the number of channels in each layer. With width multiplier <span class="math inline">\(\alpha\)</span>, the number of input channels <span class="math inline">\(M\)</span> becomes <span class="math inline">\(\alpha M\)</span> and the number of output channels <span class="math inline">\(N\)</span> becomes <span class="math inline">\(\alpha N\)</span>. This reduces computational cost by approximately <span class="math inline">\(\alpha^2\)</span>.</p>
<p>Common values for <span class="math inline">\(\alpha\)</span> include:</p>
<ul>
<li>1.0 (full model)</li>
<li>0.75</li>
<li>0.5</li>
<li>0.25</li>
</ul>
</section>
<section id="resolution-multiplier-ρ" class="level4">
<h4 class="anchored" data-anchor-id="resolution-multiplier-ρ">Resolution Multiplier (ρ)</h4>
<p>The resolution multiplier <span class="math inline">\(\rho \in (0,1]\)</span> reduces the input image resolution. The input image size becomes <span class="math inline">\(\rho D_F \times \rho D_F\)</span>, which reduces computational cost by approximately <span class="math inline">\(\rho^2\)</span>.</p>
<p>Typical values for <span class="math inline">\(\rho\)</span> correspond to common input resolutions: 224, 192, 160, and 128 pixels.</p>
</section>
</section>
</section>
<section id="training-and-implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="training-and-implementation-details" id="training-and-implementation-details">Training and Implementation Details</h2>
<section id="training-procedure" class="level3">
<h3 class="anchored" data-anchor-id="training-procedure" id="training-procedure">Training Procedure</h3>
<p>MobileNet models are typically trained using standard techniques for image classification:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 57%">
<col style="width: 42%">
</colgroup>
<thead>
<tr class="header">
<th>Parameter</th>
<th>Value</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Optimizer</strong></td>
<td>RMSprop with decay 0.9 and momentum 0.9</td>
</tr>
<tr class="even">
<td><strong>Learning Rate</strong></td>
<td>Initial rate of 0.045 with exponential decay every two epochs</td>
</tr>
<tr class="odd">
<td><strong>Weight Decay</strong></td>
<td>L2 regularization with weight decay of 4e-5</td>
</tr>
<tr class="even">
<td><strong>Batch Size</strong></td>
<td>Typically 96-128 depending on available memory</td>
</tr>
<tr class="odd">
<td><strong>Data Augmentation</strong></td>
<td>Random crops, horizontal flips, and color jittering</td>
</tr>
</tbody>
</table>
</section>
<section id="batch-normalization-and-activation" class="level3">
<h3 class="anchored" data-anchor-id="batch-normalization-and-activation" id="batch-normalization-and-activation">Batch Normalization and Activation</h3>
<p>Each convolutional layer in MobileNet is followed by batch normalization and ReLU6 activation. ReLU6 is preferred over standard ReLU because it is more robust when used with low-precision arithmetic, making it suitable for mobile deployment where quantization is often employed.</p>
</section>
<section id="dropout-and-regularization" class="level3">
<h3 class="anchored" data-anchor-id="dropout-and-regularization" id="dropout-and-regularization">Dropout and Regularization</h3>
<p>MobileNet employs several regularization techniques:</p>
<ul>
<li>Batch normalization after each convolutional layer</li>
<li>Dropout with rate 0.001 before the final classification layer</li>
<li>L2 weight decay as mentioned above</li>
</ul>
</section>
</section>
<section id="performance-analysis" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis" id="performance-analysis">Performance Analysis</h2>
<section id="accuracy-vs.-efficiency-trade-offs" class="level3">
<h3 class="anchored" data-anchor-id="accuracy-vs.-efficiency-trade-offs" id="accuracy-vs.-efficiency-trade-offs">Accuracy vs.&nbsp;Efficiency Trade-offs</h3>
<p>MobileNet achieves remarkable efficiency gains while maintaining competitive accuracy. On ImageNet classification:</p>
<ul>
<li><strong>MobileNet-224</strong> (α=1.0): 70.6% top-1 accuracy with 569M multiply-adds</li>
<li><strong>VGG-16</strong>: 71.5% top-1 accuracy with 15.3B multiply-adds</li>
</ul>
<p>This represents a <strong>27x reduction</strong> in computational cost for only 0.9% accuracy loss.</p>
</section>
<section id="comparison-with-other-architectures" class="level3">
<h3 class="anchored" data-anchor-id="comparison-with-other-architectures" id="comparison-with-other-architectures">Comparison with Other Architectures</h3>
<p>MobileNet’s efficiency becomes particularly apparent when compared to other popular architectures:</p>
<div id="tbl-performance" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Model Performance Comparison
</figcaption>
<div aria-describedby="tbl-performance-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Top-1 Accuracy</th>
<th>Million Parameters</th>
<th>Million Multiply-Adds</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>MobileNet</td>
<td>70.6%</td>
<td>4.2</td>
<td>569</td>
</tr>
<tr class="even">
<td>GoogleNet</td>
<td>69.8%</td>
<td>6.8</td>
<td>1550</td>
</tr>
<tr class="odd">
<td>VGG-16</td>
<td>71.5%</td>
<td>138</td>
<td>15300</td>
</tr>
<tr class="even">
<td>Inception V3</td>
<td>78.0%</td>
<td>23.8</td>
<td>5720</td>
</tr>
<tr class="odd">
<td>ResNet-50</td>
<td>76.0%</td>
<td>25.5</td>
<td>3800</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>MobileNet achieves the best accuracy-to-computation ratio among these models, making it ideal for mobile deployment.</p>
</section>
<section id="ablation-studies" class="level3">
<h3 class="anchored" data-anchor-id="ablation-studies" id="ablation-studies">Ablation Studies</h3>
<p>Research has shown that various design choices in MobileNet contribute to its effectiveness:</p>
<ol type="1">
<li><strong>Depthwise vs.&nbsp;Standard Convolution</strong>: Depthwise separable convolutions provide 8-9x computational savings with minimal accuracy loss</li>
<li><strong>Width Multiplier Impact</strong>: Reducing width multiplier from 1.0 to 0.75 saves 40% computation with only 2.4% accuracy drop</li>
<li><strong>Resolution Multiplier Impact</strong>: Reducing input resolution from 224 to 192 saves 30% computation with 1.3% accuracy drop</li>
</ol>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Finding
</div>
</div>
<div class="callout-body-container callout-body">
<p>The ablation studies demonstrate that MobileNet’s design choices are well-justified, with each component contributing meaningfully to the overall efficiency-accuracy trade-off.</p>
</div>
</div>
</section>
</section>
<section id="evolution-mobilenetv2-and-beyond" class="level2">
<h2 class="anchored" data-anchor-id="evolution-mobilenetv2-and-beyond" id="evolution-mobilenetv2-and-beyond">Evolution: MobileNetV2 and Beyond</h2>
<section id="mobilenetv2-improvements" class="level3">
<h3 class="anchored" data-anchor-id="mobilenetv2-improvements" id="mobilenetv2-improvements">MobileNetV2 Improvements</h3>
<p>MobileNetV2, introduced in 2018, built upon the original MobileNet with several key improvements:</p>
<section id="inverted-residuals" class="level4">
<h4 class="anchored" data-anchor-id="inverted-residuals">Inverted Residuals</h4>
<p>MobileNetV2 introduces inverted residual blocks, which expand the number of channels before the depthwise convolution and then project back to a lower-dimensional space. This design maintains representational capacity while reducing memory usage.</p>
</section>
<section id="linear-bottlenecks" class="level4">
<h4 class="anchored" data-anchor-id="linear-bottlenecks">Linear Bottlenecks</h4>
<p>The final layer of each inverted residual block uses linear activation instead of ReLU. This prevents the loss of information that can occur when ReLU is applied to low-dimensional representations.</p>
</section>
<section id="improved-performance" class="level4">
<h4 class="anchored" data-anchor-id="improved-performance">Improved Performance</h4>
<p>MobileNetV2 achieves better accuracy than the original MobileNet while maintaining similar computational efficiency. On ImageNet, MobileNetV2 achieves 72.0% top-1 accuracy with similar computational cost to the original MobileNet.</p>
</section>
</section>
<section id="mobilenetv3" class="level3">
<h3 class="anchored" data-anchor-id="mobilenetv3" id="mobilenetv3">MobileNetV3</h3>
<p>MobileNetV3, released in 2019, incorporates several advanced techniques:</p>
<ul>
<li><strong>Neural Architecture Search (NAS)</strong>: Automated architecture design for optimal efficiency</li>
<li><strong>SE (Squeeze-and-Excitation) blocks</strong>: Attention mechanisms for better feature representation</li>
<li><strong>h-swish activation</strong>: More efficient than ReLU for mobile deployment</li>
<li><strong>Platform-aware NAS</strong>: Optimization specifically for mobile hardware</li>
</ul>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="image-classification" class="level3">
<h3 class="anchored" data-anchor-id="image-classification" id="image-classification">Image Classification</h3>
<p>MobileNet excels at image classification tasks on mobile devices. Its efficiency makes it suitable for real-time classification in mobile apps, enabling features like:</p>
<ul>
<li>Real-time object recognition in camera applications</li>
<li>Automatic photo tagging and organization</li>
<li>Visual search capabilities</li>
<li>Augmented reality applications</li>
</ul>
</section>
<section id="object-detection" class="level3">
<h3 class="anchored" data-anchor-id="object-detection" id="object-detection">Object Detection</h3>
<p>MobileNet serves as an excellent backbone for mobile object detection systems:</p>
<ul>
<li><strong>MobileNet-SSD</strong>: Combines MobileNet with Single Shot Detector for efficient object detection</li>
<li><strong>MobileNetV2-SSDLite</strong>: Further optimized for mobile deployment</li>
<li>Applications in autonomous vehicles, robotics, and surveillance systems</li>
</ul>
</section>
<section id="semantic-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="semantic-segmentation" id="semantic-segmentation">Semantic Segmentation</h3>
<p>MobileNet has been adapted for semantic segmentation tasks:</p>
<ul>
<li><strong>DeepLabV3+</strong>: Uses MobileNet as encoder for efficient semantic segmentation</li>
<li>Applications in image editing, medical imaging, and autonomous navigation</li>
</ul>
</section>
<section id="transfer-learning" class="level3">
<h3 class="anchored" data-anchor-id="transfer-learning" id="transfer-learning">Transfer Learning</h3>
<p>MobileNet’s pre-trained weights serve as excellent starting points for transfer learning:</p>
<ul>
<li>Fine-tuning for specialized classification tasks</li>
<li>Feature extraction for custom applications</li>
<li>Domain adaptation for specific use cases</li>
</ul>
</section>
</section>
<section id="deployment-considerations" class="level2">
<h2 class="anchored" data-anchor-id="deployment-considerations" id="deployment-considerations">Deployment Considerations</h2>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h3>
<p>MobileNet’s design makes it particularly amenable to quantization, a technique that reduces the precision of weights and activations to decrease memory usage and increase inference speed:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">8-bit Quantization</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">16-bit Quantization</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Dynamic Quantization</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>Reduces model size by 4x with minimal accuracy loss</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Balanced approach between compression and accuracy</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Runtime optimization for different deployment scenarios</p>
</div>
</div>
</div>
</section>
<section id="hardware-optimization" class="level3">
<h3 class="anchored" data-anchor-id="hardware-optimization" id="hardware-optimization">Hardware Optimization</h3>
<p>MobileNet’s architecture aligns well with mobile hardware capabilities:</p>
<ul>
<li><strong>ARM processors</strong>: Efficient execution on mobile CPUs</li>
<li><strong>Neural processing units (NPUs)</strong>: Dedicated hardware acceleration</li>
<li><strong>GPU acceleration</strong>: Optimized implementations for mobile GPUs</li>
</ul>
</section>
<section id="framework-support" class="level3">
<h3 class="anchored" data-anchor-id="framework-support" id="framework-support">Framework Support</h3>
<p>MobileNet enjoys broad support across major deep learning frameworks:</p>
<ul>
<li><strong>TensorFlow Lite</strong>: Optimized for mobile deployment</li>
<li><strong>Core ML</strong>: Apple’s framework for iOS deployment</li>
<li><strong>ONNX</strong>: Cross-platform model representation</li>
<li><strong>PyTorch Mobile</strong>: Facebook’s mobile deployment solution</li>
</ul>
</section>
</section>
<section id="limitations-and-considerations" class="level2">
<h2 class="anchored" data-anchor-id="limitations-and-considerations" id="limitations-and-considerations">Limitations and Considerations</h2>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Trade-offs to Consider
</div>
</div>
<div class="callout-body-container callout-body">
<p>While MobileNet achieves impressive efficiency, practitioners should be aware of inherent trade-offs and limitations.</p>
</div>
</div>
<section id="accuracy-trade-offs" class="level3">
<h3 class="anchored" data-anchor-id="accuracy-trade-offs" id="accuracy-trade-offs">Accuracy Trade-offs</h3>
<p>While MobileNet achieves impressive efficiency, there are inherent trade-offs:</p>
<ul>
<li>Lower accuracy compared to larger models on complex tasks</li>
<li>Reduced representational capacity may limit performance on fine-grained classification</li>
<li>Potential degradation in transfer learning performance for significantly different domains</li>
</ul>
</section>
<section id="architecture-constraints" class="level3">
<h3 class="anchored" data-anchor-id="architecture-constraints" id="architecture-constraints">Architecture Constraints</h3>
<p>MobileNet’s design imposes certain limitations:</p>
<ul>
<li>Fixed architecture pattern may not be optimal for all tasks</li>
<li>Limited flexibility compared to more modular architectures</li>
<li>Potential bottlenecks in very deep variants</li>
</ul>
</section>
<section id="training-considerations" class="level3">
<h3 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h3>
<p>Training MobileNet requires careful attention to:</p>
<ul>
<li>Regularization to prevent overfitting with fewer parameters</li>
<li>Learning rate scheduling for stable convergence</li>
<li>Data augmentation strategies to improve generalization</li>
</ul>
</section>
</section>
<section id="future-directions-and-research" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-research" id="future-directions-and-research">Future Directions and Research</h2>
<section id="architectural-innovations" class="level3">
<h3 class="anchored" data-anchor-id="architectural-innovations" id="architectural-innovations">Architectural Innovations</h3>
<p>Ongoing research continues to improve upon MobileNet’s design:</p>
<ul>
<li><strong>Attention mechanisms</strong>: Integration of self-attention for better feature representation</li>
<li><strong>Dynamic networks</strong>: Adaptive computation based on input complexity</li>
<li><strong>Multi-scale processing</strong>: Handling objects at different scales more effectively</li>
</ul>
</section>
<section id="hardware-software-co-design" class="level3">
<h3 class="anchored" data-anchor-id="hardware-software-co-design" id="hardware-software-co-design">Hardware-Software Co-design</h3>
<p>Future developments focus on closer integration between architecture and hardware:</p>
<ul>
<li><strong>Custom silicon</strong>: Processors designed specifically for efficient neural networks</li>
<li><strong>Edge computing</strong>: Distributed processing across multiple devices</li>
<li><strong>Federated learning</strong>: Training updates without centralized data collection</li>
</ul>
</section>
<section id="automated-architecture-design" class="level3">
<h3 class="anchored" data-anchor-id="automated-architecture-design" id="automated-architecture-design">Automated Architecture Design</h3>
<p>Neural Architecture Search continues to evolve:</p>
<ul>
<li><strong>Differentiable NAS</strong>: More efficient architecture search methods</li>
<li><strong>Progressive search</strong>: Incremental architecture refinement</li>
<li><strong>Multi-objective optimization</strong>: Balancing multiple performance metrics</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>MobileNet represents a paradigm shift in neural network design, demonstrating that significant efficiency gains are possible without sacrificing too much accuracy. By introducing depthwise separable convolutions and providing tunable parameters for accuracy-efficiency trade-offs, MobileNet has enabled the deployment of sophisticated computer vision capabilities on resource-constrained devices.</p>
<p>The impact of MobileNet extends beyond its immediate applications. It has influenced a generation of efficient neural network architectures and sparked renewed interest in the optimization of deep learning models for practical deployment. As mobile devices become increasingly powerful and AI capabilities more ubiquitous, MobileNet’s principles continue to guide the development of efficient, deployable neural networks.</p>
<p>The evolution from MobileNet to MobileNetV2 and V3 demonstrates the ongoing refinement of these principles, incorporating advances in neural architecture search, attention mechanisms, and hardware-aware optimization. As we look to the future, MobileNet’s legacy lies not just in its specific architectural contributions, but in its demonstration that efficiency and accuracy need not be mutually exclusive in deep learning system design.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>For Practitioners
</div>
</div>
<div class="callout-body-container callout-body">
<p>For practitioners and researchers working on mobile AI applications, MobileNet provides both a practical solution and a blueprint for designing efficient neural networks. Its success underscores the importance of considering deployment constraints from the earliest stages of model design.</p>
</div>
</div>
<p>For practitioners and researchers working on mobile AI applications, MobileNet provides both a practical solution and a blueprint for designing efficient neural networks. Its success underscores the importance of considering deployment constraints from the earliest stages of model design, rather than treating optimization as an afterthought. As the field continues to evolve, the principles pioneered by MobileNet will undoubtedly continue to influence the development of efficient, practical AI systems.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[GShard: Scaling Giant Neural Networks with Conditional Computation]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/g-shard/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/g-shard/</guid>
      <pubDate>Tue, 15 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="gshard-scaling-giant-neural-networks-with-conditional-computation" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/g-shard/gshard.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>GShard represents a pivotal advancement in neural network scaling, introduced by Google Research in 2020. This innovative approach addresses one of the most pressing challenges in deep learning: how to scale neural networks to unprecedented sizes while maintaining computational efficiency. By leveraging sparsely-gated mixture-of-experts (MoE) and sophisticated parallelization strategies, GShard enables the training of models with trillions of parameters using conditional computation.</p>
<p>The significance of GShard extends beyond mere parameter scaling. It fundamentally changes how we think about model capacity, computational efficiency, and distributed training. Rather than activating all parameters for every input, GShard selectively activates only a subset of experts, allowing for massive models that remain computationally tractable during inference and training.</p>
</section>
<section id="background-and-motivation" class="level2">
<h2 class="anchored" data-anchor-id="background-and-motivation" id="background-and-motivation">Background and Motivation</h2>
<section id="the-scaling-challenge" class="level3">
<h3 class="anchored" data-anchor-id="the-scaling-challenge" id="the-scaling-challenge">The Scaling Challenge</h3>
<p>Traditional neural network scaling follows a straightforward principle: more parameters generally lead to better performance. However, this approach faces significant limitations as model sizes grow exponentially. Dense models require all parameters to be activated for every input, creating computational bottlenecks that become increasingly prohibitive as models scale to hundreds of billions or trillions of parameters.</p>
<p>The computational cost of training and inference scales linearly with model size in dense architectures. For a transformer model with N parameters, each forward pass requires O(N) operations, regardless of the input complexity. This relationship creates unsustainable resource requirements as models grow larger.</p>
</section>
<section id="conditional-computation-as-a-solution" class="level3">
<h3 class="anchored" data-anchor-id="conditional-computation-as-a-solution" id="conditional-computation-as-a-solution">Conditional Computation as a Solution</h3>
<p>Conditional computation offers an elegant solution to this scaling challenge. Instead of activating all parameters for every input, conditional computation selectively activates only relevant portions of the network. This approach allows for models with massive parameter counts while maintaining reasonable computational costs.</p>
<p>The mixture-of-experts paradigm serves as the foundation for GShard’s conditional computation approach. By decomposing the model into specialized expert networks and learning to route inputs to appropriate experts, GShard achieves sub-linear scaling of computational cost with respect to model size.</p>
</section>
</section>
<section id="gshard-architecture" class="level2">
<h2 class="anchored" data-anchor-id="gshard-architecture" id="gshard-architecture">GShard Architecture</h2>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<p>GShard’s architecture centers around several key innovations that work together to enable efficient scaling:</p>
<p><strong>Sparsely-Gated Mixture-of-Experts (MoE)</strong>: The fundamental building block of GShard replaces dense feed-forward layers in transformer architectures with MoE layers. Each MoE layer consists of multiple expert networks and a gating network that determines which experts to activate for each input.</p>
<p><strong>Expert Networks</strong>: Individual expert networks are typically simple feed-forward networks, similar to the feed-forward layers in standard transformers. The key difference lies in their selective activation rather than their architecture. Each expert specializes in processing certain types of inputs, though this specialization emerges naturally during training rather than being explicitly programmed.</p>
<p><strong>Gating Network</strong>: The gating network serves as the routing mechanism, determining which experts should process each input token. This network learns to make routing decisions based on the input representation, typically selecting only a small subset of available experts for each token.</p>
</section>
<section id="mixture-of-experts-implementation" class="level3">
<h3 class="anchored" data-anchor-id="mixture-of-experts-implementation" id="mixture-of-experts-implementation">Mixture-of-Experts Implementation</h3>
<p>The MoE layer in GShard operates through a sophisticated gating mechanism that balances computational efficiency with model expressiveness. For each input token, the gating network computes a probability distribution over all available experts. Rather than using all experts, GShard selects only the top-k experts (typically k=2) for each token, significantly reducing computational requirements.</p>
<p>The gating function can be expressed mathematically as:</p>
<p><span class="math display">\[G(x) = \text{Softmax}(x \cdot W_g)\]</span></p>
<p>Where <span class="math inline">\(x\)</span> represents the input token embedding and <span class="math inline">\(W_g\)</span> represents the learned gating weights. The top-k selection mechanism ensures that only the most relevant experts are activated, while the softmax normalization maintains proper probability distributions.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Load Balancing
</div>
</div>
<div class="callout-body-container callout-body">
<p>One critical challenge in MoE architectures is ensuring balanced load distribution across experts. Without proper load balancing, some experts may receive disproportionately more training examples, leading to underutilization of model capacity. GShard addresses this through auxiliary loss functions that encourage balanced expert utilization.</p>
</div>
</div>
<p><strong>Expert Capacity</strong>: To prevent memory overflow and ensure predictable computational costs, GShard implements expert capacity limits. Each expert can process a maximum number of tokens per batch, with overflow tokens either dropped or routed to alternative experts.</p>
</section>
<section id="parallelization-strategy" class="level3">
<h3 class="anchored" data-anchor-id="parallelization-strategy" id="parallelization-strategy">Parallelization Strategy</h3>
<p>GShard’s parallelization approach represents a significant departure from traditional data parallelism. The system employs a hybrid strategy that combines expert parallelism with data parallelism to efficiently distribute computation across multiple devices.</p>
<p><strong>Expert Parallelism</strong>: Different experts are placed on different devices, allowing for parallel processing of different expert computations. This approach scales naturally with the number of experts and available devices.</p>
<p><strong>Data Parallelism</strong>: Within each expert, traditional data parallelism is employed to process multiple examples simultaneously. This hybrid approach maximizes hardware utilization while maintaining efficient communication patterns.</p>
<p><strong>Communication Optimization</strong>: The routing of tokens to experts requires careful communication optimization. GShard implements efficient all-to-all communication patterns that minimize the overhead of token routing across devices.</p>
</section>
</section>
<section id="training-methodology" class="level2">
<h2 class="anchored" data-anchor-id="training-methodology" id="training-methodology">Training Methodology</h2>
<section id="distributed-training-challenges" class="level3">
<h3 class="anchored" data-anchor-id="distributed-training-challenges" id="distributed-training-challenges">Distributed Training Challenges</h3>
<p>Training GShard models presents unique challenges compared to traditional dense models. The sparse activation patterns create irregular communication requirements, and the load balancing constraints require careful optimization to prevent training instabilities.</p>
<p><strong>Gradient Synchronization</strong>: Unlike dense models where gradients can be synchronized using standard all-reduce operations, GShard requires more sophisticated gradient synchronization strategies. Only the experts that were activated during the forward pass need gradient updates, creating sparse gradient patterns that require efficient handling.</p>
<p><strong>Load Balancing During Training</strong>: Maintaining balanced expert utilization during training is crucial for model performance. GShard employs auxiliary loss functions that penalize imbalanced expert usage, encouraging the gating network to distribute load evenly across all experts.</p>
<p><strong>Stability Considerations</strong>: The discrete routing decisions in MoE architectures can create training instabilities. GShard addresses these challenges through careful initialization strategies, gradient clipping, and regularization techniques that promote stable training dynamics.</p>
</section>
<section id="optimization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="optimization-techniques" id="optimization-techniques">Optimization Techniques</h3>
<p>GShard incorporates several optimization techniques specifically designed for MoE architectures:</p>
<p><strong>Auxiliary Loss Functions</strong>: These loss functions encourage balanced expert utilization and prevent the collapse of expert diversity. The auxiliary loss is typically added to the main task loss with a small weighting factor.</p>
<p><strong>Expert Dropout</strong>: During training, GShard sometimes randomly drops entire experts to prevent over-reliance on specific experts and improve model robustness. This technique is analogous to traditional dropout but operates at the expert level.</p>
<p><strong>Capacity Factor Tuning</strong>: The capacity factor determines how many tokens each expert can process. Tuning this parameter involves balancing computational efficiency with model expressiveness, as higher capacity factors allow more flexible routing but increase computational costs.</p>
</section>
</section>
<section id="performance-analysis" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis" id="performance-analysis">Performance Analysis</h2>
<section id="computational-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h3>
<p>GShard’s primary advantage lies in its computational efficiency compared to dense models of equivalent parameter count. By activating only a subset of experts for each input, GShard achieves sub-linear scaling of computational cost with respect to model size.</p>
<p><strong>FLOPs Analysis</strong>: For a GShard model with E experts and top-k routing, the computational cost per token is approximately k/E times that of a dense model with equivalent total parameters. This represents a significant efficiency gain, especially as E increases.</p>
<p><strong>Memory Efficiency</strong>: While GShard models have large parameter counts, the memory requirements during inference are determined by the number of activated experts rather than the total parameter count. This allows for efficient deployment of very large models.</p>
<p><strong>Scaling Behavior</strong>: Empirical results demonstrate that GShard models can achieve better performance than dense models while using less computational resources. This scaling behavior enables the training of models that would be computationally prohibitive in dense architectures.</p>
</section>
<section id="quality-and-capability" class="level3">
<h3 class="anchored" data-anchor-id="quality-and-capability" id="quality-and-capability">Quality and Capability</h3>
<p>GShard has demonstrated impressive performance across various natural language processing tasks, particularly in machine translation and language modeling. The model’s ability to scale to trillions of parameters while maintaining computational efficiency has enabled breakthrough results in several domains.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Performance Metrics
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Translation Quality</strong>: GShard models have achieved state-of-the-art results on numerous machine translation benchmarks</li>
<li><strong>Language Modeling</strong>: Improved perplexity scores compared to dense models with equivalent computational budgets</li>
<li><strong>Generalization</strong>: Better generalization through expert specialization</li>
</ul>
</div>
</div>
<p><strong>Translation Quality</strong>: GShard models have achieved state-of-the-art results on numerous machine translation benchmarks, demonstrating that the MoE approach can effectively scale model capacity without sacrificing translation quality.</p>
<p><strong>Language Modeling</strong>: In language modeling tasks, GShard models have shown improved perplexity scores compared to dense models with equivalent computational budgets, indicating more efficient use of model capacity.</p>
<p><strong>Generalization</strong>: The sparse activation patterns in GShard models appear to promote better generalization, as different experts can specialize in different aspects of the input distribution.</p>
</section>
</section>
<section id="implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="implementation-details" id="implementation-details">Implementation Details</h2>
<section id="technical-architecture" class="level3">
<h3 class="anchored" data-anchor-id="technical-architecture" id="technical-architecture">Technical Architecture</h3>
<p>GShard’s implementation requires careful consideration of several technical aspects:</p>
<p><strong>Framework Integration</strong>: GShard builds upon the Mesh-TensorFlow framework, which provides the necessary infrastructure for efficient distributed training of MoE models. The framework handles the complex communication patterns required for expert routing and gradient synchronization.</p>
<p><strong>Device Placement</strong>: The placement of experts across devices requires careful planning to minimize communication overhead while maximizing computational efficiency. GShard employs sophisticated placement strategies that consider both computational load and communication patterns.</p>
<p><strong>Memory Management</strong>: Managing memory efficiently across experts requires careful attention to buffer sizes, expert capacities, and gradient accumulation strategies. GShard implements dynamic memory management techniques that adapt to varying load distributions.</p>
</section>
<section id="hyperparameter-considerations" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-considerations" id="hyperparameter-considerations">Hyperparameter Considerations</h3>
<p>Training GShard models requires careful tuning of several hyperparameters specific to MoE architectures:</p>
<div id="tbl-hyperparameters" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-hyperparameters-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Key hyperparameters for GShard training
</figcaption>
<div aria-describedby="tbl-hyperparameters-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 28%">
<col style="width: 38%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Parameter</th>
<th>Typical Range</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Number of Experts</td>
<td>8-2048</td>
<td>Affects model capacity and computational efficiency</td>
</tr>
<tr class="even">
<td>Capacity Factor</td>
<td>1.0-2.0</td>
<td>Determines tokens per expert</td>
</tr>
<tr class="odd">
<td>Auxiliary Loss Weight</td>
<td>0.01-0.1</td>
<td>Balances task performance and expert utilization</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p><strong>Number of Experts</strong>: The number of experts represents a fundamental design choice that affects both model capacity and computational efficiency. More experts provide greater capacity but require more sophisticated load balancing.</p>
<p><strong>Capacity Factor</strong>: This parameter determines how many tokens each expert can process and directly impacts both computational cost and model expressiveness. Typical values range from 1.0 to 2.0, with higher values allowing more flexible routing.</p>
<p><strong>Auxiliary Loss Weight</strong>: The weighting of auxiliary loss functions affects the balance between task performance and expert utilization. This parameter requires careful tuning to achieve optimal results.</p>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="machine-translation" class="level3">
<h3 class="anchored" data-anchor-id="machine-translation" id="machine-translation">Machine Translation</h3>
<p>GShard has demonstrated particular success in machine translation applications, where the model’s ability to scale to massive parameter counts has enabled breakthrough performance on challenging translation tasks.</p>
<p><strong>Multilingual Translation</strong>: GShard’s expert architecture naturally lends itself to multilingual translation, where different experts can specialize in different language pairs or linguistic phenomena. This specialization enables more efficient processing of diverse linguistic inputs.</p>
<p><strong>Low-Resource Languages</strong>: The increased model capacity provided by GShard has proven particularly beneficial for low-resource language translation, where the additional parameters can compensate for limited training data.</p>
<p><strong>Domain Adaptation</strong>: Different experts can specialize in different domains, allowing GShard models to handle diverse translation contexts more effectively than dense models.</p>
</section>
<section id="language-modeling" class="level3">
<h3 class="anchored" data-anchor-id="language-modeling" id="language-modeling">Language Modeling</h3>
<p>GShard has also shown impressive results in language modeling tasks, where the model’s ability to scale efficiently has enabled training of extremely large language models.</p>
<p><strong>Text Generation</strong>: The sparse activation patterns in GShard models appear to promote more diverse and coherent text generation, as different experts can specialize in different aspects of language generation.</p>
<p><strong>Few-Shot Learning</strong>: The increased model capacity provided by GShard has improved few-shot learning performance, enabling better adaptation to new tasks with minimal examples.</p>
<p><strong>Reasoning Tasks</strong>: GShard models have demonstrated improved performance on reasoning tasks that require complex logical operations, suggesting that the expert specialization enables more sophisticated reasoning capabilities.</p>
</section>
</section>
<section id="comparison-with-other-approaches" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-other-approaches" id="comparison-with-other-approaches">Comparison with Other Approaches</h2>
<section id="dense-models" class="level3">
<h3 class="anchored" data-anchor-id="dense-models" id="dense-models">Dense Models</h3>
<p>Compared to traditional dense models, GShard offers several key advantages:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Advantages</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Disadvantages</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li><strong>Computational Efficiency</strong>: Better performance per FLOP than dense models</li>
<li><strong>Scalability</strong>: Sub-linear scaling of computational cost with model size</li>
<li><strong>Specialization</strong>: Natural expert specialization improves performance on diverse tasks</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<ul>
<li><strong>Simplicity</strong>: Dense models are conceptually simpler and easier to implement</li>
<li><strong>Hardware Optimization</strong>: Existing optimizations are designed for dense computations</li>
<li><strong>Predictable Performance</strong>: Dense models have more predictable requirements</li>
</ul>
</div>
</div>
</div>
</section>
<section id="other-sparse-approaches" class="level3">
<h3 class="anchored" data-anchor-id="other-sparse-approaches" id="other-sparse-approaches">Other Sparse Approaches</h3>
<p>GShard represents one approach to sparse neural networks, but several alternative methods exist:</p>
<p><strong>Magnitude-Based Pruning</strong>: Traditional pruning approaches remove weights based on magnitude, but these methods typically don’t achieve the same level of sparsity as GShard while maintaining performance.</p>
<p><strong>Structured Sparsity</strong>: Other approaches enforce structured sparsity patterns that are more hardware-friendly but may be less flexible than GShard’s learned sparsity.</p>
<p><strong>Dynamic Sparsity</strong>: Some approaches learn to dynamically adjust sparsity patterns during training, offering different trade-offs between flexibility and efficiency.</p>
</section>
</section>
<section id="limitations-and-challenges" class="level2">
<h2 class="anchored" data-anchor-id="limitations-and-challenges" id="limitations-and-challenges">Limitations and Challenges</h2>
<section id="technical-limitations" class="level3">
<h3 class="anchored" data-anchor-id="technical-limitations" id="technical-limitations">Technical Limitations</h3>
<p>Despite its advantages, GShard faces several technical limitations:</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Key Limitations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Communication Overhead</strong>: All-to-all communication can become a bottleneck</li>
<li><strong>Load Balancing Complexity</strong>: Requires sophisticated auxiliary loss functions</li>
<li><strong>Hardware Utilization</strong>: Irregular computation patterns may lead to suboptimal hardware use</li>
<li><strong>Debugging Complexity</strong>: Sparse activation patterns make analysis challenging</li>
</ul>
</div>
</div>
</section>
<section id="scaling-challenges" class="level3">
<h3 class="anchored" data-anchor-id="scaling-challenges" id="scaling-challenges">Scaling Challenges</h3>
<p>As GShard models scale to larger sizes, several challenges emerge:</p>
<p><strong>Expert Utilization</strong>: Ensuring efficient utilization of all experts becomes increasingly difficult as the number of experts grows.</p>
<p><strong>Communication Scaling</strong>: The communication requirements for expert routing may not scale favorably with very large numbers of experts.</p>
<p><strong>Memory Constraints</strong>: While GShard is more memory-efficient than dense models, very large models still face memory limitations, especially during training.</p>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="architectural-innovations" class="level3">
<h3 class="anchored" data-anchor-id="architectural-innovations" id="architectural-innovations">Architectural Innovations</h3>
<p>Several promising directions for future development include:</p>
<p><strong>Hierarchical Experts</strong>: Organizing experts in hierarchical structures could improve routing efficiency and enable more sophisticated specialization patterns.</p>
<p><strong>Dynamic Expert Creation</strong>: Allowing the model to dynamically create new experts during training could improve adaptability to new tasks and domains.</p>
<p><strong>Cross-Layer Expert Sharing</strong>: Sharing experts across different layers could reduce parameter counts while maintaining model expressiveness.</p>
</section>
<section id="optimization-improvements" class="level3">
<h3 class="anchored" data-anchor-id="optimization-improvements" id="optimization-improvements">Optimization Improvements</h3>
<p>Future work could focus on improving the optimization of GShard models:</p>
<p><strong>Better Load Balancing</strong>: Developing more sophisticated load balancing techniques could improve expert utilization and model performance.</p>
<p><strong>Adaptive Routing</strong>: Learning to adaptively adjust routing strategies based on input characteristics could improve efficiency.</p>
<p><strong>Hardware-Aware Design</strong>: Designing MoE architectures that are more compatible with existing hardware could improve practical deployment.</p>
</section>
<section id="applications-and-domains" class="level3">
<h3 class="anchored" data-anchor-id="applications-and-domains" id="applications-and-domains">Applications and Domains</h3>
<p>GShard’s approach could be extended to new domains and applications:</p>
<p><strong>Computer Vision</strong>: Adapting MoE architectures for computer vision tasks could enable more efficient processing of visual data.</p>
<p><strong>Multimodal Learning</strong>: Combining GShard with multimodal architectures could enable more efficient processing of diverse input types.</p>
<p><strong>Reinforcement Learning</strong>: Applying MoE principles to reinforcement learning could enable more efficient learning in complex environments.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>GShard represents a significant breakthrough in neural network scaling, demonstrating that it’s possible to train models with trillions of parameters while maintaining computational efficiency. The combination of sparsely-gated mixture-of-experts with sophisticated parallelization strategies has opened new possibilities for model scaling that were previously computationally prohibitive.</p>
<p>The success of GShard has fundamental implications for the future of deep learning. It suggests that the path to more capable AI systems may lie not just in scaling model size, but in developing more efficient architectures that can leverage massive parameter counts through conditional computation.</p>
<p>While GShard faces certain limitations and challenges, its core innovations have established a new paradigm for neural network architecture design. The principles underlying GShard—sparse activation, expert specialization, and efficient parallelization—are likely to influence future developments in large-scale machine learning.</p>
<p>As the field continues to evolve, GShard’s contributions to our understanding of scalable neural architectures will undoubtedly continue to shape the development of increasingly capable and efficient AI systems. The model’s demonstration that trillion-parameter models can be both practical and effective has fundamentally changed our perspective on what’s possible in neural network scaling.</p>
<p>The ongoing research building upon GShard’s foundations promises to unlock even greater capabilities in artificial intelligence, potentially leading to systems that can process and understand information at unprecedented scales while remaining computationally efficient. This balance between scale and efficiency represents a crucial step toward more practical and deployable AI systems that can benefit a broader range of applications and users.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Mixture of Experts: A Deep Overview]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/mixture-of-experts/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/mixture-of-experts/</guid>
      <pubDate>Tue, 15 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="mixture-of-experts-a-deep-overview" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/mixture-of-experts/moe.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Mixture of Experts (MoE) represents a fundamental paradigm shift in machine learning architecture design, offering a scalable approach to building models that can handle complex, heterogeneous tasks while maintaining computational efficiency. This architectural pattern has gained significant traction in recent years, particularly in the realm of large language models and neural networks, where the ability to scale model capacity without proportionally increasing computational costs has become paramount.</p>
<p>The core insight behind MoE lies in the principle of specialization: rather than training a single monolithic model to handle all aspects of a task, we can train multiple specialized “expert” models, each focusing on different aspects or subdomains of the problem space. A gating mechanism then learns to route inputs to the most appropriate experts, creating a system that can be both highly specialized and broadly capable.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>The fundamental principle of MoE is <strong>specialization</strong>: multiple expert models focus on different aspects of a problem, coordinated by a learned gating mechanism.</p>
</div>
</div>
</section>
<section id="historical-context-and-evolution" class="level2">
<h2 class="anchored" data-anchor-id="historical-context-and-evolution" id="historical-context-and-evolution">Historical Context and Evolution</h2>
<p>The concept of mixture models has deep roots in statistics and machine learning, dating back to the 1960s with early work on mixture distributions. However, the specific formulation of Mixture of Experts as we understand it today emerged in the 1990s through the pioneering work of researchers like Robert Jacobs, Steven Nowlan, and Geoffrey Hinton.</p>
<p>The original MoE framework was motivated by the observation that many learning problems naturally decompose into subproblems that might be better solved by different models. For instance, in a classification task involving multiple classes, different regions of the input space might benefit from different decision boundaries or feature representations. This led to the development of the classical MoE architecture, which combined multiple expert networks with a gating network that learned to weight their contributions.</p>
<section id="modern-resurgence" class="level3">
<h3 class="anchored" data-anchor-id="modern-resurgence" id="modern-resurgence">Modern Resurgence</h3>
<p>The resurgence of interest in MoE architectures in recent years can be attributed to several factors:</p>
<ul>
<li><strong>Model scaling challenges</strong>: The explosion in model sizes, particularly in NLP</li>
<li><strong>Computational efficiency</strong>: Need for sublinear scaling methods</li>
<li><strong>Hardware improvements</strong>: Better support for sparse computation</li>
<li><strong>Theoretical advances</strong>: Better understanding of training dynamics</li>
</ul>
</section>
</section>
<section id="fundamental-architecture" class="level2">
<h2 class="anchored" data-anchor-id="fundamental-architecture" id="fundamental-architecture">Fundamental Architecture</h2>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<p>The MoE architecture consists of three fundamental components that work in concert to create a flexible and efficient learning system.</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Expert Networks</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Gating Network</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Combination Mechanism</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p><strong>Expert Networks</strong> form the foundation of the MoE system. These are typically neural networks, though they can be any differentiable function approximator. Each expert is designed to become specialized in handling specific types of inputs or solving particular aspects of the overall task.</p>
<p>Key characteristics:</p>
<ul>
<li>Can be identical in architecture but differ in parameters</li>
<li>May have fundamentally different architectures</li>
<li>Optimize for different input patterns or computational requirements</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p><strong>Gating Network</strong> serves as the routing mechanism that determines which experts should be activated for a given input. This network learns to predict the probability distribution over experts, effectively learning which expert or combination of experts is most likely to produce the best output.</p>
<p>Objectives:</p>
<ul>
<li>Route inputs to appropriate experts</li>
<li>Balance computational load across experts</li>
<li>Maintain end-to-end trainability</li>
</ul>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p><strong>Combination Mechanism</strong> determines how outputs from multiple experts are combined to produce the final prediction. The most common approach is a weighted combination, where the gating network’s output serves as the weights.</p>
<p>Approaches:</p>
<ul>
<li>Weighted combination (most common)</li>
<li>Attention-based mechanisms</li>
<li>Learned combination functions</li>
</ul>
</div>
</div>
</div>
</section>
<section id="mathematical-formulation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-formulation" id="mathematical-formulation">Mathematical Formulation</h3>
<p>The mathematical foundation of MoE can be expressed elegantly through probabilistic modeling. Given an input vector <span class="math inline">\(\mathbf{x}\)</span>, the MoE model computes its output as:</p>
<p><span class="math display">\[\mathbf{y} = \sum_{i=1}^{N} g_i(\mathbf{x}) \cdot E_i(\mathbf{x})\]</span></p>
<p>Where: - <span class="math inline">\(g_i(\mathbf{x})\)</span> represents the gating function’s output for expert <span class="math inline">\(i\)</span> - <span class="math inline">\(E_i(\mathbf{x})\)</span> represents the output of expert <span class="math inline">\(i\)</span></p>
<p>The gating function typically uses a softmax activation:</p>
<p><span class="math display">\[g_i(\mathbf{x}) = \frac{\exp(\mathbf{W}_g \mathbf{x} + \mathbf{b}_g)_i}{\sum_{j=1}^{N} \exp(\mathbf{W}_g \mathbf{x} + \mathbf{b}_g)_j}\]</span></p>
<p>The training objective includes multiple components:</p>
<p><span class="math display">\[\mathcal{L} = \mathcal{L}_{\text{prediction}} + \lambda \mathcal{L}_{\text{load balancing}} + \mu \mathcal{L}_{\text{expert regularization}}\]</span></p>
<div id="b14423cf" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Example MoE Implementation</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MixtureOfExperts(nn.Module):</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_dim, hidden_dim, output_dim, num_experts, top_k<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_experts <span class="op">=</span> num_experts</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.top_k <span class="op">=</span> top_k</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gating network</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.gate <span class="op">=</span> nn.Linear(input_dim, num_experts)</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Expert networks</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.experts <span class="op">=</span> nn.ModuleList([</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>            nn.Sequential(</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>                nn.Linear(input_dim, hidden_dim),</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>                nn.ReLU(),</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>                nn.Linear(hidden_dim, output_dim)</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>            ) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_experts)</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gating scores</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>        gate_scores <span class="op">=</span> <span class="va">self</span>.gate(x)</span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>        gate_probs <span class="op">=</span> F.softmax(gate_scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select top-k experts</span></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>        top_k_probs, top_k_indices <span class="op">=</span> torch.topk(gate_probs, <span class="va">self</span>.top_k, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize top-k probabilities</span></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>        top_k_probs <span class="op">=</span> top_k_probs <span class="op">/</span> top_k_probs.<span class="bu">sum</span>(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute expert outputs</span></span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>        expert_outputs <span class="op">=</span> []</span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, expert <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.experts):</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>            expert_outputs.append(expert(x))</span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine outputs</span></span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.zeros_like(expert_outputs[<span class="dv">0</span>])</span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.top_k):</span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a>            expert_idx <span class="op">=</span> top_k_indices[:, i]</span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a>            weight <span class="op">=</span> top_k_probs[:, i].unsqueeze(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j, expert_output <span class="kw">in</span> <span class="bu">enumerate</span>(expert_outputs):</span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a>                mask <span class="op">=</span> (expert_idx <span class="op">==</span> j).<span class="bu">float</span>().unsqueeze(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>                output <span class="op">+=</span> weight <span class="op">*</span> mask <span class="op">*</span> expert_output</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</details>
</div>
</section>
</section>
<section id="training-dynamics-and-optimization" class="level2">
<h2 class="anchored" data-anchor-id="training-dynamics-and-optimization" id="training-dynamics-and-optimization">Training Dynamics and Optimization</h2>
<p>Training MoE systems presents unique challenges that distinguish it from traditional neural network training. The primary challenge lies in the discrete nature of expert selection combined with the need for end-to-end differentiable training.</p>
<section id="gradient-flow-and-backpropagation" class="level3">
<h3 class="anchored" data-anchor-id="gradient-flow-and-backpropagation" id="gradient-flow-and-backpropagation">Gradient Flow and Backpropagation</h3>
<p>The gating mechanism creates a complex gradient flow pattern. When the gating network routes an input primarily to a subset of experts, the gradients flow mainly through those active experts. This can lead to training instabilities where some experts receive very few training examples, potentially leading to underfitting, while others become overutilized.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Training Challenge
</div>
</div>
<div class="callout-body-container callout-body">
<p>The soft gating approach helps mitigate gradient flow issues but increases computational overhead as multiple experts must be evaluated for each input.</p>
</div>
</div>
</section>
<section id="load-balancing-and-expert-utilization" class="level3">
<h3 class="anchored" data-anchor-id="load-balancing-and-expert-utilization" id="load-balancing-and-expert-utilization">Load Balancing and Expert Utilization</h3>
<p>One of the most critical challenges in MoE training is ensuring balanced utilization of experts. Without proper load balancing, the system may collapse to using only a few experts, essentially reducing the model to a smaller capacity system.</p>
<p><strong>Solutions for load balancing:</strong></p>
<ol type="1">
<li><strong>Auxiliary losses</strong> that penalize uneven expert utilization</li>
<li><strong>Noise injection</strong> in the gating network to encourage exploration</li>
<li><strong>Curriculum learning</strong> approaches for gradual expert specialization</li>
</ol>
</section>
<section id="sparsity-and-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="sparsity-and-efficiency" id="sparsity-and-efficiency">Sparsity and Efficiency</h3>
<p>A key advantage of MoE systems is their ability to maintain sparsity during inference. By activating only a subset of experts for each input, computational cost can be kept relatively low even as the total number of parameters increases.</p>
<p>The choice of <span class="math inline">\(k\)</span> in top-<span class="math inline">\(k\)</span> gating represents a fundamental trade-off:</p>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Small <span class="math inline">\(k\)</span></th>
<th>Large <span class="math inline">\(k\)</span></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>More efficient inference</td>
<td>Higher computational cost</td>
</tr>
<tr class="even">
<td>Limited expressiveness</td>
<td>Greater model capacity</td>
</tr>
<tr class="odd">
<td>Faster training</td>
<td>More complex optimization</td>
</tr>
</tbody>
</table>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="natural-language-processing" class="level3">
<h3 class="anchored" data-anchor-id="natural-language-processing" id="natural-language-processing">Natural Language Processing</h3>
<p>MoE has found particularly strong application in natural language processing, where the heterogeneous nature of language tasks makes expert specialization highly beneficial. Large language models like GPT-3 and subsequent models have incorporated MoE architectures to scale to trillions of parameters while maintaining reasonable computational costs.</p>
<p><strong>Expert specialization in NLP:</strong></p>
<ul>
<li>Syntactic constructions</li>
<li>Numerical information processing</li>
<li>Domain-specific terminology</li>
<li>Language-specific patterns (in multilingual models)</li>
</ul>
</section>
<section id="computer-vision" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision" id="computer-vision">Computer Vision</h3>
<p>In computer vision, MoE architectures have been applied to tasks ranging from image classification to object detection and segmentation. The visual domain’s inherent structure makes it well-suited for expert specialization.</p>
<p><strong>Applications in vision:</strong></p>
<ul>
<li>Object detection with size/category-specific experts</li>
<li>Image segmentation with boundary/texture specialists</li>
<li>Vision transformers with spatial attention experts</li>
</ul>
</section>
<section id="multimodal-learning" class="level3">
<h3 class="anchored" data-anchor-id="multimodal-learning" id="multimodal-learning">Multimodal Learning</h3>
<p>MoE architectures are particularly well-suited for multimodal learning tasks, where inputs might come from different modalities (text, images, audio, etc.). Different experts can specialize in processing different modalities or in handling the fusion of information across modalities.</p>
</section>
</section>
<section id="advanced-techniques-and-variants" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques-and-variants" id="advanced-techniques-and-variants">Advanced Techniques and Variants</h2>
<section id="hierarchical-mixture-of-experts" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-mixture-of-experts" id="hierarchical-mixture-of-experts">Hierarchical Mixture of Experts</h3>
<p>Hierarchical MoE extends the basic MoE concept by organizing experts in a tree-like structure. This approach allows for more efficient routing and can capture hierarchical patterns in the data.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A[Input] --&gt; B[Level 1 Gate]
    B --&gt; C[Expert Cluster 1]
    B --&gt; D[Expert Cluster 2]
    B --&gt; E[Expert Cluster 3]
    C --&gt; F[Expert 1.1]
    C --&gt; G[Expert 1.2]
    D --&gt; H[Expert 2.1]
    D --&gt; I[Expert 2.2]
    E --&gt; J[Expert 3.1]
    E --&gt; K[Expert 3.2]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="sparse-mixture-of-experts" class="level3">
<h3 class="anchored" data-anchor-id="sparse-mixture-of-experts" id="sparse-mixture-of-experts">Sparse Mixture of Experts</h3>
<p>Sparse MoE focuses on maximizing the efficiency benefits of expert sparsity. These systems typically activate only a very small fraction of available experts for each input.</p>
<p><strong>Example: Switch Transformer</strong></p>
<ul>
<li>Activates only one expert per input</li>
<li>Enables very efficient scaling</li>
<li>Requires careful design for single-expert effectiveness</li>
</ul>
</section>
<section id="adaptive-mixture-of-experts" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-mixture-of-experts" id="adaptive-mixture-of-experts">Adaptive Mixture of Experts</h3>
<p>Adaptive MoE systems dynamically adjust their architecture based on input or task requirements:</p>
<ul>
<li>Dynamic expert count adjustment</li>
<li>Architecture modification based on context</li>
<li>Computational resource adaptation</li>
</ul>
</section>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<section id="training-stability" class="level3">
<h3 class="anchored" data-anchor-id="training-stability" id="training-stability">Training Stability</h3>
<p>Training MoE systems can be significantly more challenging than training traditional neural networks. The interaction between the gating network and expert networks creates a complex optimization landscape.</p>
<p><strong>Common issues:</strong></p>
<ul>
<li>Mode collapse (using only subset of experts)</li>
<li>Gradient flow problems</li>
<li>Training instabilities</li>
</ul>
</section>
<section id="computational-overhead" class="level3">
<h3 class="anchored" data-anchor-id="computational-overhead" id="computational-overhead">Computational Overhead</h3>
<p>While MoE systems can achieve sublinear scaling in terms of computational cost per parameter, they often have higher absolute computational costs than smaller traditional models.</p>
<p><strong>Overhead sources:</strong></p>
<ul>
<li>Gating network computation</li>
<li>Multiple expert evaluation</li>
<li>Memory requirements for all expert parameters</li>
</ul>
</section>
<section id="expert-specialization-vs.-generalization" class="level3">
<h3 class="anchored" data-anchor-id="expert-specialization-vs.-generalization" id="expert-specialization-vs.-generalization">Expert Specialization vs.&nbsp;Generalization</h3>
<p>The balance between expert specialization and generalization represents a fundamental challenge in MoE design. This is particularly acute in dynamic environments where the input distribution may shift over time.</p>
</section>
</section>
<section id="recent-developments-and-state-of-the-art" class="level2">
<h2 class="anchored" data-anchor-id="recent-developments-and-state-of-the-art" id="recent-developments-and-state-of-the-art">Recent Developments and State-of-the-Art</h2>
<section id="large-scale-language-models" class="level3">
<h3 class="anchored" data-anchor-id="large-scale-language-models" id="large-scale-language-models">Large-Scale Language Models</h3>
<p>The most prominent recent application of MoE has been in large-scale language models:</p>
<ul>
<li><strong>PaLM</strong>: Pathways Language Model with MoE scaling</li>
<li><strong>GLaM</strong>: Generalist Language Model with efficient MoE</li>
<li><strong>GPT variants</strong>: Various GPT models with MoE components</li>
</ul>
</section>
<section id="efficient-training-methods" class="level3">
<h3 class="anchored" data-anchor-id="efficient-training-methods" id="efficient-training-methods">Efficient Training Methods</h3>
<p>Recent research has focused on developing more efficient training methods:</p>
<ul>
<li>Better load balancing techniques</li>
<li>More stable training procedures</li>
<li>Reduced gating mechanism overhead</li>
<li>Expert parallelism for distributed training</li>
</ul>
</section>
<section id="integration-with-other-techniques" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-other-techniques" id="integration-with-other-techniques">Integration with Other Techniques</h3>
<p>MoE is increasingly being combined with other advanced techniques:</p>
<ul>
<li>Attention mechanisms</li>
<li>Normalization methods</li>
<li>Architectural innovations</li>
<li>Transformer architectures</li>
</ul>
</section>
</section>
<section id="future-directions-and-research-opportunities" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-research-opportunities" id="future-directions-and-research-opportunities">Future Directions and Research Opportunities</h2>
<section id="automated-expert-design" class="level3">
<h3 class="anchored" data-anchor-id="automated-expert-design" id="automated-expert-design">Automated Expert Design</h3>
<p>Current MoE systems typically use manually designed expert architectures. Future research directions include:</p>
<ul>
<li>Neural architecture search for MoE</li>
<li>Task-specific expert design</li>
<li>Automated capacity allocation</li>
</ul>
</section>
<section id="dynamic-expert-creation" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-expert-creation" id="dynamic-expert-creation">Dynamic Expert Creation</h3>
<p>Rather than having a fixed set of experts, future systems might:</p>
<ul>
<li>Dynamically create and remove experts</li>
<li>Adapt to evolving task requirements</li>
<li>Respond to changing data distributions</li>
</ul>
</section>
<section id="theoretical-understanding" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-understanding" id="theoretical-understanding">Theoretical Understanding</h3>
<p>Despite practical success, theoretical understanding remains limited:</p>
<ul>
<li>When and why MoE systems work well</li>
<li>Optimal design principles</li>
<li>Convergence guarantees</li>
<li>Generalization bounds</li>
</ul>
</section>
<section id="hardware-co-design" class="level3">
<h3 class="anchored" data-anchor-id="hardware-co-design" id="hardware-co-design">Hardware Co-design</h3>
<p>The unique computational patterns of MoE systems suggest opportunities for specialized hardware:</p>
<ul>
<li>MoE-optimized processors</li>
<li>Efficient sparse computation</li>
<li>Memory hierarchy optimization</li>
<li>Distributed computing architectures</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Mixture of Experts represents a powerful paradigm for building scalable and efficient machine learning systems. By leveraging the principle of specialization, MoE systems can achieve remarkable performance while maintaining computational efficiency.</p>
<p><strong>Key takeaways:</strong></p>
<ol type="1">
<li><strong>Scalability</strong>: MoE enables sublinear scaling of computational cost with model capacity</li>
<li><strong>Specialization</strong>: Expert networks can focus on specific aspects of complex tasks</li>
<li><strong>Efficiency</strong>: Sparse activation patterns reduce computational overhead</li>
<li><strong>Challenges</strong>: Training stability and load balancing remain significant hurdles</li>
<li><strong>Future potential</strong>: Continued innovation in architectures, training methods, and hardware</li>
</ol>
<p>The success of MoE in recent large-scale language models demonstrates its potential for enabling the next generation of AI systems. As our understanding deepens and techniques improve, MoE will likely play an increasingly important role in advanced AI system development across diverse domains.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Looking Forward
</div>
</div>
<div class="callout-body-container callout-body">
<p>The combination of MoE with other advanced techniques and the development of specialized hardware will likely drive continued innovation in this space, making AI systems both more capable and more efficient.</p>
</div>
</div>
<hr>
<p><em>This document provides a comprehensive overview of Mixture of Experts architectures, from theoretical foundations to practical applications and future directions. For the latest developments in this rapidly evolving field, readers are encouraged to consult recent research publications and conference proceedings.</em></p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Switch Transformer: Scaling Neural Networks with Sparsity]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/switch-transformer/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/switch-transformer/</guid>
      <pubDate>Tue, 15 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="switch-transformer-scaling-neural-networks-with-sparsity" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/switch-transformer/switch.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The Switch Transformer represents a groundbreaking advancement in neural network architecture, introduced by Google Research in 2021. This innovative model addresses one of the most pressing challenges in deep learning: how to scale neural networks to unprecedented sizes while maintaining computational efficiency. By leveraging the concept of sparsity and expert routing, Switch Transformer achieves remarkable performance improvements with fewer computational resources per token.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>Not all parts of a neural network need to be active for every input. Switch Transformer employs a sparse approach where only a subset of the model’s parameters are activated for each token.</p>
</div>
</div>
<p>The key insight behind Switch Transformer is that not all parts of a neural network need to be active for every input. Instead of using dense computations across the entire network, Switch Transformer employs a sparse approach where only a subset of the model’s parameters are activated for each token, dramatically improving efficiency while scaling to trillions of parameters.</p>
</section>
<section id="background-and-motivation" class="level2">
<h2 class="anchored" data-anchor-id="background-and-motivation" id="background-and-motivation">Background and Motivation</h2>
<section id="the-scaling-challenge" class="level3">
<h3 class="anchored" data-anchor-id="the-scaling-challenge" id="the-scaling-challenge">The Scaling Challenge</h3>
<p>Traditional transformer models face a fundamental trade-off between model capacity and computational efficiency. While larger models generally perform better, they require exponentially more computational resources. For instance, GPT-3 with 175 billion parameters requires enormous computational power for both training and inference, making it accessible only to organizations with substantial resources.</p>
</section>
<section id="mixture-of-experts-moe-foundation" class="level3">
<h3 class="anchored" data-anchor-id="mixture-of-experts-moe-foundation" id="mixture-of-experts-moe-foundation">Mixture of Experts (MoE) Foundation</h3>
<p>Switch Transformer builds upon the Mixture of Experts (MoE) paradigm, which has been explored in various forms since the 1990s. The core idea is to have multiple specialized “expert” networks, with a gating mechanism that determines which experts should process each input. This approach allows for increased model capacity without proportionally increasing computational cost.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Previous MoE Challenges
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Complex routing algorithms</li>
<li>Training instability</li>
<li>Load balancing issues</li>
<li>Difficulty in scaling to very large numbers of experts</li>
</ul>
</div>
</div>
<p>Switch Transformer addresses these limitations through elegant simplifications and innovations.</p>
</section>
</section>
<section id="architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h2>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<p>The Switch Transformer architecture consists of several key components that work together to achieve efficient sparse computation:</p>
<section id="switch-layer" class="level4">
<h4 class="anchored" data-anchor-id="switch-layer">Switch Layer</h4>
<p>The fundamental building block of Switch Transformer is the Switch Layer, which replaces the traditional feed-forward network (FFN) in transformer blocks. Each Switch Layer contains multiple expert networks, typically implemented as separate FFN modules.</p>
</section>
<section id="switch-routing" class="level4">
<h4 class="anchored" data-anchor-id="switch-routing">Switch Routing</h4>
<p>The routing mechanism is dramatically simplified compared to previous MoE approaches. Instead of complex routing algorithms, Switch Transformer uses a straightforward approach:</p>
<ul>
<li>Each token is routed to exactly one expert</li>
<li>The routing decision is made by a learned gating function</li>
<li>This “hard routing” approach eliminates the need for complex load balancing</li>
</ul>
</section>
<section id="expert-networks" class="level4">
<h4 class="anchored" data-anchor-id="expert-networks">Expert Networks</h4>
<p>Expert networks are individual feed-forward networks that specialize in processing specific types of inputs. Each expert has the same architecture as a standard transformer FFN but develops specialized representations during training.</p>
</section>
</section>
<section id="mathematical-foundation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-foundation" id="mathematical-foundation">Mathematical Foundation</h3>
<p>The Switch Transformer routing can be expressed mathematically as:</p>
<div id="31bba9c7" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Switch Transformer routing function</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> switch_routing(x, experts, gating_function):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co">    y = Switch(x) = Σ(i=1 to N) G(x)_i * E_i(x)</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Where:</span></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="co">    - x is the input token</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co">    - N is the number of experts</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="co">    - G(x)_i is the gating function output for expert i</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="co">    - E_i(x) is the output of expert i</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>    N <span class="op">=</span> <span class="bu">len</span>(experts)</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    gating_weights <span class="op">=</span> gating_function(x)</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Key innovation: sparse output where only one expert gets non-zero weight</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>    selected_expert <span class="op">=</span> argmax(gating_weights)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> experts[selected_expert](x)</span></code></pre></div></div>
</div>
<p>The key innovation is that <code>G(x)</code> produces a sparse output where only one expert receives a non-zero weight, simplifying computation significantly.</p>
</section>
</section>
<section id="key-innovations" class="level2">
<h2 class="anchored" data-anchor-id="key-innovations" id="key-innovations">Key Innovations</h2>
<section id="simplified-routing-algorithm" class="level3">
<h3 class="anchored" data-anchor-id="simplified-routing-algorithm" id="simplified-routing-algorithm">Simplified Routing Algorithm</h3>
<p>Switch Transformer introduces a dramatically simplified routing mechanism:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Traditional MoE Routing</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Switch Routing</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li>Complex gating functions</li>
<li>Multiple experts per token</li>
<li>Soft routing with weighted combinations</li>
<li>Difficult load balancing</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<ul>
<li>Single expert per token</li>
<li>Hard routing decisions</li>
<li>Simple argmax selection</li>
<li>Natural load distribution</li>
</ul>
</div>
</div>
</div>
<p>This simplification reduces computational overhead while maintaining the benefits of expert specialization.</p>
</section>
<section id="expert-capacity-and-load-balancing" class="level3">
<h3 class="anchored" data-anchor-id="expert-capacity-and-load-balancing" id="expert-capacity-and-load-balancing">Expert Capacity and Load Balancing</h3>
<p>One of the most innovative aspects of Switch Transformer is its approach to load balancing:</p>
<section id="capacity-factor" class="level4">
<h4 class="anchored" data-anchor-id="capacity-factor">Capacity Factor</h4>
<p>The model uses a capacity factor to determine how many tokens each expert can process. This is calculated as:</p>
<div id="9ddd61fe" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> calculate_expert_capacity(tokens_per_batch, num_experts, capacity_factor):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Expert Capacity = (tokens_per_batch / num_experts) * capacity_factor</span></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (tokens_per_batch <span class="op">/</span> num_experts) <span class="op">*</span> capacity_factor</span></code></pre></div></div>
</div>
</section>
<section id="auxiliary-loss" class="level4">
<h4 class="anchored" data-anchor-id="auxiliary-loss">Auxiliary Loss</h4>
<p>To encourage balanced routing, Switch Transformer employs an auxiliary loss function that penalizes uneven distribution of tokens across experts:</p>
<div id="140facb3" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> auxiliary_loss(expert_frequencies, expert_probabilities, alpha<span class="op">=</span><span class="fl">0.01</span>):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co">    L_aux = α * Σ(i=1 to N) f_i * P_i</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co">    </span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Where f_i is the fraction of tokens routed to expert i,</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="co">    and P_i is the probability mass for expert i.</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> alpha <span class="op">*</span> <span class="bu">sum</span>(f <span class="op">*</span> p <span class="cf">for</span> f, p <span class="kw">in</span> <span class="bu">zip</span>(expert_frequencies, expert_probabilities))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="selective-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="selective-precision-training" id="selective-precision-training">Selective Precision Training</h3>
<p>Switch Transformer introduces selective precision training, where different components use different numerical precisions:</p>
<ul>
<li>Router computations use float32 for stability</li>
<li>Expert computations can use lower precision (bfloat16)</li>
<li>This approach balances training stability with computational efficiency</li>
</ul>
</section>
</section>
<section id="technical-implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="technical-implementation-details" id="technical-implementation-details">Technical Implementation Details</h2>
<section id="training-considerations" class="level3">
<h3 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h3>
<p>Training Switch Transformer models requires careful consideration of several factors:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Training Best Practices
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Initialization Strategy</strong>
<ul>
<li>Experts are initialized with small random weights</li>
<li>Router weights are initialized to produce uniform distributions</li>
<li>Proper initialization is crucial for achieving expert specialization</li>
</ul></li>
<li><strong>Regularization Techniques</strong>
<ul>
<li>Dropout is applied within expert networks</li>
<li>Auxiliary loss provides implicit regularization</li>
<li>Expert dropout can be used to improve robustness</li>
</ul></li>
<li><strong>Distributed Training</strong>
<ul>
<li>Experts can be distributed across different machines</li>
<li>All-to-all communication patterns are used for token routing</li>
<li>Careful attention to communication efficiency is required</li>
</ul></li>
</ol>
</div>
</div>
</section>
<section id="inference-optimization" class="level3">
<h3 class="anchored" data-anchor-id="inference-optimization" id="inference-optimization">Inference Optimization</h3>
<p>Inference with Switch Transformer models involves several optimizations:</p>
<section id="expert-caching" class="level4">
<h4 class="anchored" data-anchor-id="expert-caching">Expert Caching</h4>
<ul>
<li>Frequently used experts can be cached in fast memory</li>
<li>Dynamic expert loading based on input characteristics</li>
<li>Predictive expert prefetching</li>
</ul>
</section>
<section id="batching-strategies" class="level4">
<h4 class="anchored" data-anchor-id="batching-strategies">Batching Strategies</h4>
<ul>
<li>Tokens routed to the same expert are batched together</li>
<li>Dynamic batching based on routing decisions</li>
<li>Memory-efficient expert execution</li>
</ul>
</section>
</section>
</section>
<section id="performance-and-scalability" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-scalability" id="performance-and-scalability">Performance and Scalability</h2>
<section id="empirical-results" class="level3">
<h3 class="anchored" data-anchor-id="empirical-results" id="empirical-results">Empirical Results</h3>
<p>Switch Transformer has demonstrated impressive performance across various benchmarks:</p>
<div id="87645018" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Sample data for illustration</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> [<span class="st">'Dense Transformer'</span>, <span class="st">'Traditional MoE'</span>, <span class="st">'Switch Transformer'</span>]</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>flops_per_token <span class="op">=</span> [<span class="dv">100</span>, <span class="dv">80</span>, <span class="dv">30</span>]</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>performance_score <span class="op">=</span> [<span class="dv">85</span>, <span class="dv">88</span>, <span class="dv">92</span>]</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>fig, (ax1, ax2) <span class="op">=</span> plt.subplots(<span class="dv">1</span>, <span class="dv">2</span>, figsize<span class="op">=</span>(<span class="dv">12</span>, <span class="dv">5</span>))</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="co"># FLOPs comparison</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>ax1.bar(models, flops_per_token, color<span class="op">=</span>[<span class="st">'#ff6b6b'</span>, <span class="st">'#4ecdc4'</span>, <span class="st">'#45b7d1'</span>])</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>ax1.set_ylabel(<span class="st">'FLOPs per Token (Relative)'</span>)</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>ax1.set_title(<span class="st">'Computational Efficiency'</span>)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>ax1.tick_params(axis<span class="op">=</span><span class="st">'x'</span>, rotation<span class="op">=</span><span class="dv">45</span>)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Performance comparison</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>ax2.bar(models, performance_score, color<span class="op">=</span>[<span class="st">'#ff6b6b'</span>, <span class="st">'#4ecdc4'</span>, <span class="st">'#45b7d1'</span>])</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>ax2.set_ylabel(<span class="st">'Performance Score'</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>ax2.set_title(<span class="st">'Model Performance'</span>)</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>ax2.tick_params(axis<span class="op">=</span><span class="st">'x'</span>, rotation<span class="op">=</span><span class="dv">45</span>)</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>plt.tight_layout()</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/models/switch-transformer/cell-5-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Performance comparison showing Switch Transformer’s efficiency gains</figcaption>
</figure>
</div>
</div>
</div>
<section id="language-modeling" class="level4">
<h4 class="anchored" data-anchor-id="language-modeling">Language Modeling</h4>
<ul>
<li>Achieved state-of-the-art results on language modeling tasks</li>
<li>Significant improvements in perplexity with fewer FLOPs</li>
<li>Effective scaling to trillion-parameter models</li>
</ul>
</section>
<section id="multi-task-learning" class="level4">
<h4 class="anchored" data-anchor-id="multi-task-learning">Multi-task Learning</h4>
<ul>
<li>Strong performance across diverse NLP tasks</li>
<li>Effective knowledge transfer between tasks</li>
<li>Improved sample efficiency</li>
</ul>
</section>
</section>
<section id="scaling-properties" class="level3">
<h3 class="anchored" data-anchor-id="scaling-properties" id="scaling-properties">Scaling Properties</h3>
<p>The scaling properties of Switch Transformer are particularly noteworthy:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Parameter Scaling</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Expert Specialization</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-3" role="tab" aria-controls="tabset-2-3" aria-selected="false" href="">Computational Efficiency</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<ul>
<li>Linear increase in parameters with number of experts</li>
<li>Sublinear increase in computational cost</li>
<li>Maintained quality with increased sparsity</li>
</ul>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<ul>
<li>Experts develop clear specializations during training</li>
<li>Linguistic experts emerge (syntax, semantics, etc.)</li>
<li>Domain-specific experts for specialized tasks</li>
</ul>
</div>
<div id="tabset-2-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-3-tab">
<ul>
<li>Significant reduction in FLOPs per token</li>
<li>Improved throughput for large-scale applications</li>
<li>Better resource utilization in distributed settings</li>
</ul>
</div>
</div>
</div>
</section>
</section>
<section id="advantages-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="advantages-and-limitations" id="advantages-and-limitations">Advantages and Limitations</h2>
<section id="advantages" class="level3">
<h3 class="anchored" data-anchor-id="advantages" id="advantages">Advantages</h3>
<ol type="1">
<li><strong>Computational Efficiency</strong>: Dramatically reduced computational cost per token while maintaining large model capacity</li>
<li><strong>Scalability</strong>: Ability to scale to trillions of parameters without proportional increase in computation</li>
<li><strong>Specialization</strong>: Experts develop clear specializations, leading to better performance on diverse tasks</li>
<li><strong>Flexibility</strong>: Can be applied to various transformer architectures and tasks</li>
<li><strong>Resource Optimization</strong>: Better utilization of computational resources in distributed settings</li>
</ol>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<div class="callout callout-style-default callout-caution callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Caution</span>Current Limitations
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Memory Requirements</strong>: Despite computational efficiency, large models still require substantial memory</li>
<li><strong>Communication Overhead</strong>: Distributed training requires careful optimization of communication patterns</li>
<li><strong>Load Balancing</strong>: Achieving perfect load balance across experts remains challenging</li>
<li><strong>Complexity</strong>: Implementation complexity is higher than standard transformers</li>
<li><strong>Hardware Dependencies</strong>: Optimal performance requires specialized hardware configurations</li>
</ol>
</div>
</div>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="natural-language-processing" class="level3">
<h3 class="anchored" data-anchor-id="natural-language-processing" id="natural-language-processing">Natural Language Processing</h3>
<p>Switch Transformer has shown particular strength in various NLP applications:</p>
<section id="language-modeling-1" class="level4">
<h4 class="anchored" data-anchor-id="language-modeling-1">Language Modeling</h4>
<ul>
<li>Large-scale language model pretraining</li>
<li>Improved efficiency for autoregressive generation</li>
<li>Better handling of diverse linguistic phenomena</li>
</ul>
</section>
<section id="machine-translation" class="level4">
<h4 class="anchored" data-anchor-id="machine-translation">Machine Translation</h4>
<ul>
<li>Multilingual translation systems</li>
<li>Language-specific expert development</li>
<li>Improved handling of low-resource languages</li>
</ul>
</section>
<section id="text-classification" class="level4">
<h4 class="anchored" data-anchor-id="text-classification">Text Classification</h4>
<ul>
<li>Multi-domain classification tasks</li>
<li>Efficient fine-tuning for specific domains</li>
<li>Robust performance across diverse text types</li>
</ul>
</section>
</section>
<section id="beyond-nlp" class="level3">
<h3 class="anchored" data-anchor-id="beyond-nlp" id="beyond-nlp">Beyond NLP</h3>
<p>While primarily developed for NLP, Switch Transformer principles can be applied to other domains:</p>
<section id="computer-vision" class="level4">
<h4 class="anchored" data-anchor-id="computer-vision">Computer Vision</h4>
<ul>
<li>Vision transformers with expert routing</li>
<li>Specialized processing for different visual patterns</li>
<li>Efficient scaling for large vision models</li>
</ul>
</section>
<section id="multimodal-learning" class="level4">
<h4 class="anchored" data-anchor-id="multimodal-learning">Multimodal Learning</h4>
<ul>
<li>Cross-modal expert specialization</li>
<li>Efficient processing of diverse input modalities</li>
<li>Improved scaling for multimodal models</li>
</ul>
</section>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="framework-support" class="level3">
<h3 class="anchored" data-anchor-id="framework-support" id="framework-support">Framework Support</h3>
<p>Switch Transformer implementations are available in several frameworks:</p>
<div id="79589915" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example implementation structure in PyTorch</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SwitchTransformerLayer(nn.Module):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, num_experts, expert_capacity):</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_experts <span class="op">=</span> num_experts</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.expert_capacity <span class="op">=</span> expert_capacity</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Router network</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.router <span class="op">=</span> nn.Linear(d_model, num_experts)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Expert networks</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.experts <span class="op">=</span> nn.ModuleList([</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>            nn.Sequential(</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>                nn.Linear(d_model, d_model <span class="op">*</span> <span class="dv">4</span>),</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>                nn.ReLU(),</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>                nn.Linear(d_model <span class="op">*</span> <span class="dv">4</span>, d_model)</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            ) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_experts)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Router decision</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        router_logits <span class="op">=</span> <span class="va">self</span>.router(x)</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        expert_weights <span class="op">=</span> torch.softmax(router_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select expert (hard routing)</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>        selected_expert <span class="op">=</span> torch.argmax(expert_weights, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply selected expert</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len <span class="op">=</span> x.shape[:<span class="dv">2</span>]</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.zeros_like(x)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_experts):</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>            mask <span class="op">=</span> (selected_expert <span class="op">==</span> i)</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> mask.<span class="bu">any</span>():</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>                expert_input <span class="op">=</span> x[mask]</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>                expert_output <span class="op">=</span> <span class="va">self</span>.experts[i](expert_input)</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>                output[mask] <span class="op">=</span> expert_output</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</div>
<section id="jaxflax" class="level4">
<h4 class="anchored" data-anchor-id="jaxflax">JAX/Flax</h4>
<ul>
<li>Original implementation from Google Research</li>
<li>Optimized for TPU training</li>
<li>Comprehensive distributed training support</li>
</ul>
</section>
<section id="pytorch" class="level4">
<h4 class="anchored" data-anchor-id="pytorch">PyTorch</h4>
<ul>
<li>Community implementations available</li>
<li>Integration with Hugging Face Transformers</li>
<li>Support for GPU training</li>
</ul>
</section>
<section id="tensorflow" class="level4">
<h4 class="anchored" data-anchor-id="tensorflow">TensorFlow</h4>
<ul>
<li>TensorFlow Model Garden implementations</li>
<li>Integration with TensorFlow Serving</li>
<li>Support for various deployment scenarios</li>
</ul>
</section>
</section>
<section id="deployment-strategies" class="level3">
<h3 class="anchored" data-anchor-id="deployment-strategies" id="deployment-strategies">Deployment Strategies</h3>
<p>Deploying Switch Transformer models requires careful consideration:</p>
<section id="inference-optimization-1" class="level4">
<h4 class="anchored" data-anchor-id="inference-optimization-1">Inference Optimization</h4>
<ul>
<li>Expert pruning for reduced model size</li>
<li>Dynamic expert loading</li>
<li>Efficient batching strategies</li>
</ul>
</section>
<section id="serving-infrastructure" class="level4">
<h4 class="anchored" data-anchor-id="serving-infrastructure">Serving Infrastructure</h4>
<ul>
<li>Distributed serving across multiple machines</li>
<li>Load balancing for expert utilization</li>
<li>Caching strategies for frequently used experts</li>
</ul>
</section>
</section>
</section>
<section id="future-directions-and-research" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-research" id="future-directions-and-research">Future Directions and Research</h2>
<section id="ongoing-research-areas" class="level3">
<h3 class="anchored" data-anchor-id="ongoing-research-areas" id="ongoing-research-areas">Ongoing Research Areas</h3>
<p>Several areas of active research are extending Switch Transformer capabilities:</p>
<section id="improved-routing-algorithms" class="level4">
<h4 class="anchored" data-anchor-id="improved-routing-algorithms">Improved Routing Algorithms</h4>
<ul>
<li>More sophisticated routing mechanisms</li>
<li>Adaptive routing based on input characteristics</li>
<li>Learned routing policies</li>
</ul>
</section>
<section id="dynamic-expert-creation" class="level4">
<h4 class="anchored" data-anchor-id="dynamic-expert-creation">Dynamic Expert Creation</h4>
<ul>
<li>Automatic expert creation and pruning</li>
<li>Adaptive model capacity based on task requirements</li>
<li>Continual learning with expert specialization</li>
</ul>
</section>
<section id="cross-domain-applications" class="level4">
<h4 class="anchored" data-anchor-id="cross-domain-applications">Cross-domain Applications</h4>
<ul>
<li>Extension to other domains beyond NLP</li>
<li>Universal expert architectures</li>
<li>Multi-task expert sharing</li>
</ul>
</section>
</section>
<section id="emerging-variants" class="level3">
<h3 class="anchored" data-anchor-id="emerging-variants" id="emerging-variants">Emerging Variants</h3>
<p>Several variants and extensions of Switch Transformer are being explored:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-3-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-1" role="tab" aria-controls="tabset-3-1" aria-selected="true" href="">GLaM (Generalist Language Model)</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-2" role="tab" aria-controls="tabset-3-2" aria-selected="false" href="">PaLM (Pathways Language Model)</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-3" role="tab" aria-controls="tabset-3-3" aria-selected="false" href="">Switch Transformer V2</a></li></ul>
<div class="tab-content">
<div id="tabset-3-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-3-1-tab">
<ul>
<li>Improved routing mechanisms</li>
<li>Better scaling properties</li>
<li>Enhanced expert specialization</li>
</ul>
</div>
<div id="tabset-3-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-2-tab">
<ul>
<li>Integration with Google’s Pathways system</li>
<li>Improved distributed training</li>
<li>Better hardware utilization</li>
</ul>
</div>
<div id="tabset-3-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-3-tab">
<ul>
<li>Architectural improvements</li>
<li>Better training stability</li>
<li>Enhanced expert utilization</li>
</ul>
</div>
</div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Switch Transformer represents a significant advancement in neural network architecture, demonstrating that sparse computation can achieve remarkable efficiency gains while maintaining or improving model performance. By simplifying the routing mechanism and leveraging expert specialization, Switch Transformer has opened new possibilities for scaling neural networks to unprecedented sizes.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Key Contributions
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Simplified routing algorithm</strong> that maintains effectiveness while reducing complexity</li>
<li><strong>Efficient scaling</strong> to trillion-parameter models with sublinear computational cost</li>
<li><strong>Demonstrated effectiveness</strong> across diverse NLP tasks</li>
<li><strong>Foundation</strong> for future sparse neural network architectures</li>
</ul>
</div>
</div>
<p>As the field continues to evolve, Switch Transformer’s principles of sparsity and expert routing will likely influence the development of future large-scale neural networks. The model’s success demonstrates that efficiency and scale are not mutually exclusive, opening new possibilities for democratizing access to large-scale AI systems.</p>
<p>The ongoing research and development in this area suggest that sparse neural networks will play an increasingly important role in the future of artificial intelligence, making powerful models more accessible and efficient for a broader range of applications and organizations.</p>
</section>
<section id="references-and-further-reading" class="level2">
<h2 class="anchored" data-anchor-id="references-and-further-reading" id="references-and-further-reading">References and Further Reading</h2>
<p>For those interested in diving deeper into Switch Transformer and related topics, the following resources provide comprehensive coverage:</p>
<ul>
<li>Original Switch Transformer paper: “Switch Transformer: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”</li>
<li>Mixture of Experts literature for historical context</li>
<li>Pathways system architecture papers</li>
<li>JAX/Flax documentation for implementation details</li>
<li>Recent advances in sparse neural network research</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Looking Forward
</div>
</div>
<div class="callout-body-container callout-body">
<p>The Switch Transformer represents not just a technical achievement but a paradigm shift toward more efficient and scalable neural network architectures, paving the way for the next generation of AI systems.</p>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete YOLO Object Detection Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-code/</guid>
      <pubDate>Sat, 12 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="complete-yolo-object-detection-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-code/yolo-code.png" class="img-fluid"></p>
<section id="introduction-to-yolo" class="level2">
<h2 class="anchored" data-anchor-id="introduction-to-yolo" id="introduction-to-yolo">1. Introduction to YOLO</h2>
<p>YOLO (You Only Look Once) is a state-of-the-art real-time object detection algorithm that revolutionized computer vision by treating object detection as a single regression problem. Unlike traditional methods that apply classifiers to different parts of an image, YOLO looks at the entire image once and predicts bounding boxes and class probabilities directly.</p>
<section id="key-advantages" class="level3">
<h3 class="anchored" data-anchor-id="key-advantages" id="key-advantages">Key Advantages:</h3>
<ul>
<li><strong>Speed</strong>: Real-time detection (30+ FPS)</li>
<li><strong>Global Context</strong>: Sees entire image during training and testing</li>
<li><strong>Unified Architecture</strong>: Single neural network for end-to-end training</li>
<li><strong>Versatility</strong>: Works well across different object types</li>
</ul>
</section>
<section id="yolo-evolution" class="level3">
<h3 class="anchored" data-anchor-id="yolo-evolution" id="yolo-evolution">YOLO Evolution:</h3>
<ul>
<li><strong>YOLOv1</strong> (2016): Original paper, 45 FPS</li>
<li><strong>YOLOv2/YOLO9000</strong> (2016): Better accuracy, 40+ FPS</li>
<li><strong>YOLOv3</strong> (2018): Multi-scale detection, Darknet-53</li>
<li><strong>YOLOv4</strong> (2020): Improved accuracy and speed</li>
<li><strong>YOLOv5</strong> (2020): PyTorch implementation, user-friendly</li>
<li><strong>YOLOv8</strong> (2023): Latest Ultralytics version, best performance</li>
</ul>
</section>
</section>
<section id="yolo-architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="yolo-architecture-overview" id="yolo-architecture-overview">2. YOLO Architecture Overview</h2>
<section id="core-concept" class="level3">
<h3 class="anchored" data-anchor-id="core-concept" id="core-concept">Core Concept</h3>
<p>YOLO divides an image into an S×S grid. Each grid cell predicts: - <strong>B bounding boxes</strong> (x, y, width, height, confidence) - <strong>C class probabilities</strong></p>
</section>
<section id="network-architecture-yolov8" class="level3">
<h3 class="anchored" data-anchor-id="network-architecture-yolov8" id="network-architecture-yolov8">Network Architecture (YOLOv8)</h3>
<pre><code>Input Image (640×640×3)
        ↓
Backbone (CSPDarknet53)
        ↓
Neck (PANet)
        ↓
Head (Detection layers)
        ↓
Output (Predictions)</code></pre>
</section>
<section id="loss-function-components" class="level3">
<h3 class="anchored" data-anchor-id="loss-function-components" id="loss-function-components">Loss Function Components:</h3>
<ol type="1">
<li><strong>Localization Loss</strong>: Bounding box coordinate errors</li>
<li><strong>Confidence Loss</strong>: Object presence confidence</li>
<li><strong>Classification Loss</strong>: Class prediction errors</li>
</ol>
</section>
</section>
<section id="setting-up-the-environment" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-the-environment" id="setting-up-the-environment">3. Setting Up the Environment</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Python 3.8+</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">--version</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Create virtual environment</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> venv yolo_env</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="bu">source</span> yolo_env/bin/activate  <span class="co"># Linux/Mac</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="co"># or</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="ex">yolo_env\Scripts\activate</span>     <span class="co"># Windows</span></span></code></pre></div></div>
</section>
<section id="install-dependencies" class="level3">
<h3 class="anchored" data-anchor-id="install-dependencies" id="install-dependencies">Install Dependencies</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyTorch (check pytorch.org for your system)</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision torchaudio</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Install Ultralytics YOLOv8</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install ultralytics</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Additional dependencies</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install opencv-python pillow matplotlib numpy pandas</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install jupyter notebook  <span class="co"># For interactive development</span></span></code></pre></div></div>
</section>
<section id="verify-installation" class="level3">
<h3 class="anchored" data-anchor-id="verify-installation" id="verify-installation">Verify Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> ultralytics</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> ultralytics <span class="im">import</span> YOLO</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"PyTorch version: </span><span class="sc">{</span>torch<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"CUDA available: </span><span class="sc">{</span>torch<span class="sc">.</span>cuda<span class="sc">.</span>is_available()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Ultralytics version: </span><span class="sc">{</span>ultralytics<span class="sc">.</span>__version__<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="yolov8-implementation" class="level2">
<h2 class="anchored" data-anchor-id="yolov8-implementation" id="yolov8-implementation">4. YOLOv8 Implementation</h2>
<section id="basic-object-detection" class="level3">
<h3 class="anchored" data-anchor-id="basic-object-detection" id="basic-object-detection">Basic Object Detection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> ultralytics <span class="im">import</span> YOLO</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained model</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> YOLO(<span class="st">'yolov8n.pt'</span>)  <span class="co"># nano version for speed</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># model = YOLO('yolov8s.pt')  # small</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="co"># model = YOLO('yolov8m.pt')  # medium</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="co"># model = YOLO('yolov8l.pt')  # large</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="co"># model = YOLO('yolov8x.pt')  # extra large</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Single image inference</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> detect_objects(image_path):</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="co">    Detect objects in a single image</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> model(image_path)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Process results</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get bounding boxes, confidence scores, and class IDs</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        boxes <span class="op">=</span> result.boxes.xyxy.cpu().numpy()  <span class="co"># x1, y1, x2, y2</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        confidences <span class="op">=</span> result.boxes.conf.cpu().numpy()</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        class_ids <span class="op">=</span> result.boxes.cls.cpu().numpy()</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load image</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        img <span class="op">=</span> cv2.imread(image_path)</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        img_rgb <span class="op">=</span> cv2.cvtColor(img, cv2.COLOR_BGR2RGB)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Draw bounding boxes</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, (box, conf, cls_id) <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="bu">zip</span>(boxes, confidences, class_ids)):</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>            x1, y1, x2, y2 <span class="op">=</span> <span class="bu">map</span>(<span class="bu">int</span>, box)</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>            class_name <span class="op">=</span> model.names[<span class="bu">int</span>(cls_id)]</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Draw rectangle and label</span></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>            cv2.rectangle(img_rgb, (x1, y1), (x2, y2), (<span class="dv">0</span>, <span class="dv">255</span>, <span class="dv">0</span>), <span class="dv">2</span>)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>            cv2.putText(img_rgb, <span class="ss">f'</span><span class="sc">{</span>class_name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>conf<span class="sc">:.2f}</span><span class="ss">'</span>, </span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>                       (x1, y1<span class="op">-</span><span class="dv">10</span>), cv2.FONT_HERSHEY_SIMPLEX, <span class="fl">0.5</span>, (<span class="dv">0</span>, <span class="dv">255</span>, <span class="dv">0</span>), <span class="dv">2</span>)</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> img_rgb, results</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>image_path <span class="op">=</span> <span class="st">'path/to/your/image.jpg'</span></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>detected_img, results <span class="op">=</span> detect_objects(image_path)</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a><span class="co"># Display results</span></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">12</span>, <span class="dv">8</span>))</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>plt.imshow(detected_img)</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>plt.axis(<span class="st">'off'</span>)</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'YOLO Object Detection Results'</span>)</span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
</section>
<section id="video-processing" class="level3">
<h3 class="anchored" data-anchor-id="video-processing" id="video-processing">Video Processing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_video(video_path, output_path<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Process video for object detection</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    cap <span class="op">=</span> cv2.VideoCapture(video_path)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get video properties</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    fps <span class="op">=</span> <span class="bu">int</span>(cap.get(cv2.CAP_PROP_FPS))</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    width <span class="op">=</span> <span class="bu">int</span>(cap.get(cv2.CAP_PROP_FRAME_WIDTH))</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    height <span class="op">=</span> <span class="bu">int</span>(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup video writer if output path provided</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> output_path:</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        fourcc <span class="op">=</span> cv2.VideoWriter_fourcc(<span class="op">*</span><span class="st">'mp4v'</span>)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> cv2.VideoWriter(output_path, fourcc, fps, (width, height))</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        ret, frame <span class="op">=</span> cap.read()</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> ret:</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run YOLO detection</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> model(frame)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Draw results on frame</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        annotated_frame <span class="op">=</span> results[<span class="dv">0</span>].plot()</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save or display frame</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> output_path:</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>            out.write(annotated_frame)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>            cv2.imshow(<span class="st">'YOLO Detection'</span>, annotated_frame)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> cv2.waitKey(<span class="dv">1</span>) <span class="op">&amp;</span> <span class="bn">0xFF</span> <span class="op">==</span> <span class="bu">ord</span>(<span class="st">'q'</span>):</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>    cap.release()</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> output_path:</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        out.release()</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    cv2.destroyAllWindows()</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>process_video(<span class="st">'input_video.mp4'</span>, <span class="st">'output_video.mp4'</span>)</span></code></pre></div></div>
</section>
<section id="real-time-webcam-detection" class="level3">
<h3 class="anchored" data-anchor-id="real-time-webcam-detection" id="real-time-webcam-detection">Real-time Webcam Detection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> real_time_detection():</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Real-time object detection from webcam</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    cap <span class="op">=</span> cv2.VideoCapture(<span class="dv">0</span>)  <span class="co"># Use 0 for default camera</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        ret, frame <span class="op">=</span> cap.read()</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> ret:</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run YOLO detection</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> model(frame)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Draw results</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        annotated_frame <span class="op">=</span> results[<span class="dv">0</span>].plot()</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Display frame</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        cv2.imshow(<span class="st">'Real-time YOLO Detection'</span>, annotated_frame)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Exit on 'q' key press</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> cv2.waitKey(<span class="dv">1</span>) <span class="op">&amp;</span> <span class="bn">0xFF</span> <span class="op">==</span> <span class="bu">ord</span>(<span class="st">'q'</span>):</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    cap.release()</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    cv2.destroyAllWindows()</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Start real-time detection</span></span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>real_time_detection()</span></code></pre></div></div>
</section>
</section>
<section id="custom-training" class="level2">
<h2 class="anchored" data-anchor-id="custom-training" id="custom-training">5. Custom Dataset Training</h2>
<section id="dataset-preparation" class="level3">
<h3 class="anchored" data-anchor-id="dataset-preparation" id="dataset-preparation">Dataset Preparation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> shutil</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pathlib <span class="im">import</span> Path</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_dataset_structure(base_path):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co">    Create YOLO dataset structure</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    paths <span class="op">=</span> [</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">'train/images'</span>,</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">'train/labels'</span>,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">'val/images'</span>,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">'val/labels'</span>,</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">'test/images'</span>,</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">'test/labels'</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> path <span class="kw">in</span> paths:</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        Path(base_path <span class="op">/</span> path).mkdir(parents<span class="op">=</span><span class="va">True</span>, exist_ok<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Dataset structure created at </span><span class="sc">{</span>base_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Create dataset structure</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>dataset_path <span class="op">=</span> Path(<span class="st">'custom_dataset'</span>)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>create_dataset_structure(dataset_path)</span></code></pre></div></div>
</section>
<section id="data-configuration-file" class="level3">
<h3 class="anchored" data-anchor-id="data-configuration-file" id="data-configuration-file">Data Configuration File</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># data.yaml</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="fu">train</span><span class="kw">:</span><span class="at"> ../custom_dataset/train/images</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="fu">val</span><span class="kw">:</span><span class="at"> ../custom_dataset/val/images</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="fu">test</span><span class="kw">:</span><span class="at"> ../custom_dataset/test/images</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="fu">nc</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span><span class="co">  # number of classes</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="fu">names</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">'person'</span><span class="kw">,</span><span class="at"> </span><span class="st">'car'</span><span class="kw">,</span><span class="at"> </span><span class="st">'bicycle'</span><span class="kw">]</span><span class="co">  # class names</span></span></code></pre></div></div>
</section>
<section id="training-script" class="level3">
<h3 class="anchored" data-anchor-id="training-script" id="training-script">Training Script</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> ultralytics <span class="im">import</span> YOLO</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_custom_model():</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Train YOLO model on custom dataset</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load a pre-trained model</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(<span class="st">'yolov8n.pt'</span>)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train the model</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> model.train(</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span><span class="st">'data.yaml'</span>,           <span class="co"># dataset config file</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        epochs<span class="op">=</span><span class="dv">100</span>,                 <span class="co"># number of training epochs</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        imgsz<span class="op">=</span><span class="dv">640</span>,                  <span class="co"># image size</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span><span class="dv">16</span>,              <span class="co"># batch size</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>        device<span class="op">=</span><span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>,</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        workers<span class="op">=</span><span class="dv">4</span>,                  <span class="co"># number of data loader workers</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        project<span class="op">=</span><span class="st">'runs/train'</span>,       <span class="co"># project directory</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">'custom_model'</span>,        <span class="co"># experiment name</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>        save<span class="op">=</span><span class="va">True</span>,                  <span class="co"># save model checkpoints</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>        save_period<span class="op">=</span><span class="dv">10</span>,             <span class="co"># save checkpoint every N epochs</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>        cache<span class="op">=</span><span class="va">True</span>,                 <span class="co"># cache images for faster training</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>        augment<span class="op">=</span><span class="va">True</span>,               <span class="co"># use data augmentation</span></span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>        lr0<span class="op">=</span><span class="fl">0.01</span>,                   <span class="co"># initial learning rate</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>        weight_decay<span class="op">=</span><span class="fl">0.0005</span>,        <span class="co"># weight decay</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>        warmup_epochs<span class="op">=</span><span class="dv">3</span>,            <span class="co"># warmup epochs</span></span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>        patience<span class="op">=</span><span class="dv">50</span>,                <span class="co"># early stopping patience</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        verbose<span class="op">=</span><span class="va">True</span>                <span class="co"># verbose output</span></span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Start training</span></span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> train_custom_model()</span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Training completed!"</span>)</span></code></pre></div></div>
</section>
<section id="data-augmentation" class="level3">
<h3 class="anchored" data-anchor-id="data-augmentation" id="data-augmentation">Data Augmentation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom augmentation configuration</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>augmentation_config <span class="op">=</span> {</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'hsv_h'</span>: <span class="fl">0.015</span>,      <span class="co"># HSV-Hue augmentation</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">'hsv_s'</span>: <span class="fl">0.7</span>,        <span class="co"># HSV-Saturation augmentation</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'hsv_v'</span>: <span class="fl">0.4</span>,        <span class="co"># HSV-Value augmentation</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'degrees'</span>: <span class="fl">10.0</span>,     <span class="co"># rotation degrees</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'translate'</span>: <span class="fl">0.1</span>,    <span class="co"># translation</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">'scale'</span>: <span class="fl">0.5</span>,        <span class="co"># scale</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'shear'</span>: <span class="fl">2.0</span>,        <span class="co"># shear degrees</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'perspective'</span>: <span class="fl">0.0</span>,  <span class="co"># perspective</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">'flipud'</span>: <span class="fl">0.0</span>,       <span class="co"># flip up-down probability</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">'fliplr'</span>: <span class="fl">0.5</span>,       <span class="co"># flip left-right probability</span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">'mosaic'</span>: <span class="fl">1.0</span>,       <span class="co"># mosaic probability</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">'mixup'</span>: <span class="fl">0.1</span>,        <span class="co"># mixup probability</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">'copy_paste'</span>: <span class="fl">0.1</span>    <span class="co"># copy-paste probability</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">6. Advanced Features</h2>
<section id="model-validation-and-metrics" class="level3">
<h3 class="anchored" data-anchor-id="model-validation-and-metrics" id="model-validation-and-metrics">Model Validation and Metrics</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate_model(model_path, data_config):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Validate trained model and get metrics</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validate the model</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> model.val(</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span>data_config,</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        imgsz<span class="op">=</span><span class="dv">640</span>,</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        device<span class="op">=</span><span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>,</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        plots<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        save_json<span class="op">=</span><span class="va">True</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Print metrics</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"mAP50: </span><span class="sc">{</span>results<span class="sc">.</span>box<span class="sc">.</span>map50<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"mAP50-95: </span><span class="sc">{</span>results<span class="sc">.</span>box<span class="sc">.</span><span class="bu">map</span><span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Precision: </span><span class="sc">{</span>results<span class="sc">.</span>box<span class="sc">.</span>mp<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Recall: </span><span class="sc">{</span>results<span class="sc">.</span>box<span class="sc">.</span>mr<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Validate model</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>validation_results <span class="op">=</span> validate_model(<span class="st">'runs/train/custom_model/weights/best.pt'</span>, <span class="st">'data.yaml'</span>)</span></code></pre></div></div>
</section>
<section id="model-export-and-optimization" class="level3">
<h3 class="anchored" data-anchor-id="model-export-and-optimization" id="model-export-and-optimization">Model Export and Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> export_model(model_path, export_format<span class="op">=</span><span class="st">'onnx'</span>):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Export model to different formats</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Export options</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    export_formats <span class="op">=</span> {</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'onnx'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'onnx'</span>),</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">'torchscript'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'torchscript'</span>),</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">'tflite'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'tflite'</span>),</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">'tensorrt'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'engine'</span>),  <span class="co"># TensorRT</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">'openvino'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'openvino'</span>),</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">'coreml'</span>: model.export(<span class="bu">format</span><span class="op">=</span><span class="st">'coreml'</span>)</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> export_formats[export_format]</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Export to ONNX</span></span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>onnx_model <span class="op">=</span> export_model(<span class="st">'runs/train/custom_model/weights/best.pt'</span>, <span class="st">'onnx'</span>)</span></code></pre></div></div>
</section>
<section id="hyperparameter-tuning" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> hyperparameter_tuning():</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Automated hyperparameter tuning</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(<span class="st">'yolov8n.pt'</span>)</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Tune hyperparameters</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    model.tune(</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        data<span class="op">=</span><span class="st">'data.yaml'</span>,</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        epochs<span class="op">=</span><span class="dv">30</span>,</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        iterations<span class="op">=</span><span class="dv">300</span>,</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        optimizer<span class="op">=</span><span class="st">'AdamW'</span>,</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        plots<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        save<span class="op">=</span><span class="va">True</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Run hyperparameter tuning</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>hyperparameter_tuning()</span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">7. Performance Optimization</h2>
<section id="multi-gpu-training" class="level3">
<h3 class="anchored" data-anchor-id="multi-gpu-training" id="multi-gpu-training">Multi-GPU Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multi_gpu_training():</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Training with multiple GPUs</span></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.device_count() <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> YOLO(<span class="st">'yolov8n.pt'</span>)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multi-GPU training</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> model.train(</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            data<span class="op">=</span><span class="st">'data.yaml'</span>,</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            epochs<span class="op">=</span><span class="dv">100</span>,</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            imgsz<span class="op">=</span><span class="dv">640</span>,</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>            batch_size<span class="op">=</span><span class="dv">32</span>,  <span class="co"># Increase batch size for multi-GPU</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>            device<span class="op">=</span><span class="st">'0,1,2,3'</span>,  <span class="co"># Specify GPU IDs</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>            workers<span class="op">=</span><span class="dv">8</span></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Multiple GPUs not available"</span>)</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>multi_gpu_training()</span></code></pre></div></div>
</section>
<section id="inference-optimization" class="level3">
<h3 class="anchored" data-anchor-id="inference-optimization" id="inference-optimization">Inference Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_model(model_path, test_images):</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Benchmark model performance</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warm up</span></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        model(<span class="st">'path/to/test/image.jpg'</span>)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Benchmark</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    times <span class="op">=</span> []</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> image_path <span class="kw">in</span> test_images:</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> model(image_path)</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>        times.append(end_time <span class="op">-</span> start_time)</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>    avg_time <span class="op">=</span> np.mean(times)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>    fps <span class="op">=</span> <span class="dv">1</span> <span class="op">/</span> avg_time</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Average inference time: </span><span class="sc">{</span>avg_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"FPS: </span><span class="sc">{</span>fps<span class="sc">:.2f}</span><span class="ss">"</span>)</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> avg_time, fps</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Benchmark your model</span></span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>test_images <span class="op">=</span> [<span class="st">'test1.jpg'</span>, <span class="st">'test2.jpg'</span>, <span class="st">'test3.jpg'</span>]</span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>avg_time, fps <span class="op">=</span> benchmark_model(<span class="st">'yolov8n.pt'</span>, test_images)</span></code></pre></div></div>
</section>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_inference(model_path, image_path):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Memory efficient inference for large images</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Process image in tiles for large images</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> process_large_image(image_path, tile_size<span class="op">=</span><span class="dv">640</span>, overlap<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        img <span class="op">=</span> cv2.imread(image_path)</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        h, w <span class="op">=</span> img.shape[:<span class="dv">2</span>]</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> h <span class="op">&lt;=</span> tile_size <span class="kw">and</span> w <span class="op">&lt;=</span> tile_size:</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Small image, process normally</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> model(img)</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Split into tiles</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>        step <span class="op">=</span> <span class="bu">int</span>(tile_size <span class="op">*</span> (<span class="dv">1</span> <span class="op">-</span> overlap))</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> y <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, h, step):</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> x <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, w, step):</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Extract tile</span></span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>                tile <span class="op">=</span> img[y:y<span class="op">+</span>tile_size, x:x<span class="op">+</span>tile_size]</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Process tile</span></span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>                tile_results <span class="op">=</span> model(tile)</span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Adjust coordinates</span></span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> result <span class="kw">in</span> tile_results:</span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> result.boxes <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>                        result.boxes.xyxy[:, [<span class="dv">0</span>, <span class="dv">2</span>]] <span class="op">+=</span> x</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>                        result.boxes.xyxy[:, [<span class="dv">1</span>, <span class="dv">3</span>]] <span class="op">+=</span> y</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>                results.extend(tile_results)</span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> process_large_image(image_path)</span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Process large image</span></span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>large_image_results <span class="op">=</span> memory_efficient_inference(<span class="st">'yolov8n.pt'</span>, <span class="st">'large_image.jpg'</span>)</span></code></pre></div></div>
</section>
</section>
<section id="real-world-applications" class="level2">
<h2 class="anchored" data-anchor-id="real-world-applications" id="real-world-applications">8. Real-world Applications</h2>
<section id="security-camera-system" class="level3">
<h3 class="anchored" data-anchor-id="security-camera-system" id="security-camera-system">Security Camera System</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SecuritySystem:</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path, camera_sources):</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cameras <span class="op">=</span> camera_sources</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alerts <span class="op">=</span> []</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> monitor_cameras(<span class="va">self</span>):</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a><span class="co">        Monitor multiple camera feeds</span></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> camera_id, source <span class="kw">in</span> <span class="va">self</span>.cameras.items():</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>            cap <span class="op">=</span> cv2.VideoCapture(source)</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>                ret, frame <span class="op">=</span> cap.read()</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="kw">not</span> ret:</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Detect objects</span></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>                results <span class="op">=</span> <span class="va">self</span>.model(frame)</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Check for specific objects (e.g., person)</span></span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> result.boxes <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>                        classes <span class="op">=</span> result.boxes.cls.cpu().numpy()</span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> <span class="dv">0</span> <span class="kw">in</span> classes:  <span class="co"># Person detected</span></span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>                            <span class="va">self</span>.trigger_alert(camera_id, frame)</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Display frame</span></span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>                annotated_frame <span class="op">=</span> results[<span class="dv">0</span>].plot()</span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>                cv2.imshow(<span class="ss">f'Camera </span><span class="sc">{</span>camera_id<span class="sc">}</span><span class="ss">'</span>, annotated_frame)</span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> cv2.waitKey(<span class="dv">1</span>) <span class="op">&amp;</span> <span class="bn">0xFF</span> <span class="op">==</span> <span class="bu">ord</span>(<span class="st">'q'</span>):</span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>            cap.release()</span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>        cv2.destroyAllWindows()</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> trigger_alert(<span class="va">self</span>, camera_id, frame):</span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a><span class="co">        Trigger security alert</span></span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>        timestamp <span class="op">=</span> time.strftime(<span class="st">"%Y-%m-</span><span class="sc">%d</span><span class="st"> %H:%M:%S"</span>)</span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>        alert <span class="op">=</span> {</span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">'camera_id'</span>: camera_id,</span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">'timestamp'</span>: timestamp,</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>            <span class="st">'frame'</span>: frame</span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alerts.append(alert)</span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"ALERT: Person detected on Camera </span><span class="sc">{</span>camera_id<span class="sc">}</span><span class="ss"> at </span><span class="sc">{</span>timestamp<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup security system</span></span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>cameras <span class="op">=</span> {</span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>    <span class="st">'cam1'</span>: <span class="dv">0</span>,  <span class="co"># Webcam</span></span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>    <span class="st">'cam2'</span>: <span class="st">'rtsp://camera2/stream'</span>,  <span class="co"># IP camera</span></span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a>security <span class="op">=</span> SecuritySystem(<span class="st">'yolov8n.pt'</span>, cameras)</span></code></pre></div></div>
</section>
<section id="traffic-monitoring" class="level3">
<h3 class="anchored" data-anchor-id="traffic-monitoring" id="traffic-monitoring">Traffic Monitoring</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TrafficMonitor:</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path):</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vehicle_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.speed_violations <span class="op">=</span> []</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_traffic(<span class="va">self</span>, video_path):</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a><span class="co">        Analyze traffic from video feed</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>        cap <span class="op">=</span> cv2.VideoCapture(video_path)</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>            ret, frame <span class="op">=</span> cap.read()</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="kw">not</span> ret:</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Detect vehicles</span></span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>            results <span class="op">=</span> <span class="va">self</span>.model(frame)</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Count vehicles</span></span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>            vehicle_classes <span class="op">=</span> [<span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>]  <span class="co"># car, motorcycle, bus, truck</span></span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>            current_vehicles <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> result.boxes <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>                    classes <span class="op">=</span> result.boxes.cls.cpu().numpy()</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>                    current_vehicles <span class="op">+=</span> <span class="bu">sum</span>(<span class="dv">1</span> <span class="cf">for</span> cls <span class="kw">in</span> classes <span class="cf">if</span> cls <span class="kw">in</span> vehicle_classes)</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.vehicle_count <span class="op">=</span> <span class="bu">max</span>(<span class="va">self</span>.vehicle_count, current_vehicles)</span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Display results</span></span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a>            annotated_frame <span class="op">=</span> results[<span class="dv">0</span>].plot()</span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a>            cv2.putText(annotated_frame, <span class="ss">f'Vehicles: </span><span class="sc">{</span>current_vehicles<span class="sc">}</span><span class="ss">'</span>, </span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a>                       (<span class="dv">10</span>, <span class="dv">30</span>), cv2.FONT_HERSHEY_SIMPLEX, <span class="dv">1</span>, (<span class="dv">0</span>, <span class="dv">255</span>, <span class="dv">0</span>), <span class="dv">2</span>)</span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a>            cv2.imshow(<span class="st">'Traffic Monitor'</span>, annotated_frame)</span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> cv2.waitKey(<span class="dv">1</span>) <span class="op">&amp;</span> <span class="bn">0xFF</span> <span class="op">==</span> <span class="bu">ord</span>(<span class="st">'q'</span>):</span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a>        cap.release()</span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a>        cv2.destroyAllWindows()</span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Maximum vehicles detected: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>vehicle_count<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor traffic</span></span>
<span id="cb19-48"><a href="#cb19-48" aria-hidden="true" tabindex="-1"></a>traffic_monitor <span class="op">=</span> TrafficMonitor(<span class="st">'yolov8n.pt'</span>)</span>
<span id="cb19-49"><a href="#cb19-49" aria-hidden="true" tabindex="-1"></a>traffic_monitor.analyze_traffic(<span class="st">'traffic_video.mp4'</span>)</span></code></pre></div></div>
</section>
<section id="quality-control-system" class="level3">
<h3 class="anchored" data-anchor-id="quality-control-system" id="quality-control-system">Quality Control System</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> QualityControl:</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path):</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> YOLO(model_path)</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.defect_log <span class="op">=</span> []</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inspect_products(<span class="va">self</span>, image_paths):</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a><span class="co">        Inspect products for defects</span></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> image_path <span class="kw">in</span> image_paths:</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>            results <span class="op">=</span> <span class="va">self</span>.model(image_path)</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Analyze results for defects</span></span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>            defects_found <span class="op">=</span> []</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> result.boxes <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>                    classes <span class="op">=</span> result.boxes.cls.cpu().numpy()</span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>                    confidences <span class="op">=</span> result.boxes.conf.cpu().numpy()</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">for</span> cls, conf <span class="kw">in</span> <span class="bu">zip</span>(classes, confidences):</span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> conf <span class="op">&gt;</span> <span class="fl">0.5</span>:  <span class="co"># Confidence threshold</span></span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>                            defect_type <span class="op">=</span> <span class="va">self</span>.model.names[<span class="bu">int</span>(cls)]</span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>                            defects_found.append(defect_type)</span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Log results</span></span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>            inspection_result <span class="op">=</span> {</span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">'image'</span>: image_path,</span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">'defects'</span>: defects_found,</span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">'status'</span>: <span class="st">'FAIL'</span> <span class="cf">if</span> defects_found <span class="cf">else</span> <span class="st">'PASS'</span>,</span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">'timestamp'</span>: time.strftime(<span class="st">"%Y-%m-</span><span class="sc">%d</span><span class="st"> %H:%M:%S"</span>)</span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.defect_log.append(inspection_result)</span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Inspected </span><span class="sc">{</span>image_path<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>inspection_result[<span class="st">'status'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.defect_log</span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Quality control inspection</span></span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a>qc <span class="op">=</span> QualityControl(<span class="st">'custom_defect_model.pt'</span>)</span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a>product_images <span class="op">=</span> [<span class="st">'product1.jpg'</span>, <span class="st">'product2.jpg'</span>, <span class="st">'product3.jpg'</span>]</span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a>inspection_results <span class="op">=</span> qc.inspect_products(product_images)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices-and-tips" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-tips" id="best-practices-and-tips">Best Practices and Tips</h2>
<section id="performance-tips" class="level3">
<h3 class="anchored" data-anchor-id="performance-tips" id="performance-tips">Performance Tips</h3>
<ol type="1">
<li><strong>Choose the right model size</strong>: Use YOLOv8n for speed, YOLOv8x for accuracy</li>
<li><strong>Optimize image size</strong>: Use 640x640 for balance, smaller for speed</li>
<li><strong>Use appropriate batch size</strong>: Maximize GPU utilization</li>
<li><strong>Enable model compilation</strong>: Use TorchScript or TensorRT for production</li>
<li><strong>Implement model caching</strong>: Load models once and reuse</li>
</ol>
</section>
<section id="training-tips" class="level3">
<h3 class="anchored" data-anchor-id="training-tips" id="training-tips">Training Tips</h3>
<ol type="1">
<li><strong>Data quality over quantity</strong>: Focus on high-quality, diverse training data</li>
<li><strong>Proper data augmentation</strong>: Use appropriate augmentations for your domain</li>
<li><strong>Monitor training metrics</strong>: Watch for overfitting and adjust accordingly</li>
<li><strong>Use transfer learning</strong>: Start with pre-trained weights</li>
<li><strong>Regular validation</strong>: Validate on held-out data during training</li>
</ol>
</section>
<section id="deployment-tips" class="level3">
<h3 class="anchored" data-anchor-id="deployment-tips" id="deployment-tips">Deployment Tips</h3>
<ol type="1">
<li><strong>Model versioning</strong>: Keep track of model versions and performance</li>
<li><strong>A/B testing</strong>: Test different models in production</li>
<li><strong>Monitoring</strong>: Track inference time and accuracy in production</li>
<li><strong>Fallback mechanisms</strong>: Have backup models for critical applications</li>
<li><strong>Documentation</strong>: Document model performance and limitations</li>
</ol>
<p>This comprehensive guide covers the essential aspects of working with YOLO for object detection. Start with the basic implementations and gradually explore advanced features as your needs grow. Remember to always validate your models thoroughly before deploying them in production environments.</p>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[The Mathematics Behind YOLO: A Deep Dive into Object Detection]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-math/</guid>
      <pubDate>Sat, 12 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="the-mathematics-behind-yolo-a-deep-dive-into-object-detection" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-math/yolomath.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>You Only Look Once (YOLO) revolutionized object detection by treating it as a single regression problem, directly predicting bounding boxes and class probabilities from full images in one evaluation. Unlike traditional approaches that apply classifiers to different parts of an image, YOLO’s unified architecture enables real-time detection while maintaining high accuracy.</p>
</section>
<section id="core-mathematical-framework" class="level2">
<h2 class="anchored" data-anchor-id="core-mathematical-framework" id="core-mathematical-framework">Core Mathematical Framework</h2>
<section id="grid-based-detection-paradigm" class="level3">
<h3 class="anchored" data-anchor-id="grid-based-detection-paradigm" id="grid-based-detection-paradigm">Grid-Based Detection Paradigm</h3>
<p>YOLO divides an input image into an <span class="math inline">\(S \times S\)</span> grid. Each grid cell is responsible for detecting objects whose centers fall within that cell. This spatial decomposition transforms the object detection problem into a structured prediction task.</p>
<p>For an input image of dimensions <span class="math inline">\(W \times H\)</span>, each grid cell covers a region of size <span class="math inline">\((W/S) \times (H/S)\)</span>. The mathematical mapping from image coordinates to grid coordinates is:</p>
<p><span class="math display">\[
\begin{align}
\text{grid}_x &amp;= \lfloor x_{\text{center}} / (W/S) \rfloor \\
\text{grid}_y &amp;= \lfloor y_{\text{center}} / (H/S) \rfloor
\end{align}
\]</span></p>
<p>where <span class="math inline">\((x_{\text{center}}, y_{\text{center}})\)</span> represents the center coordinates of an object’s bounding box.</p>
</section>
<section id="output-tensor-structure" class="level3">
<h3 class="anchored" data-anchor-id="output-tensor-structure" id="output-tensor-structure">Output Tensor Structure</h3>
<p>The network outputs a tensor of shape <span class="math inline">\(S \times S \times (B \times 5 + C)\)</span>, where:</p>
<ul>
<li><span class="math inline">\(S\)</span> is the grid size</li>
<li><span class="math inline">\(B\)</span> is the number of bounding boxes per grid cell<br>
</li>
<li><span class="math inline">\(C\)</span> is the number of classes</li>
</ul>
<p>Each bounding box prediction contains 5 values: <span class="math inline">\((x, y, w, h, \text{confidence})\)</span>, and each grid cell predicts <span class="math inline">\(C\)</span> class probabilities.</p>
</section>
</section>
<section id="bounding-box-parameterizatioz" class="level2">
<h2 class="anchored" data-anchor-id="bounding-box-parameterizatioz" id="bounding-box-parameterizatioz">Bounding Box Parameterizatioz</h2>
<section id="coordinate-encoding" class="level3">
<h3 class="anchored" data-anchor-id="coordinate-encoding" id="coordinate-encoding">Coordinate Encoding</h3>
<p>YOLO uses a sophisticated coordinate encoding scheme that ensures predictions are bounded and interpretable:</p>
<p><strong>Center Coordinates:</strong> <span class="math display">\[
\begin{align}
x &amp;= \sigma(t_x) + c_x \\
y &amp;= \sigma(t_y) + c_y
\end{align}
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(t_x, t_y\)</span> are the raw network outputs</li>
<li><span class="math inline">\(\sigma\)</span> is the sigmoid function</li>
<li><span class="math inline">\(c_x, c_y\)</span> are the grid cell offsets <span class="math inline">\((0 \leq c_x, c_y &lt; S)\)</span></li>
</ul>
<p>This formulation ensures that predicted centers lie within the responsible grid cell, as <span class="math inline">\(\sigma(t_x) \in [0,1]\)</span>.</p>
<p><strong>Dimensions:</strong> <span class="math display">\[
\begin{align}
w &amp;= p_w \times \exp(t_w) \\
h &amp;= p_h \times \exp(t_h)
\end{align}
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(t_w, t_h\)</span> are the raw network outputs</li>
<li><span class="math inline">\(p_w, p_h\)</span> are anchor box dimensions (in YOLOv2+)</li>
</ul>
<p>The exponential ensures positive dimensions, while anchor boxes provide reasonable priors.</p>
</section>
<section id="confidence-score-mathematics" class="level3">
<h3 class="anchored" data-anchor-id="confidence-score-mathematics" id="confidence-score-mathematics">Confidence Score Mathematics</h3>
<p>The confidence score represents the intersection over union (IoU) between the predicted box and the ground truth box:</p>
<p><span class="math display">\[
\text{Confidence} = \Pr(\text{Object}) \times \text{IoU}(\text{pred}, \text{truth})
\]</span></p>
<p>During inference, this becomes: <span class="math display">\[
\text{Confidence} = \Pr(\text{Object}) \times \text{IoU}(\text{pred}, \text{truth}) \times \Pr(\text{Class}_i|\text{Object})
\]</span></p>
<p>The confidence score effectively captures both the likelihood of an object being present and the accuracy of the bounding box prediction.</p>
</section>
</section>
<section id="loss-function-architecture" class="level2">
<h2 class="anchored" data-anchor-id="loss-function-architecture" id="loss-function-architecture">Loss Function Architecture</h2>
<p>YOLO’s loss function is a carefully designed multi-part objective that balances localization accuracy, confidence prediction, and classification performance.</p>
<section id="complete-loss-function" class="level3">
<h3 class="anchored" data-anchor-id="complete-loss-function" id="complete-loss-function">Complete Loss Function</h3>
<p><span class="math display">\[
\mathcal{L} = \lambda_{\text{coord}} \times \mathcal{L}_{\text{loc}} + \mathcal{L}_{\text{conf}} + \mathcal{L}_{\text{class}}
\]</span></p>
</section>
<section id="localization-loss" class="level3">
<h3 class="anchored" data-anchor-id="localization-loss" id="localization-loss">Localization Loss</h3>
<p><span class="math display">\[
\begin{align}
\mathcal{L}_{\text{loc}} &amp;= \sum_{i=0}^{S^2} \sum_{j=0}^{B} \mathbb{1}_{ij}^{\text{obj}} [(x_i - \hat{x}_i)^2 + (y_i - \hat{y}_i)^2] \\
&amp;\quad + \sum_{i=0}^{S^2} \sum_{j=0}^{B} \mathbb{1}_{ij}^{\text{obj}} [(\sqrt{w_i} - \sqrt{\hat{w}_i})^2 + (\sqrt{h_i} - \sqrt{\hat{h}_i})^2]
\end{align}
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(\mathbb{1}_{ij}^{\text{obj}}\)</span> indicates if object appears in cell <span class="math inline">\(i\)</span> and predictor <span class="math inline">\(j\)</span> is responsible</li>
<li><span class="math inline">\((x_i, y_i, w_i, h_i)\)</span> are ground truth coordinates</li>
<li><span class="math inline">\((\hat{x}_i, \hat{y}_i, \hat{w}_i, \hat{h}_i)\)</span> are predicted coordinates</li>
</ul>
<p>The square root transformation for width and height ensures that errors in large boxes are weighted less heavily than errors in small boxes, addressing the scale sensitivity problem.</p>
</section>
<section id="confidence-loss" class="level3">
<h3 class="anchored" data-anchor-id="confidence-loss" id="confidence-loss">Confidence Loss</h3>
<p><span class="math display">\[
\begin{align}
\mathcal{L}_{\text{conf}} &amp;= \sum_{i=0}^{S^2} \sum_{j=0}^{B} \mathbb{1}_{ij}^{\text{obj}} (C_i - \hat{C}_i)^2 \\
&amp;\quad + \lambda_{\text{noobj}} \sum_{i=0}^{S^2} \sum_{j=0}^{B} \mathbb{1}_{ij}^{\text{noobj}} (C_i - \hat{C}_i)^2
\end{align}
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(C_i\)</span> is the confidence score (IoU between predicted and ground truth boxes)</li>
<li><span class="math inline">\(\hat{C}_i\)</span> is the predicted confidence</li>
<li><span class="math inline">\(\mathbb{1}_{ij}^{\text{noobj}}\)</span> indicates when no object is present</li>
<li><span class="math inline">\(\lambda_{\text{noobj}}\)</span> (typically 0.5) weights down the loss from confidence predictions for boxes that don’t contain objects</li>
</ul>
</section>
<section id="classification-loss" class="level3">
<h3 class="anchored" data-anchor-id="classification-loss" id="classification-loss">Classification Loss</h3>
<p><span class="math display">\[
\mathcal{L}_{\text{class}} = \sum_{i=0}^{S^2} \mathbb{1}_{i}^{\text{obj}} \sum_{c \in \text{classes}} (p_i(c) - \hat{p}_i(c))^2
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(p_i(c)\)</span> is the conditional class probability for class <span class="math inline">\(c\)</span></li>
<li><span class="math inline">\(\hat{p}_i(c)\)</span> is the predicted class probability</li>
<li><span class="math inline">\(\mathbb{1}_{i}^{\text{obj}}\)</span> indicates if an object appears in cell <span class="math inline">\(i\)</span></li>
</ul>
</section>
</section>
<section id="intersection-over-union-iou-calculations" class="level2">
<h2 class="anchored" data-anchor-id="intersection-over-union-iou-calculations" id="intersection-over-union-iou-calculations">Intersection over Union (IoU) Calculations</h2>
<p>IoU is fundamental to YOLO’s operation, used in both training and inference:</p>
<p><span class="math display">\[
\text{IoU} = \frac{\text{Area}(\text{Intersection})}{\text{Area}(\text{Union})}
\]</span></p>
<p>For two boxes with corners <span class="math inline">\((x_1,y_1,x_2,y_2)\)</span> and <span class="math inline">\((x_1',y_1',x_2',y_2')\)</span>:</p>
<p><span class="math display">\[
\begin{align}
\text{Intersection Area} &amp;= \max(0, \min(x_2,x_2') - \max(x_1,x_1')) \\
&amp;\quad \times \max(0, \min(y_2,y_2') - \max(y_1,y_1')) \\
\text{Union Area} &amp;= (x_2-x_1)(y_2-y_1) + (x_2'-x_1')(y_2'-y_1') \\
&amp;\quad - \text{Intersection Area}
\end{align}
\]</span></p>
</section>
<section id="non-maximum-suppression-nms" class="level2">
<h2 class="anchored" data-anchor-id="non-maximum-suppression-nms" id="non-maximum-suppression-nms">Non-Maximum Suppression (NMS)</h2>
<p>NMS eliminates redundant detections using IoU-based suppression:</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>NMS Algorithm
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li>Sort detections by confidence score (descending)</li>
<li>While detections remain:
<ol type="a">
<li>Select highest confidence detection</li>
<li>Remove all detections with IoU &gt; threshold with selected detection</li>
<li>Add selected detection to final output</li>
</ol></li>
</ol>
</div>
</div>
<p>The mathematical condition for suppression: <span class="math display">\[
\text{Suppress if } \text{IoU}(\text{box}_i, \text{box}_j) &gt; \tau_{\text{NMS}} \text{ AND } \text{conf}(\text{box}_i) &lt; \text{conf}(\text{box}_j)
\]</span></p>
<p>where <span class="math inline">\(\tau_{\text{NMS}}\)</span> is the NMS threshold.</p>
</section>
<section id="anchor-box-mathematics-yolov2" class="level2">
<h2 class="anchored" data-anchor-id="anchor-box-mathematics-yolov2" id="anchor-box-mathematics-yolov2">Anchor Box Mathematics (YOLOv2+)</h2>
<p>YOLOv2 introduced anchor boxes to improve small object detection:</p>
<section id="anchor-box-selection" class="level3">
<h3 class="anchored" data-anchor-id="anchor-box-selection" id="anchor-box-selection">Anchor Box Selection</h3>
<p>Anchor boxes are selected using K-means clustering on training set bounding boxes, with a custom distance metric:</p>
<p><span class="math display">\[
d(\text{box}, \text{centroid}) = 1 - \text{IoU}(\text{box}, \text{centroid})
\]</span></p>
<p>This ensures that anchor boxes are chosen to maximize IoU with ground truth boxes rather than Euclidean distance.</p>
</section>
<section id="prediction-with-anchors" class="level3">
<h3 class="anchored" data-anchor-id="prediction-with-anchors" id="prediction-with-anchors">Prediction with Anchors</h3>
<p>With anchor boxes, the prediction formulation becomes:</p>
<p><span class="math display">\[
\begin{align}
x &amp;= \sigma(t_x) + c_x \\
y &amp;= \sigma(t_y) + c_y \\
w &amp;= p_w \times \exp(t_w) \\
h &amp;= p_h \times \exp(t_h)
\end{align}
\]</span></p>
<p>where <span class="math inline">\(p_w\)</span> and <span class="math inline">\(p_h\)</span> are the anchor box dimensions.</p>
</section>
</section>
<section id="mathematical-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-optimizations" id="mathematical-optimizations">Mathematical Optimizations</h2>
<section id="gradient-flow-analysis" class="level3">
<h3 class="anchored" data-anchor-id="gradient-flow-analysis" id="gradient-flow-analysis">Gradient Flow Analysis</h3>
<p>The sigmoid activation in coordinate prediction ensures bounded gradients:</p>
<p><span class="math display">\[
\frac{\partial \mathcal{L}}{\partial t_x} = \frac{\partial \mathcal{L}}{\partial x} \times \frac{\partial x}{\partial t_x} = \frac{\partial \mathcal{L}}{\partial x} \times \sigma(t_x)(1-\sigma(t_x))
\]</span></p>
<p>This prevents exploding gradients while maintaining sensitivity to coordinate adjustments.</p>
</section>
<section id="multi-scale-training-mathematics" class="level3">
<h3 class="anchored" data-anchor-id="multi-scale-training-mathematics" id="multi-scale-training-mathematics">Multi-Scale Training Mathematics</h3>
<p>YOLOv2 employs multi-scale training by randomly resizing images during training:</p>
<p><span class="math display">\[
\text{Scale factor} = \frac{\text{random choice}([320, 352, 384, 416, 448, 480, 512, 544, 576, 608])}{416}
\]</span></p>
<p>This mathematical approach improves robustness across different input resolutions.</p>
</section>
</section>
<section id="computational-complexity-analysis" class="level2">
<h2 class="anchored" data-anchor-id="computational-complexity-analysis" id="computational-complexity-analysis">Computational Complexity Analysis</h2>
<section id="forward-pass-complexity" class="level3">
<h3 class="anchored" data-anchor-id="forward-pass-complexity" id="forward-pass-complexity">Forward Pass Complexity</h3>
<p>For a network with <span class="math inline">\(L\)</span> layers and an input of size <span class="math inline">\(W \times H \times C\)</span>:</p>
<ul>
<li>Convolutional layers: <span class="math inline">\(O(W \times H \times C_{\text{in}} \times C_{\text{out}} \times K^2)\)</span> per layer</li>
<li>Total complexity: <span class="math inline">\(O(W \times H \times \sum(C_{\text{in}} \times C_{\text{out}} \times K^2))\)</span></li>
</ul>
</section>
<section id="inference-speed-mathematics" class="level3">
<h3 class="anchored" data-anchor-id="inference-speed-mathematics" id="inference-speed-mathematics">Inference Speed Mathematics</h3>
<p>YOLO’s single forward pass eliminates the need for region proposal networks:</p>
<ul>
<li>Traditional methods: <span class="math inline">\(O(N \times \text{Forward pass})\)</span> where <span class="math inline">\(N\)</span> is number of proposals</li>
<li>YOLO: <span class="math inline">\(O(1 \times \text{Forward pass})\)</span></li>
</ul>
<p>This represents a significant computational advantage.</p>
</section>
</section>
<section id="advanced-mathematical-concepts" class="level2">
<h2 class="anchored" data-anchor-id="advanced-mathematical-concepts" id="advanced-mathematical-concepts">Advanced Mathematical Concepts</h2>
<section id="focal-loss-integration-yolov3" class="level3">
<h3 class="anchored" data-anchor-id="focal-loss-integration-yolov3" id="focal-loss-integration-yolov3">Focal Loss Integration (YOLOv3+)</h3>
<p>Some YOLO variants incorporate focal loss to address class imbalance:</p>
<p><span class="math display">\[
\text{Focal Loss} = -\alpha(1-p_t)^\gamma \log(p_t)
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(p_t\)</span> is the predicted probability for the true class</li>
<li><span class="math inline">\(\alpha\)</span> is a weighting factor</li>
<li><span class="math inline">\(\gamma\)</span> is the focusing parameter</li>
</ul>
</section>
<section id="feature-pyramid-networks-mathematics" class="level3">
<h3 class="anchored" data-anchor-id="feature-pyramid-networks-mathematics" id="feature-pyramid-networks-mathematics">Feature Pyramid Networks Mathematics</h3>
<p>YOLOv3 uses feature pyramids with mathematical upsampling:</p>
<p><span class="math display">\[
\begin{align}
\text{Upsampled feature} &amp;= \text{Interpolate}(\text{Lower resolution feature}, \text{scale factor}=2) \\
\text{Combined feature} &amp;= \text{Concat}(\text{Upsampled feature}, \text{Higher resolution feature})
\end{align}
\]</span></p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The mathematical foundation of YOLO demonstrates elegant solutions to complex computer vision problems. By formulating object detection as a single regression problem, YOLO achieves remarkable efficiency while maintaining accuracy. The careful design of the loss function, coordinate encoding, and architectural choices reflects deep mathematical insights into the nature of object detection.</p>
<p>Understanding these mathematical principles is crucial for practitioners seeking to modify, improve, or adapt YOLO for specific applications. The balance between localization accuracy, confidence prediction, and classification performance showcases how mathematical rigor can lead to practical breakthroughs in computer vision.</p>
<p>The evolution from YOLO to YOLOv8 and beyond continues to build upon these mathematical foundations, incorporating advances in deep learning theory while maintaining the core insight that object detection can be efficiently solved through direct prediction rather than complex multi-stage pipelines.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>YOLO’s unified architecture treats object detection as a single regression problem</li>
<li>The grid-based approach with mathematical coordinate encoding ensures bounded predictions</li>
<li>The multi-part loss function balances localization, confidence, and classification objectives</li>
<li>Mathematical optimizations like anchor boxes and multi-scale training improve performance</li>
<li>Understanding the mathematical foundations enables effective adaptation and improvement</li>
</ul>
</div>
</div>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[YOLO (You Only Look Once): A Comprehensive Beginner’s Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-summary/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-summary/</guid>
      <pubDate>Sat, 12 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="yolo-you-only-look-once-a-comprehensive-beginners-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/you-only-look-once/yolo-summary/yolo.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>In the rapidly evolving world of computer vision and artificial intelligence, few innovations have been as transformative as YOLO (You Only Look Once). This revolutionary object detection algorithm has fundamentally changed how computers “see” and understand images, making real-time object detection accessible to developers, researchers, and businesses worldwide.</p>
<p>YOLO represents a paradigm shift from traditional object detection methods, offering unprecedented speed without significantly compromising accuracy. Whether you’re a student exploring computer vision, a developer building AI applications, or simply curious about how machines can identify objects in images, this guide will take you through everything you need to know about YOLO.</p>
</section>
<section id="what-is-yolo" class="level2">
<h2 class="anchored" data-anchor-id="what-is-yolo" id="what-is-yolo">What is YOLO?</h2>
<p>YOLO, which stands for “You Only Look Once,” is a state-of-the-art object detection algorithm that can identify and locate multiple objects within an image in real-time. Unlike traditional methods that examine an image multiple times to detect objects, YOLO processes the entire image in a single forward pass through a neural network, hence the name “You Only Look Once.”</p>
<p>The algorithm doesn’t just identify what objects are present in an image; it also determines their precise locations by drawing bounding boxes around them. This dual capability of classification and localization makes YOLO incredibly powerful for a wide range of applications.</p>
</section>
<section id="the-problem-yolo-solves" class="level2">
<h2 class="anchored" data-anchor-id="the-problem-yolo-solves" id="the-problem-yolo-solves">The Problem YOLO Solves</h2>
<p>Before YOLO, object detection was a complex, multi-step process that was both computationally expensive and time-consuming. Traditional approaches like R-CNN (Region-based Convolutional Neural Networks) would:</p>
<ol type="1">
<li>Generate thousands of potential object regions in an image</li>
<li>Run a classifier on each region separately</li>
<li>Post-process the results to eliminate duplicates</li>
</ol>
<p>This approach, while accurate, was incredibly slow. Processing a single image could take several seconds, making real-time applications virtually impossible.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>YOLO revolutionized this by treating object detection as a single regression problem. Instead of looking at an image multiple times, YOLO divides the image into a grid and predicts bounding boxes and class probabilities for each grid cell simultaneously.</p>
</div>
</div>
</section>
<section id="how-yolo-works-the-core-concept" class="level2">
<h2 class="anchored" data-anchor-id="how-yolo-works-the-core-concept" id="how-yolo-works-the-core-concept">How YOLO Works: The Core Concept</h2>
<section id="grid-based-approach" class="level3">
<h3 class="anchored" data-anchor-id="grid-based-approach" id="grid-based-approach">Grid-Based Approach</h3>
<p>YOLO divides an input image into an S×S grid (commonly 7×7 or 13×13). Each grid cell is responsible for detecting objects whose centers fall within that cell. This approach ensures that every part of the image is examined exactly once.</p>
</section>
<section id="bounding-box-prediction" class="level3">
<h3 class="anchored" data-anchor-id="bounding-box-prediction" id="bounding-box-prediction">Bounding Box Prediction</h3>
<p>For each grid cell, YOLO predicts:</p>
<ul>
<li><strong>B bounding boxes</strong> (typically 2 or 3 per cell)</li>
<li><strong>Confidence scores</strong> for each bounding box</li>
<li><strong>Class probabilities</strong> for each grid cell</li>
</ul>
<p>Each bounding box consists of five values:</p>
<ul>
<li><strong>x, y</strong>: Center coordinates of the box (relative to the grid cell)</li>
<li><strong>width, height</strong>: Dimensions of the box (relative to the entire image)</li>
<li><strong>Confidence score</strong>: Probability that the box contains an object</li>
</ul>
</section>
<section id="class-prediction" class="level3">
<h3 class="anchored" data-anchor-id="class-prediction" id="class-prediction">Class Prediction</h3>
<p>Each grid cell also predicts the probability of each object class (person, car, dog, etc.) being present in that cell. This creates a comprehensive understanding of both what objects are present and where they’re located.</p>
</section>
<section id="network-architecture" class="level3">
<h3 class="anchored" data-anchor-id="network-architecture" id="network-architecture">Network Architecture</h3>
<p>The YOLO network is based on a convolutional neural network (CNN) architecture. The original YOLO used a modified version of the GoogLeNet architecture, but subsequent versions have evolved to use more efficient designs.</p>
<p>The network consists of:</p>
<ul>
<li><strong>Convolutional layers</strong> for feature extraction</li>
<li><strong>Fully connected layers</strong> for prediction</li>
<li><strong>Output layer</strong> that produces the final detection results</li>
</ul>
</section>
</section>
<section id="evolution-of-yolo-from-v1-to-v8" class="level2">
<h2 class="anchored" data-anchor-id="evolution-of-yolo-from-v1-to-v8" id="evolution-of-yolo-from-v1-to-v8">Evolution of YOLO: From v1 to v8</h2>
<section id="yolov1-2015" class="level3">
<h3 class="anchored" data-anchor-id="yolov1-2015" id="yolov1-2015">YOLOv1 (2015)</h3>
<p>The original YOLO introduced the revolutionary single-shot detection concept. While groundbreaking, it had limitations in detecting small objects and struggled with objects that were close together.</p>
</section>
<section id="yolov2-2016" class="level3">
<h3 class="anchored" data-anchor-id="yolov2-2016" id="yolov2-2016">YOLOv2 (2016)</h3>
<p>Also known as YOLO9000, this version introduced:</p>
<ul>
<li>Batch normalization for improved training</li>
<li>Anchor boxes for better bounding box predictions</li>
<li>Higher resolution training</li>
<li>Multi-scale training for robustness</li>
</ul>
</section>
<section id="yolov3-2018" class="level3">
<h3 class="anchored" data-anchor-id="yolov3-2018" id="yolov3-2018">YOLOv3 (2018)</h3>
<p>Significant improvements included:</p>
<ul>
<li>Feature Pyramid Networks (FPN) for better multi-scale detection</li>
<li>Logistic regression for object confidence</li>
<li>Multi-label classification capability</li>
<li>Darknet-53 backbone for improved feature extraction</li>
</ul>
</section>
<section id="yolov4-2020" class="level3">
<h3 class="anchored" data-anchor-id="yolov4-2020" id="yolov4-2020">YOLOv4 (2020)</h3>
<p>Focused on optimization and practical improvements:</p>
<ul>
<li>CSPDarkNet53 backbone</li>
<li>SPP (Spatial Pyramid Pooling) block</li>
<li>PANet path aggregation</li>
<li>Extensive use of data augmentation techniques</li>
</ul>
</section>
<section id="yolov5-2020" class="level3">
<h3 class="anchored" data-anchor-id="yolov5-2020" id="yolov5-2020">YOLOv5 (2020)</h3>
<p>Developed by Ultralytics, not the original authors:</p>
<ul>
<li>PyTorch implementation for easier use</li>
<li>Improved training procedures</li>
<li>Better model scaling</li>
<li>Enhanced user experience and documentation</li>
</ul>
</section>
<section id="yolov6-v8-2021-2023" class="level3">
<h3 class="anchored" data-anchor-id="yolov6-v8-2021-2023" id="yolov6-v8-2021-2023">YOLOv6-v8 (2021-2023)</h3>
<p>Continued refinements focusing on:</p>
<ul>
<li>Improved accuracy-speed trade-offs</li>
<li>Better mobile and edge device support</li>
<li>Enhanced training techniques</li>
<li>More robust architectures</li>
</ul>
</section>
</section>
<section id="key-advantages-of-yolo" class="level2">
<h2 class="anchored" data-anchor-id="key-advantages-of-yolo" id="key-advantages-of-yolo">Key Advantages of YOLO</h2>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Why Choose YOLO?
</div>
</div>
<div class="callout-body-container callout-body">
<p>YOLO’s main advantages make it ideal for real-time applications:</p>
<ul>
<li><strong>Speed</strong>: Single-pass approach enables real-time processing</li>
<li><strong>Global Context</strong>: Sees entire image for better understanding</li>
<li><strong>Simplicity</strong>: Unified architecture for easy implementation</li>
<li><strong>End-to-End Training</strong>: Optimizes entire pipeline jointly</li>
</ul>
</div>
</div>
<section id="speed" class="level3">
<h3 class="anchored" data-anchor-id="speed" id="speed">Speed</h3>
<p>YOLO’s single-pass approach makes it incredibly fast. Modern versions can process images at 30+ frames per second on standard hardware, enabling real-time applications.</p>
</section>
<section id="global-context" class="level3">
<h3 class="anchored" data-anchor-id="global-context" id="global-context">Global Context</h3>
<p>Unlike sliding window approaches, YOLO sees the entire image during training and testing, allowing it to understand global context and make more informed predictions.</p>
</section>
<section id="generalization" class="level3">
<h3 class="anchored" data-anchor-id="generalization" id="generalization">Generalization</h3>
<p>YOLO learns generalizable representations of objects, making it perform well on new, unseen images and different domains.</p>
</section>
<section id="simplicity" class="level3">
<h3 class="anchored" data-anchor-id="simplicity" id="simplicity">Simplicity</h3>
<p>The unified architecture makes YOLO easier to understand, implement, and modify compared to multi-stage detection systems.</p>
</section>
<section id="end-to-end-training" class="level3">
<h3 class="anchored" data-anchor-id="end-to-end-training" id="end-to-end-training">End-to-End Training</h3>
<p>The entire detection pipeline can be optimized jointly, leading to better overall performance.</p>
</section>
</section>
<section id="common-applications" class="level2">
<h2 class="anchored" data-anchor-id="common-applications" id="common-applications">Common Applications</h2>
<section id="autonomous-vehicles" class="level3">
<h3 class="anchored" data-anchor-id="autonomous-vehicles" id="autonomous-vehicles">Autonomous Vehicles</h3>
<p>YOLO is widely used in self-driving cars to detect pedestrians, other vehicles, traffic signs, and road obstacles in real-time.</p>
</section>
<section id="security-and-surveillance" class="level3">
<h3 class="anchored" data-anchor-id="security-and-surveillance" id="security-and-surveillance">Security and Surveillance</h3>
<p>Security systems use YOLO to detect unauthorized persons, suspicious activities, or specific objects in video feeds.</p>
</section>
<section id="retail-and-inventory-management" class="level3">
<h3 class="anchored" data-anchor-id="retail-and-inventory-management" id="retail-and-inventory-management">Retail and Inventory Management</h3>
<p>Stores use YOLO for automated checkout systems, inventory tracking, and customer behavior analysis.</p>
</section>
<section id="sports-analytics" class="level3">
<h3 class="anchored" data-anchor-id="sports-analytics" id="sports-analytics">Sports Analytics</h3>
<p>YOLO tracks players, balls, and other objects in sports videos for performance analysis and automated highlighting.</p>
</section>
<section id="medical-imaging" class="level3">
<h3 class="anchored" data-anchor-id="medical-imaging" id="medical-imaging">Medical Imaging</h3>
<p>In healthcare, YOLO assists in detecting anomalies in medical images, though this requires specialized training and validation.</p>
</section>
<section id="industrial-automation" class="level3">
<h3 class="anchored" data-anchor-id="industrial-automation" id="industrial-automation">Industrial Automation</h3>
<p>Manufacturing uses YOLO for quality control, defect detection, and automated sorting systems.</p>
</section>
</section>
<section id="getting-started-with-yolo" class="level2">
<h2 class="anchored" data-anchor-id="getting-started-with-yolo" id="getting-started-with-yolo">Getting Started with YOLO</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<ul>
<li>Basic understanding of machine learning concepts</li>
<li>Familiarity with Python programming</li>
<li>Understanding of computer vision fundamentals</li>
<li>Knowledge of deep learning frameworks (PyTorch or TensorFlow)</li>
</ul>
</section>
<section id="installation" class="level3">
<h3 class="anchored" data-anchor-id="installation" id="installation">Installation</h3>
<p>The easiest way to get started is with YOLOv5 or YOLOv8 using Ultralytics:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install ultralytics</span></code></pre></div></div>
</section>
<section id="basic-usage-example" class="level3">
<h3 class="anchored" data-anchor-id="basic-usage-example" id="basic-usage-example">Basic Usage Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> ultralytics <span class="im">import</span> YOLO</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Load a pre-trained model</span></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> YOLO(<span class="st">'yolov8n.pt'</span>)</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Run inference on an image</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> model(<span class="st">'path/to/image.jpg'</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Display results</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>results[<span class="dv">0</span>].show()</span></code></pre></div></div>
</section>
<section id="training-on-custom-data" class="level3">
<h3 class="anchored" data-anchor-id="training-on-custom-data" id="training-on-custom-data">Training on Custom Data</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Training Requirements
</div>
</div>
<div class="callout-body-container callout-body">
<p>Before training on custom data, ensure you have:</p>
<ol type="1">
<li><strong>Prepared dataset</strong> in YOLO format</li>
<li><strong>Configuration file</strong> specifying classes and paths</li>
<li><strong>Adequate computational resources</strong> for training</li>
<li><strong>Validation strategy</strong> for model evaluation</li>
</ol>
</div>
</div>
<ol type="1">
<li><strong>Prepare your dataset</strong> in YOLO format</li>
<li><strong>Create a configuration file</strong> specifying classes and paths</li>
<li><strong>Train the model</strong> using the provided training scripts</li>
<li><strong>Evaluate and fine-tune</strong> the model performance</li>
</ol>
</section>
</section>
<section id="understanding-yolo-output" class="level2">
<h2 class="anchored" data-anchor-id="understanding-yolo-output" id="understanding-yolo-output">Understanding YOLO Output</h2>
<section id="bounding-boxes" class="level3">
<h3 class="anchored" data-anchor-id="bounding-boxes" id="bounding-boxes">Bounding Boxes</h3>
<p>Each detected object is represented by a bounding box with coordinates (x, y, width, height) and a confidence score.</p>
</section>
<section id="class-predictions" class="level3">
<h3 class="anchored" data-anchor-id="class-predictions" id="class-predictions">Class Predictions</h3>
<p>Each bounding box includes class probabilities indicating what type of object was detected.</p>
</section>
<section id="confidence-scores" class="level3">
<h3 class="anchored" data-anchor-id="confidence-scores" id="confidence-scores">Confidence Scores</h3>
<p>These indicate how certain the model is about the detection. Higher scores mean more confident detections.</p>
</section>
</section>
<section id="common-challenges-and-solutions" class="level2">
<h2 class="anchored" data-anchor-id="common-challenges-and-solutions" id="common-challenges-and-solutions">Common Challenges and Solutions</h2>
<section id="small-object-detection" class="level3">
<h3 class="anchored" data-anchor-id="small-object-detection" id="small-object-detection">Small Object Detection</h3>
<ul>
<li><strong>Challenge</strong>: YOLO traditionally struggles with very small objects.</li>
<li><strong>Solution</strong>: Use higher resolution inputs, multi-scale training, and feature pyramid networks.</li>
</ul>
</section>
<section id="overlapping-objects" class="level3">
<h3 class="anchored" data-anchor-id="overlapping-objects" id="overlapping-objects">Overlapping Objects</h3>
<ul>
<li><strong>Challenge</strong>: Objects that overlap significantly can be difficult to detect separately.</li>
<li><strong>Solution</strong>: Non-maximum suppression and improved anchor box strategies help address this.</li>
</ul>
</section>
<section id="class-imbalance" class="level3">
<h3 class="anchored" data-anchor-id="class-imbalance" id="class-imbalance">Class Imbalance</h3>
<ul>
<li><strong>Challenge</strong>: Some object classes may be underrepresented in training data.</li>
<li><strong>Solution</strong>: Use data augmentation, balanced sampling, and focal loss techniques.</li>
</ul>
</section>
<section id="domain-adaptation" class="level3">
<h3 class="anchored" data-anchor-id="domain-adaptation" id="domain-adaptation">Domain Adaptation</h3>
<ul>
<li><strong>Challenge</strong>: Models trained on one type of data may not work well on different domains.</li>
<li><strong>Solution</strong>: Transfer learning, domain adaptation techniques, and diverse training data.</li>
</ul>
</section>
</section>
<section id="best-practices-for-yolo-implementation" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-for-yolo-implementation" id="best-practices-for-yolo-implementation">Best Practices for YOLO Implementation</h2>
<section id="data-preparation" class="level3">
<h3 class="anchored" data-anchor-id="data-preparation" id="data-preparation">Data Preparation</h3>
<ul>
<li>Ensure high-quality, diverse training data</li>
<li>Use proper annotation tools and formats</li>
<li>Implement data augmentation techniques</li>
<li>Maintain balanced class distributions</li>
</ul>
</section>
<section id="training-optimization" class="level3">
<h3 class="anchored" data-anchor-id="training-optimization" id="training-optimization">Training Optimization</h3>
<ul>
<li>Start with pre-trained weights</li>
<li>Use appropriate learning rates and schedules</li>
<li>Monitor training metrics carefully</li>
<li>Implement early stopping to prevent overfitting</li>
</ul>
</section>
<section id="model-selection" class="level3">
<h3 class="anchored" data-anchor-id="model-selection" id="model-selection">Model Selection</h3>
<ul>
<li>Choose the right YOLO version for your speed-accuracy requirements</li>
<li>Consider model size constraints for deployment</li>
<li>Evaluate different backbone architectures</li>
</ul>
</section>
<section id="post-processing" class="level3">
<h3 class="anchored" data-anchor-id="post-processing" id="post-processing">Post-Processing</h3>
<ul>
<li>Tune non-maximum suppression parameters</li>
<li>Set appropriate confidence thresholds</li>
<li>Implement tracking for video applications</li>
</ul>
</section>
</section>
<section id="performance-metrics" class="level2">
<h2 class="anchored" data-anchor-id="performance-metrics" id="performance-metrics">Performance Metrics</h2>
<section id="mean-average-precision-map" class="level3">
<h3 class="anchored" data-anchor-id="mean-average-precision-map" id="mean-average-precision-map">Mean Average Precision (mAP)</h3>
<p>The primary metric for evaluating object detection performance, measuring accuracy across different confidence thresholds.</p>
</section>
<section id="intersection-over-union-iou" class="level3">
<h3 class="anchored" data-anchor-id="intersection-over-union-iou" id="intersection-over-union-iou">Intersection over Union (IoU)</h3>
<p>Measures the overlap between predicted and ground truth bounding boxes.</p>
</section>
<section id="frames-per-second-fps" class="level3">
<h3 class="anchored" data-anchor-id="frames-per-second-fps" id="frames-per-second-fps">Frames Per Second (FPS)</h3>
<p>Measures the speed of the detection system, crucial for real-time applications.</p>
</section>
<section id="model-size" class="level3">
<h3 class="anchored" data-anchor-id="model-size" id="model-size">Model Size</h3>
<p>Important for deployment on resource-constrained devices.</p>
</section>
</section>
<section id="future-trends-and-developments" class="level2">
<h2 class="anchored" data-anchor-id="future-trends-and-developments" id="future-trends-and-developments">Future Trends and Developments</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Emerging Trends
</div>
</div>
<div class="callout-body-container callout-body">
<p>The future of YOLO and object detection includes:</p>
<ul>
<li><strong>Transformer-based architectures</strong> for improved attention mechanisms</li>
<li><strong>Mobile optimization</strong> for edge deployment</li>
<li><strong>Multi-modal detection</strong> combining visual and other data</li>
<li><strong>Self-supervised learning</strong> to reduce labeling requirements</li>
</ul>
</div>
</div>
<section id="transformer-based-architectures" class="level3">
<h3 class="anchored" data-anchor-id="transformer-based-architectures" id="transformer-based-architectures">Transformer-Based Architectures</h3>
<p>Integration of transformer models for improved feature extraction and attention mechanisms.</p>
</section>
<section id="mobile-and-edge-optimization" class="level3">
<h3 class="anchored" data-anchor-id="mobile-and-edge-optimization" id="mobile-and-edge-optimization">Mobile and Edge Optimization</h3>
<p>Continued focus on making YOLO more efficient for mobile and edge devices.</p>
</section>
<section id="multi-modal-detection" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-detection" id="multi-modal-detection">Multi-Modal Detection</h3>
<p>Combining visual information with other modalities like text or audio.</p>
</section>
<section id="improved-small-object-detection" class="level3">
<h3 class="anchored" data-anchor-id="improved-small-object-detection" id="improved-small-object-detection">Improved Small Object Detection</h3>
<p>Advanced techniques for detecting very small objects in high-resolution images.</p>
</section>
<section id="self-supervised-learning" class="level3">
<h3 class="anchored" data-anchor-id="self-supervised-learning" id="self-supervised-learning">Self-Supervised Learning</h3>
<p>Reducing dependence on labeled data through self-supervised training approaches.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>YOLO has democratized object detection by making it fast, accurate, and accessible to developers worldwide. Its evolution from the original 2015 paper to the latest versions demonstrates the rapid pace of innovation in computer vision.</p>
<p>Understanding YOLO opens doors to numerous applications across industries, from autonomous vehicles to retail analytics. The algorithm’s simplicity, combined with its powerful capabilities, makes it an essential tool in the modern AI toolkit.</p>
<p>As you begin your journey with YOLO, remember that practical experience is invaluable. Start with pre-trained models, experiment with different versions, and gradually work toward training custom models for your specific use cases. The computer vision community continues to push the boundaries of what’s possible with object detection, and YOLO remains at the forefront of these exciting developments.</p>
<p>Whether you’re building the next generation of smart cameras, developing autonomous systems, or simply exploring the fascinating world of computer vision, YOLO provides a solid foundation for understanding how machines can see and interpret the world around us.</p>
</section>
<section id="appendix-additional-resources" class="level2">
<h2 class="anchored" data-anchor-id="appendix-additional-resources" id="appendix-additional-resources">Appendix: Additional Resources</h2>
<section id="useful-links" class="level3">
<h3 class="anchored" data-anchor-id="useful-links" id="useful-links">Useful Links</h3>
<ul>
<li><a href="https://docs.ultralytics.com/">Ultralytics YOLO Documentation</a></li>
<li><a href="https://arxiv.org/abs/1506.02640">Original YOLO Paper</a></li>
<li><a href="https://github.com/ultralytics/ultralytics">YOLOv8 GitHub Repository</a></li>
</ul>
</section>
<section id="code-examples" class="level3">
<h3 class="anchored" data-anchor-id="code-examples" id="code-examples">Code Examples</h3>
<p>Additional code examples and tutorials can be found in the project repository.</p>



</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Neural Architecture Search: Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-code/</guid>
      <pubDate>Fri, 11 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="neural-architecture-search-complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-code/nas-code.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Neural Architecture Search (NAS) is an automated approach to designing neural network architectures. Instead of manually crafting network designs, NAS algorithms explore the space of possible architectures to find optimal configurations for specific tasks.</p>
<section id="why-nas-matters" class="level3">
<h3 class="anchored" data-anchor-id="why-nas-matters" id="why-nas-matters">Why NAS Matters</h3>
<ul>
<li><strong>Automation</strong>: Reduces human effort in architecture design</li>
<li><strong>Performance</strong>: Can discover architectures that outperform human-designed ones</li>
<li><strong>Efficiency</strong>: Optimizes for specific constraints (latency, memory, energy)</li>
<li><strong>Scalability</strong>: Adapts to different tasks and domains</li>
</ul>
</section>
</section>
<section id="theoretical-foundations" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-foundations" id="theoretical-foundations">Theoretical Foundations</h2>
<section id="the-nas-framework" class="level3">
<h3 class="anchored" data-anchor-id="the-nas-framework" id="the-nas-framework">The NAS Framework</h3>
<p>NAS consists of three main components:</p>
<ol type="1">
<li><strong>Search Space</strong>: Defines the set of possible architectures</li>
<li><strong>Search Strategy</strong>: Determines how to explore the search space</li>
<li><strong>Performance Estimation</strong>: Evaluates architecture quality</li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Dict, Tuple, Optional</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> defaultdict</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> NASFramework:</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, search_space, search_strategy, performance_estimator):</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_space <span class="op">=</span> search_space</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_strategy <span class="op">=</span> search_strategy</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.performance_estimator <span class="op">=</span> performance_estimator</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.history <span class="op">=</span> []</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> search(<span class="va">self</span>, num_iterations: <span class="bu">int</span>):</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Main NAS loop"""</span></span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> iteration <span class="kw">in</span> <span class="bu">range</span>(num_iterations):</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample architecture from search space</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>            architecture <span class="op">=</span> <span class="va">self</span>.search_strategy.sample_architecture(</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.search_space, <span class="va">self</span>.history</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Evaluate architecture</span></span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>            performance <span class="op">=</span> <span class="va">self</span>.performance_estimator.evaluate(architecture)</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update history</span></span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.history.append({</span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">'architecture'</span>: architecture,</span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">'performance'</span>: performance,</span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">'iteration'</span>: iteration</span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update search strategy</span></span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.search_strategy.update(architecture, performance)</span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.get_best_architecture()</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_best_architecture(<span class="va">self</span>):</span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">max</span>(<span class="va">self</span>.history, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="st">'performance'</span>])</span></code></pre></div></div>
</section>
</section>
<section id="search-space-design" class="level2">
<h2 class="anchored" data-anchor-id="search-space-design" id="search-space-design">Search Space Design</h2>
<section id="macro-search-space" class="level3">
<h3 class="anchored" data-anchor-id="macro-search-space" id="macro-search-space">Macro Search Space</h3>
<p>Defines the overall structure of the network (number of layers, skip connections, etc.).</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MacroSearchSpace:</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, max_layers: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>, operations: List[<span class="bu">str</span>] <span class="op">=</span> <span class="va">None</span>):</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_layers <span class="op">=</span> max_layers</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.operations <span class="op">=</span> operations <span class="kw">or</span> [</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv3x3'</span>, <span class="st">'conv5x5'</span>, <span class="st">'conv7x7'</span>, <span class="st">'maxpool3x3'</span>, </span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">'avgpool3x3'</span>, <span class="st">'identity'</span>, <span class="st">'zero'</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>) <span class="op">-&gt;</span> Dict:</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample a random architecture"""</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        num_layers <span class="op">=</span> random.randint(<span class="dv">8</span>, <span class="va">self</span>.max_layers)</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        architecture <span class="op">=</span> {</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">'layers'</span>: [],</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">'skip_connections'</span>: []</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_layers):</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>            layer <span class="op">=</span> {</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>                <span class="st">'operation'</span>: random.choice(<span class="va">self</span>.operations),</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>                <span class="st">'filters'</span>: random.choice([<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">512</span>]),</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>                <span class="st">'kernel_size'</span>: random.choice([<span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>]) <span class="cf">if</span> <span class="st">'conv'</span> <span class="kw">in</span> <span class="va">self</span>.operations[<span class="dv">0</span>] <span class="cf">else</span> <span class="dv">3</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>            architecture[<span class="st">'layers'</span>].append(layer)</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add skip connections</span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, num_layers):</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.3</span>:  <span class="co"># 30% chance of skip connection</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>                source <span class="op">=</span> random.randint(<span class="dv">0</span>, i<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>                architecture[<span class="st">'skip_connections'</span>].append((source, i))</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> architecture</span></code></pre></div></div>
</section>
<section id="micro-search-space-cell-based" class="level3">
<h3 class="anchored" data-anchor-id="micro-search-space-cell-based" id="micro-search-space-cell-based">Micro Search Space (Cell-based)</h3>
<p>Focuses on designing building blocks (cells) that are repeated throughout the network.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CellSearchSpace:</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_nodes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">7</span>, num_ops: <span class="bu">int</span> <span class="op">=</span> <span class="dv">8</span>):</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_nodes <span class="op">=</span> num_nodes</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.operations <span class="op">=</span> [</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>            <span class="st">'none'</span>, <span class="st">'max_pool_3x3'</span>, <span class="st">'avg_pool_3x3'</span>, <span class="st">'skip_connect'</span>,</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">'sep_conv_3x3'</span>, <span class="st">'sep_conv_5x5'</span>, <span class="st">'dil_conv_3x3'</span>, <span class="st">'dil_conv_5x5'</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_ops <span class="op">=</span> <span class="bu">len</span>(<span class="va">self</span>.operations)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_cell(<span class="va">self</span>) <span class="op">-&gt;</span> Dict:</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample a cell architecture"""</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        cell <span class="op">=</span> {</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">'normal_cell'</span>: <span class="va">self</span>._sample_single_cell(),</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">'reduction_cell'</span>: <span class="va">self</span>._sample_single_cell()</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> cell</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _sample_single_cell(<span class="va">self</span>) <span class="op">-&gt;</span> List[Tuple]:</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample a single cell with intermediate nodes"""</span></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        cell <span class="op">=</span> []</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">2</span>, <span class="va">self</span>.num_nodes <span class="op">+</span> <span class="dv">2</span>):  <span class="co"># Nodes 2 to num_nodes+1</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Each node has two inputs</span></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">2</span>):</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Sample input node (0 to i-1)</span></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>                input_node <span class="op">=</span> random.randint(<span class="dv">0</span>, i<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Sample operation</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>                operation <span class="op">=</span> random.choice(<span class="va">self</span>.operations)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>                cell.append((input_node, operation))</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> cell</span></code></pre></div></div>
</section>
<section id="differentiable-search-space" class="level3">
<h3 class="anchored" data-anchor-id="differentiable-search-space" id="differentiable-search-space">Differentiable Search Space</h3>
<p>Enables gradient-based optimization of architectures.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DifferentiableSearchSpace(nn.Module):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, operations: List[<span class="bu">str</span>], num_nodes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">4</span>):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.operations <span class="op">=</span> operations</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_nodes <span class="op">=</span> num_nodes</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_ops <span class="op">=</span> <span class="bu">len</span>(operations)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Architecture parameters (alpha)</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> nn.Parameter(torch.randn(num_nodes, num_ops))</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Operation modules</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ops <span class="op">=</span> nn.ModuleList([</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._get_operation(op) <span class="cf">for</span> op <span class="kw">in</span> operations</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _get_operation(<span class="va">self</span>, op_name: <span class="bu">str</span>) <span class="op">-&gt;</span> nn.Module:</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get operation module by name"""</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> op_name <span class="op">==</span> <span class="st">'conv3x3'</span>:</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> nn.Conv2d(<span class="dv">32</span>, <span class="dv">32</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'conv5x5'</span>:</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> nn.Conv2d(<span class="dv">32</span>, <span class="dv">32</span>, <span class="dv">5</span>, padding<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'maxpool3x3'</span>:</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> nn.MaxPool2d(<span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'avgpool3x3'</span>:</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> nn.AvgPool2d(<span class="dv">3</span>, stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'identity'</span>:</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> nn.Identity()</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'zero'</span>:</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> Zero()</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Unknown operation: </span><span class="sc">{</span>op_name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Softmax over operations</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        weights <span class="op">=</span> torch.softmax(<span class="va">self</span>.alpha, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mixed operation</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, op <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.ops):</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>            output <span class="op">+=</span> weights[<span class="dv">0</span>, i] <span class="op">*</span> op(x)  <span class="co"># Simplified for single node</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_discrete_architecture(<span class="va">self</span>):</span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Extract discrete architecture from continuous parameters"""</span></span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>        arch <span class="op">=</span> []</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> node <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_nodes):</span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>            best_op_idx <span class="op">=</span> torch.argmax(<span class="va">self</span>.alpha[node])</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>            arch.append(<span class="va">self</span>.operations[best_op_idx])</span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> arch</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Zero(nn.Module):</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.zeros_like(x)</span></code></pre></div></div>
</section>
</section>
<section id="search-strategies" class="level2">
<h2 class="anchored" data-anchor-id="search-strategies" id="search-strategies">Search Strategies</h2>
<section id="random-search" class="level3">
<h3 class="anchored" data-anchor-id="random-search" id="random-search">Random Search</h3>
<p>Simple baseline that samples architectures randomly.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RandomSearch:</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.history <span class="op">=</span> []</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>, search_space, history):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> search_space.sample_architecture()</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update(<span class="va">self</span>, architecture, performance):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.history.append((architecture, performance))</span></code></pre></div></div>
</section>
<section id="evolutionary-search" class="level3">
<h3 class="anchored" data-anchor-id="evolutionary-search" id="evolutionary-search">Evolutionary Search</h3>
<p>Uses genetic algorithms to evolve architectures.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EvolutionarySearch:</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, population_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">50</span>, mutation_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span>):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population_size <span class="op">=</span> population_size</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mutation_rate <span class="op">=</span> mutation_rate</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population <span class="op">=</span> []</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fitness_scores <span class="op">=</span> []</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> initialize_population(<span class="va">self</span>, search_space):</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize random population"""</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population <span class="op">=</span> [</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            search_space.sample_architecture() </span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.population_size)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>, search_space, history):</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.population:</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.initialize_population(search_space)</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.population[<span class="dv">0</span>]</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Tournament selection</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._tournament_selection()</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _tournament_selection(<span class="va">self</span>, tournament_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">3</span>):</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Select parent via tournament selection"""</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        tournament_indices <span class="op">=</span> random.sample(</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>            <span class="bu">range</span>(<span class="bu">len</span>(<span class="va">self</span>.population)), tournament_size</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        tournament_fitness <span class="op">=</span> [<span class="va">self</span>.fitness_scores[i] <span class="cf">for</span> i <span class="kw">in</span> tournament_indices]</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        winner_idx <span class="op">=</span> tournament_indices[np.argmax(tournament_fitness)]</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.population[winner_idx]</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update(<span class="va">self</span>, architecture, performance):</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update population with new architecture"""</span></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(<span class="va">self</span>.population) <span class="op">&lt;</span> <span class="va">self</span>.population_size:</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.population.append(architecture)</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.fitness_scores.append(performance)</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Replace worst performing architecture</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>            worst_idx <span class="op">=</span> np.argmin(<span class="va">self</span>.fitness_scores)</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> performance <span class="op">&gt;</span> <span class="va">self</span>.fitness_scores[worst_idx]:</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.population[worst_idx] <span class="op">=</span> architecture</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.fitness_scores[worst_idx] <span class="op">=</span> performance</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> mutate_architecture(<span class="va">self</span>, architecture, search_space):</span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Mutate architecture"""</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="va">self</span>.mutation_rate:</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simple mutation: change random operation</span></span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">'layers'</span> <span class="kw">in</span> architecture:</span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>                layer_idx <span class="op">=</span> random.randint(<span class="dv">0</span>, <span class="bu">len</span>(architecture[<span class="st">'layers'</span>]) <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>                architecture[<span class="st">'layers'</span>][layer_idx][<span class="st">'operation'</span>] <span class="op">=</span> random.choice(</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>                    search_space.operations</span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> architecture</span></code></pre></div></div>
</section>
<section id="reinforcement-learning-search" class="level3">
<h3 class="anchored" data-anchor-id="reinforcement-learning-search" id="reinforcement-learning-search">Reinforcement Learning Search</h3>
<p>Uses RL to learn architecture sampling policies.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RLController(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vocab_size: <span class="bu">int</span>, hidden_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">64</span>):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.vocab_size <span class="op">=</span> vocab_size</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.hidden_size <span class="op">=</span> hidden_size</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lstm <span class="op">=</span> nn.LSTM(vocab_size, hidden_size, batch_first<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(hidden_size, vocab_size)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        lstm_out, _ <span class="op">=</span> <span class="va">self</span>.lstm(x)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>.classifier(lstm_out)</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>, max_length: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>):</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample architecture using the controller"""</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.<span class="bu">eval</span>()</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>            sequence <span class="op">=</span> []</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            hidden <span class="op">=</span> <span class="va">None</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Start token</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            input_token <span class="op">=</span> torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, <span class="va">self</span>.vocab_size)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(max_length):</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>                logits, hidden <span class="op">=</span> <span class="va">self</span>.lstm(input_token, hidden)</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>                logits <span class="op">=</span> <span class="va">self</span>.classifier(logits)</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Sample next token</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>                probs <span class="op">=</span> torch.softmax(logits.squeeze(), dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>                next_token <span class="op">=</span> torch.multinomial(probs, <span class="dv">1</span>).item()</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>                sequence.append(next_token)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Prepare input for next step</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>                input_token <span class="op">=</span> torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, <span class="va">self</span>.vocab_size)</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>                input_token[<span class="dv">0</span>, <span class="dv">0</span>, next_token] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> sequence</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ReinforcementLearningSearch:</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, vocab_size: <span class="bu">int</span>, learning_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.001</span>):</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.controller <span class="op">=</span> RLController(vocab_size)</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> torch.optim.Adam(</span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.controller.parameters(), lr<span class="op">=</span>learning_rate</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.baseline <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.baseline_decay <span class="op">=</span> <span class="fl">0.99</span></span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sample_architecture(<span class="va">self</span>, search_space, history):</span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>        sequence <span class="op">=</span> <span class="va">self</span>.controller.sample_architecture()</span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._sequence_to_architecture(sequence, search_space)</span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _sequence_to_architecture(<span class="va">self</span>, sequence, search_space):</span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Convert sequence to architecture"""</span></span>
<span id="cb7-55"><a href="#cb7-55" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simplified conversion</span></span>
<span id="cb7-56"><a href="#cb7-56" aria-hidden="true" tabindex="-1"></a>        architecture <span class="op">=</span> {<span class="st">'layers'</span>: []}</span>
<span id="cb7-57"><a href="#cb7-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(sequence), <span class="dv">2</span>):</span>
<span id="cb7-58"><a href="#cb7-58" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> i <span class="op">+</span> <span class="dv">1</span> <span class="op">&lt;</span> <span class="bu">len</span>(sequence):</span>
<span id="cb7-59"><a href="#cb7-59" aria-hidden="true" tabindex="-1"></a>                op_idx <span class="op">=</span> sequence[i] <span class="op">%</span> <span class="bu">len</span>(search_space.operations)</span>
<span id="cb7-60"><a href="#cb7-60" aria-hidden="true" tabindex="-1"></a>                filter_idx <span class="op">=</span> sequence[i <span class="op">+</span> <span class="dv">1</span>] <span class="op">%</span> <span class="dv">4</span></span>
<span id="cb7-61"><a href="#cb7-61" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-62"><a href="#cb7-62" aria-hidden="true" tabindex="-1"></a>                layer <span class="op">=</span> {</span>
<span id="cb7-63"><a href="#cb7-63" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'operation'</span>: search_space.operations[op_idx],</span>
<span id="cb7-64"><a href="#cb7-64" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'filters'</span>: [<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>][filter_idx]</span>
<span id="cb7-65"><a href="#cb7-65" aria-hidden="true" tabindex="-1"></a>                }</span>
<span id="cb7-66"><a href="#cb7-66" aria-hidden="true" tabindex="-1"></a>                architecture[<span class="st">'layers'</span>].append(layer)</span>
<span id="cb7-67"><a href="#cb7-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-68"><a href="#cb7-68" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> architecture</span>
<span id="cb7-69"><a href="#cb7-69" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-70"><a href="#cb7-70" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update(<span class="va">self</span>, architecture, performance):</span>
<span id="cb7-71"><a href="#cb7-71" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update controller using REINFORCE"""</span></span>
<span id="cb7-72"><a href="#cb7-72" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update baseline</span></span>
<span id="cb7-73"><a href="#cb7-73" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.baseline <span class="op">=</span> <span class="va">self</span>.baseline_decay <span class="op">*</span> <span class="va">self</span>.baseline <span class="op">+</span> <span class="op">\</span></span>
<span id="cb7-74"><a href="#cb7-74" aria-hidden="true" tabindex="-1"></a>                       (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.baseline_decay) <span class="op">*</span> performance</span>
<span id="cb7-75"><a href="#cb7-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-76"><a href="#cb7-76" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate advantage</span></span>
<span id="cb7-77"><a href="#cb7-77" aria-hidden="true" tabindex="-1"></a>        advantage <span class="op">=</span> performance <span class="op">-</span> <span class="va">self</span>.baseline</span>
<span id="cb7-78"><a href="#cb7-78" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-79"><a href="#cb7-79" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update controller (simplified)</span></span>
<span id="cb7-80"><a href="#cb7-80" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb7-81"><a href="#cb7-81" aria-hidden="true" tabindex="-1"></a>        <span class="co"># In practice, you'd compute the log probability of the sampled architecture</span></span>
<span id="cb7-82"><a href="#cb7-82" aria-hidden="true" tabindex="-1"></a>        <span class="co"># and multiply by the advantage for the REINFORCE update</span></span>
<span id="cb7-83"><a href="#cb7-83" aria-hidden="true" tabindex="-1"></a>        <span class="co"># loss = -log_prob * advantage</span></span>
<span id="cb7-84"><a href="#cb7-84" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer.step()</span></code></pre></div></div>
</section>
<section id="differentiable-architecture-search-darts" class="level3">
<h3 class="anchored" data-anchor-id="differentiable-architecture-search-darts" id="differentiable-architecture-search-darts">Differentiable Architecture Search (DARTS)</h3>
<p>Gradient-based search using continuous relaxation.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DARTSSearch:</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model: DifferentiableSearchSpace, learning_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.025</span>):</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> torch.optim.SGD(</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.parameters(), lr<span class="op">=</span>learning_rate, momentum<span class="op">=</span><span class="fl">0.9</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.arch_optimizer <span class="op">=</span> torch.optim.Adam(</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>            [<span class="va">self</span>.model.alpha], lr<span class="op">=</span><span class="fl">3e-4</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> search_step(<span class="va">self</span>, train_data, val_data, criterion):</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Single search step in DARTS"""</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update architecture parameters</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.arch_optimizer.zero_grad()</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        val_loss <span class="op">=</span> <span class="va">self</span>._compute_val_loss(val_data, criterion)</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        val_loss.backward()</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.arch_optimizer.step()</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update model parameters</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        train_loss <span class="op">=</span> <span class="va">self</span>._compute_train_loss(train_data, criterion)</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        train_loss.backward()</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer.step()</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> train_loss.item(), val_loss.item()</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _compute_train_loss(<span class="va">self</span>, data, criterion):</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute training loss"""</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        inputs, targets <span class="op">=</span> data</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(inputs)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> criterion(outputs, targets)</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _compute_val_loss(<span class="va">self</span>, data, criterion):</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute validation loss"""</span></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        inputs, targets <span class="op">=</span> data</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(inputs)</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> criterion(outputs, targets)</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_final_architecture(<span class="va">self</span>):</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Extract final discrete architecture"""</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.model.get_discrete_architecture()</span></code></pre></div></div>
</section>
</section>
<section id="performance-estimation" class="level2">
<h2 class="anchored" data-anchor-id="performance-estimation" id="performance-estimation">Performance Estimation</h2>
<section id="full-training" class="level3">
<h3 class="anchored" data-anchor-id="full-training" id="full-training">Full Training</h3>
<p>Most accurate but computationally expensive.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FullTrainingEvaluator:</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dataset, num_epochs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100</span>):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dataset <span class="op">=</span> dataset</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_epochs <span class="op">=</span> num_epochs</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate(<span class="va">self</span>, architecture) <span class="op">-&gt;</span> <span class="bu">float</span>:</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate architecture by full training"""</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> <span class="va">self</span>._build_model(architecture)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Training loop</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.SGD(model.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_epochs):</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> batch <span class="kw">in</span> <span class="va">self</span>.dataset:</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>                inputs, targets <span class="op">=</span> batch</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>                optimizer.zero_grad()</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> model(inputs)</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>                loss.backward()</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate on validation set</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._evaluate_model(model)</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _build_model(<span class="va">self</span>, architecture):</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Build model from architecture description"""</span></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation depends on architecture format</span></span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _evaluate_model(<span class="va">self</span>, model):</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate model accuracy"""</span></span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation for model evaluation</span></span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</section>
<section id="early-stopping" class="level3">
<h3 class="anchored" data-anchor-id="early-stopping" id="early-stopping">Early Stopping</h3>
<p>Reduces training time while maintaining correlation with full training.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EarlyStoppingEvaluator:</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dataset, max_epochs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">20</span>, patience: <span class="bu">int</span> <span class="op">=</span> <span class="dv">5</span>):</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dataset <span class="op">=</span> dataset</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_epochs <span class="op">=</span> max_epochs</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patience <span class="op">=</span> patience</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate(<span class="va">self</span>, architecture) <span class="op">-&gt;</span> <span class="bu">float</span>:</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate with early stopping"""</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> <span class="va">self</span>._build_model(architecture)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.SGD(model.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        best_val_acc <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        patience_counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.max_epochs):</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Training</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">=</span> <span class="va">self</span>._train_epoch(model, optimizer, criterion)</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validation</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>            val_acc <span class="op">=</span> <span class="va">self</span>._validate_epoch(model)</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Early stopping check</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> val_acc <span class="op">&gt;</span> best_val_acc:</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>                best_val_acc <span class="op">=</span> val_acc</span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>                patience_counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>                patience_counter <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> patience_counter <span class="op">&gt;=</span> <span class="va">self</span>.patience:</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> best_val_acc</span></code></pre></div></div>
</section>
<section id="weight-sharing-one-shot" class="level3">
<h3 class="anchored" data-anchor-id="weight-sharing-one-shot" id="weight-sharing-one-shot">Weight Sharing (One-Shot)</h3>
<p>Trains a super-network once and evaluates sub-networks by inheritance.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> WeightSharingEvaluator:</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, supernet: nn.Module, dataset):</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.supernet <span class="op">=</span> supernet</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dataset <span class="op">=</span> dataset</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.trained <span class="op">=</span> <span class="va">False</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_supernet(<span class="va">self</span>):</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train the supernet once"""</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.trained:</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.SGD(<span class="va">self</span>.supernet.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):  <span class="co"># Train supernet</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> batch <span class="kw">in</span> <span class="va">self</span>.dataset:</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>                inputs, targets <span class="op">=</span> batch</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>                optimizer.zero_grad()</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Sample random path through supernet</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.supernet.sample_active_subnet()</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.supernet(inputs)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>                loss.backward()</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.trained <span class="op">=</span> <span class="va">True</span></span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate(<span class="va">self</span>, architecture) <span class="op">-&gt;</span> <span class="bu">float</span>:</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate architecture using trained supernet"""</span></span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.trained:</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.train_supernet()</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Configure supernet for specific architecture</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.supernet.set_active_subnet(architecture)</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate on validation set</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._evaluate_subnet()</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _evaluate_subnet(<span class="va">self</span>):</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate current subnet configuration"""</span></span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.supernet.<span class="bu">eval</span>()</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> batch <span class="kw">in</span> <span class="va">self</span>.dataset:</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>                inputs, targets <span class="op">=</span> batch</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.supernet(inputs)</span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>                _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs.data, <span class="dv">1</span>)</span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>                total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>                correct <span class="op">+=</span> (predicted <span class="op">==</span> targets).<span class="bu">sum</span>().item()</span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> correct <span class="op">/</span> total</span></code></pre></div></div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="progressive-search" class="level3">
<h3 class="anchored" data-anchor-id="progressive-search" id="progressive-search">Progressive Search</h3>
<p>Gradually increases search space complexity.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ProgressiveSearch:</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_search_space, max_complexity: <span class="bu">int</span> <span class="op">=</span> <span class="dv">5</span>):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_search_space <span class="op">=</span> base_search_space</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_complexity <span class="op">=</span> max_complexity</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.current_complexity <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_strategy <span class="op">=</span> EvolutionarySearch()</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> search(<span class="va">self</span>, iterations_per_stage: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100</span>):</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Progressive search with increasing complexity"""</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        best_architectures <span class="op">=</span> []</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> complexity <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, <span class="va">self</span>.max_complexity <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.current_complexity <span class="op">=</span> complexity</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Search at current complexity level</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(iterations_per_stage):</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>                architecture <span class="op">=</span> <span class="va">self</span>._sample_architecture_at_complexity()</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>                performance <span class="op">=</span> <span class="va">self</span>._evaluate_architecture(architecture)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.search_strategy.update(architecture, performance)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Get best architecture at this complexity</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>            best_arch <span class="op">=</span> <span class="bu">max</span>(<span class="va">self</span>.search_strategy.history, </span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>                          key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">1</span>])</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>            best_architectures.append(best_arch)</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> best_architectures</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _sample_architecture_at_complexity(<span class="va">self</span>):</span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample architecture with limited complexity"""</span></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>        arch <span class="op">=</span> <span class="va">self</span>.base_search_space.sample_architecture()</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Limit architecture complexity</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>        arch[<span class="st">'layers'</span>] <span class="op">=</span> arch[<span class="st">'layers'</span>][:<span class="va">self</span>.current_complexity <span class="op">*</span> <span class="dv">3</span>]</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> arch</span></code></pre></div></div>
</section>
<section id="multi-objective-nas" class="level3">
<h3 class="anchored" data-anchor-id="multi-objective-nas" id="multi-objective-nas">Multi-Objective NAS</h3>
<p>Optimizes multiple objectives simultaneously.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiObjectiveNAS:</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, objectives: List[<span class="bu">str</span>]):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.objectives <span class="op">=</span> objectives  <span class="co"># e.g., ['accuracy', 'latency', 'flops']</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pareto_front <span class="op">=</span> []</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_architecture(<span class="va">self</span>, architecture) <span class="op">-&gt;</span> Dict[<span class="bu">str</span>, <span class="bu">float</span>]:</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Evaluate architecture on multiple objectives"""</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> {}</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'accuracy'</span> <span class="kw">in</span> <span class="va">self</span>.objectives:</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>            results[<span class="st">'accuracy'</span>] <span class="op">=</span> <span class="va">self</span>._evaluate_accuracy(architecture)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'latency'</span> <span class="kw">in</span> <span class="va">self</span>.objectives:</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>            results[<span class="st">'latency'</span>] <span class="op">=</span> <span class="va">self</span>._evaluate_latency(architecture)</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'flops'</span> <span class="kw">in</span> <span class="va">self</span>.objectives:</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>            results[<span class="st">'flops'</span>] <span class="op">=</span> <span class="va">self</span>._evaluate_flops(architecture)</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update_pareto_front(<span class="va">self</span>, architecture, objectives):</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update Pareto front with new architecture"""</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check if architecture is dominated</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        dominated <span class="op">=</span> <span class="va">False</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> pareto_arch, pareto_obj <span class="kw">in</span> <span class="va">self</span>.pareto_front:</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>._dominates(pareto_obj, objectives):</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>                dominated <span class="op">=</span> <span class="va">True</span></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> dominated:</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Remove dominated architectures</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.pareto_front <span class="op">=</span> [</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>                (arch, obj) <span class="cf">for</span> arch, obj <span class="kw">in</span> <span class="va">self</span>.pareto_front</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>._dominates(objectives, obj)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>            ]</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add new architecture</span></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.pareto_front.append((architecture, objectives))</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _dominates(<span class="va">self</span>, obj1: Dict, obj2: Dict) <span class="op">-&gt;</span> <span class="bu">bool</span>:</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Check if obj1 dominates obj2"""</span></span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>        better_in_all <span class="op">=</span> <span class="va">True</span></span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>        strictly_better_in_one <span class="op">=</span> <span class="va">False</span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> objective <span class="kw">in</span> <span class="va">self</span>.objectives:</span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> objective <span class="kw">in</span> [<span class="st">'accuracy'</span>]:  <span class="co"># Higher is better</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> obj1[objective] <span class="op">&lt;</span> obj2[objective]:</span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>                    better_in_all <span class="op">=</span> <span class="va">False</span></span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>                <span class="cf">elif</span> obj1[objective] <span class="op">&gt;</span> obj2[objective]:</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>                    strictly_better_in_one <span class="op">=</span> <span class="va">True</span></span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:  <span class="co"># Lower is better (latency, flops)</span></span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> obj1[objective] <span class="op">&gt;</span> obj2[objective]:</span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>                    better_in_all <span class="op">=</span> <span class="va">False</span></span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>                <span class="cf">elif</span> obj1[objective] <span class="op">&lt;</span> obj2[objective]:</span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a>                    strictly_better_in_one <span class="op">=</span> <span class="va">True</span></span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-56"><a href="#cb13-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> better_in_all <span class="kw">and</span> strictly_better_in_one</span></code></pre></div></div>
</section>
</section>
<section id="implementation-examples" class="level2">
<h2 class="anchored" data-anchor-id="implementation-examples" id="implementation-examples">Implementation Examples</h2>
<section id="complete-darts-implementation" class="level3">
<h3 class="anchored" data-anchor-id="complete-darts-implementation" id="complete-darts-implementation">Complete DARTS Implementation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DARTSCell(nn.Module):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_nodes: <span class="bu">int</span>, channels: <span class="bu">int</span>):</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_nodes <span class="op">=</span> num_nodes</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.channels <span class="op">=</span> channels</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mixed operations for each edge</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mixed_ops <span class="op">=</span> nn.ModuleList()</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_nodes):</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">2</span> <span class="op">+</span> i):  <span class="co"># Each node connects to all previous nodes</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.mixed_ops.append(MixedOp(channels))</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Architecture parameters</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> nn.Parameter(torch.randn(<span class="bu">len</span>(<span class="va">self</span>.mixed_ops), <span class="dv">8</span>))</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, inputs):</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># inputs[0] and inputs[1] are the two input nodes</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        states <span class="op">=</span> [inputs[<span class="dv">0</span>], inputs[<span class="dv">1</span>]]</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>        offset <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_nodes):</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Collect inputs from all previous nodes</span></span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>            node_inputs <span class="op">=</span> []</span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(states)):</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>                op_idx <span class="op">=</span> offset <span class="op">+</span> j</span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>                node_inputs.append(<span class="va">self</span>.mixed_ops[op_idx](states[j], <span class="va">self</span>.alpha[op_idx]))</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sum all inputs to this node</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>            state <span class="op">=</span> <span class="bu">sum</span>(node_inputs)</span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>            states.append(state)</span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>            offset <span class="op">+=</span> <span class="bu">len</span>(states) <span class="op">-</span> <span class="dv">1</span></span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate final nodes</span></span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.cat(states[<span class="op">-</span><span class="va">self</span>.num_nodes:], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MixedOp(nn.Module):</span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, channels: <span class="bu">int</span>):</span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ops <span class="op">=</span> nn.ModuleList([</span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>            SepConv(channels, channels, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>),</span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>            SepConv(channels, channels, <span class="dv">5</span>, <span class="dv">1</span>, <span class="dv">2</span>),</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>            DilConv(channels, channels, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>            DilConv(channels, channels, <span class="dv">5</span>, <span class="dv">1</span>, <span class="dv">4</span>, <span class="dv">2</span>),</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>),</span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>            nn.AvgPool2d(<span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>),</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>            Identity(),</span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>            Zero()</span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, alpha):</span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply weighted sum of operations</span></span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a>        weights <span class="op">=</span> torch.softmax(alpha, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">sum</span>(w <span class="op">*</span> op(x) <span class="cf">for</span> w, op <span class="kw">in</span> <span class="bu">zip</span>(weights, <span class="va">self</span>.ops))</span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SepConv(nn.Module):</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, kernel_size, stride, padding):</span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Sequential(</span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, </span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>                     groups<span class="op">=</span>in_channels),</span>
<span id="cb14-61"><a href="#cb14-61" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>),</span>
<span id="cb14-62"><a href="#cb14-62" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(out_channels),</span>
<span id="cb14-63"><a href="#cb14-63" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-64"><a href="#cb14-64" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb14-65"><a href="#cb14-65" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-66"><a href="#cb14-66" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-67"><a href="#cb14-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.conv(x)</span>
<span id="cb14-68"><a href="#cb14-68" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-69"><a href="#cb14-69" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DilConv(nn.Module):</span>
<span id="cb14-70"><a href="#cb14-70" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, kernel_size, stride, padding, dilation):</span>
<span id="cb14-71"><a href="#cb14-71" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-72"><a href="#cb14-72" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Sequential(</span>
<span id="cb14-73"><a href="#cb14-73" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, </span>
<span id="cb14-74"><a href="#cb14-74" aria-hidden="true" tabindex="-1"></a>                     dilation<span class="op">=</span>dilation),</span>
<span id="cb14-75"><a href="#cb14-75" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(out_channels),</span>
<span id="cb14-76"><a href="#cb14-76" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-77"><a href="#cb14-77" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb14-78"><a href="#cb14-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-79"><a href="#cb14-79" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-80"><a href="#cb14-80" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.conv(x)</span>
<span id="cb14-81"><a href="#cb14-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-82"><a href="#cb14-82" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Identity(nn.Module):</span>
<span id="cb14-83"><a href="#cb14-83" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-84"><a href="#cb14-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb14-85"><a href="#cb14-85" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-86"><a href="#cb14-86" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Zero(nn.Module):</span>
<span id="cb14-87"><a href="#cb14-87" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-88"><a href="#cb14-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.zeros_like(x)</span>
<span id="cb14-89"><a href="#cb14-89" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-90"><a href="#cb14-90" aria-hidden="true" tabindex="-1"></a><span class="co"># Complete DARTS Network</span></span>
<span id="cb14-91"><a href="#cb14-91" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DARTSNetwork(nn.Module):</span>
<span id="cb14-92"><a href="#cb14-92" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes: <span class="bu">int</span>, num_cells: <span class="bu">int</span> <span class="op">=</span> <span class="dv">8</span>, channels: <span class="bu">int</span> <span class="op">=</span> <span class="dv">36</span>):</span>
<span id="cb14-93"><a href="#cb14-93" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-94"><a href="#cb14-94" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_cells <span class="op">=</span> num_cells</span>
<span id="cb14-95"><a href="#cb14-95" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.channels <span class="op">=</span> channels</span>
<span id="cb14-96"><a href="#cb14-96" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-97"><a href="#cb14-97" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Stem</span></span>
<span id="cb14-98"><a href="#cb14-98" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stem <span class="op">=</span> nn.Sequential(</span>
<span id="cb14-99"><a href="#cb14-99" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, channels, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>),</span>
<span id="cb14-100"><a href="#cb14-100" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(channels)</span>
<span id="cb14-101"><a href="#cb14-101" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb14-102"><a href="#cb14-102" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-103"><a href="#cb14-103" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cells</span></span>
<span id="cb14-104"><a href="#cb14-104" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cells <span class="op">=</span> nn.ModuleList()</span>
<span id="cb14-105"><a href="#cb14-105" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_cells):</span>
<span id="cb14-106"><a href="#cb14-106" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> i <span class="kw">in</span> [num_cells <span class="op">//</span> <span class="dv">3</span>, <span class="dv">2</span> <span class="op">*</span> num_cells <span class="op">//</span> <span class="dv">3</span>]:</span>
<span id="cb14-107"><a href="#cb14-107" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Reduction cell</span></span>
<span id="cb14-108"><a href="#cb14-108" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.cells.append(DARTSCell(<span class="dv">4</span>, channels))</span>
<span id="cb14-109"><a href="#cb14-109" aria-hidden="true" tabindex="-1"></a>                channels <span class="op">*=</span> <span class="dv">2</span></span>
<span id="cb14-110"><a href="#cb14-110" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb14-111"><a href="#cb14-111" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Normal cell</span></span>
<span id="cb14-112"><a href="#cb14-112" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.cells.append(DARTSCell(<span class="dv">4</span>, channels))</span>
<span id="cb14-113"><a href="#cb14-113" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-114"><a href="#cb14-114" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classifier</span></span>
<span id="cb14-115"><a href="#cb14-115" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(channels, num_classes)</span>
<span id="cb14-116"><a href="#cb14-116" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_pool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb14-117"><a href="#cb14-117" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-118"><a href="#cb14-118" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb14-119"><a href="#cb14-119" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.stem(x)</span>
<span id="cb14-120"><a href="#cb14-120" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-121"><a href="#cb14-121" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> cell <span class="kw">in</span> <span class="va">self</span>.cells:</span>
<span id="cb14-122"><a href="#cb14-122" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> cell([x, x])  <span class="co"># Use same input for both inputs</span></span>
<span id="cb14-123"><a href="#cb14-123" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-124"><a href="#cb14-124" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.global_pool(x)</span>
<span id="cb14-125"><a href="#cb14-125" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb14-126"><a href="#cb14-126" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb14-127"><a href="#cb14-127" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-128"><a href="#cb14-128" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="evolutionary-search-example" class="level3">
<h3 class="anchored" data-anchor-id="evolutionary-search-example" id="evolutionary-search-example">Evolutionary Search Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EvolutionaryNAS:</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, population_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">50</span>, generations: <span class="bu">int</span> <span class="op">=</span> <span class="dv">100</span>):</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population_size <span class="op">=</span> population_size</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.generations <span class="op">=</span> generations</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population <span class="op">=</span> []</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fitness_history <span class="op">=</span> []</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> run_search(<span class="va">self</span>, search_space, evaluator):</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run evolutionary search"""</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize population</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.population <span class="op">=</span> [</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            search_space.sample_architecture() </span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.population_size)</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> generation <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.generations):</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Evaluate population</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>            fitness_scores <span class="op">=</span> []</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> individual <span class="kw">in</span> <span class="va">self</span>.population:</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>                fitness <span class="op">=</span> evaluator.evaluate(individual)</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>                fitness_scores.append(fitness)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.fitness_history.append(<span class="bu">max</span>(fitness_scores))</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Selection and reproduction</span></span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>            new_population <span class="op">=</span> []</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.population_size):</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Tournament selection</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>                parent1 <span class="op">=</span> <span class="va">self</span>._tournament_selection(fitness_scores)</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>                parent2 <span class="op">=</span> <span class="va">self</span>._tournament_selection(fitness_scores)</span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Crossover</span></span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>                child <span class="op">=</span> <span class="va">self</span>._crossover(parent1, parent2)</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Mutation</span></span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>                child <span class="op">=</span> <span class="va">self</span>._mutate(child, search_space)</span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>                new_population.append(child)</span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.population <span class="op">=</span> new_population</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Return best architecture</span></span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>        final_fitness <span class="op">=</span> [evaluator.evaluate(ind) <span class="cf">for</span> ind <span class="kw">in</span> <span class="va">self</span>.population]</span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>        best_idx <span class="op">=</span> np.argmax(final_fitness)</span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.population[best_idx], final_fitness[best_idx]</span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _tournament_selection(<span class="va">self</span>, fitness_scores, tournament_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">3</span>):</span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Tournament selection"""</span></span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>        tournament_indices <span class="op">=</span> random.sample(<span class="bu">range</span>(<span class="bu">len</span>(fitness_scores)), tournament_size)</span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>        tournament_fitness <span class="op">=</span> [fitness_scores[i] <span class="cf">for</span> i <span class="kw">in</span> tournament_indices]</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>        winner_idx <span class="op">=</span> tournament_indices[np.argmax(tournament_fitness)]</span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.population[winner_idx]</span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _crossover(<span class="va">self</span>, parent1, parent2):</span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Single-point crossover"""</span></span>
<span id="cb15-56"><a href="#cb15-56" aria-hidden="true" tabindex="-1"></a>        child <span class="op">=</span> parent1.copy()</span>
<span id="cb15-57"><a href="#cb15-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-58"><a href="#cb15-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'layers'</span> <span class="kw">in</span> parent1 <span class="kw">and</span> <span class="st">'layers'</span> <span class="kw">in</span> parent2:</span>
<span id="cb15-59"><a href="#cb15-59" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Crossover layers</span></span>
<span id="cb15-60"><a href="#cb15-60" aria-hidden="true" tabindex="-1"></a>            min_length <span class="op">=</span> <span class="bu">min</span>(<span class="bu">len</span>(parent1[<span class="st">'layers'</span>]), <span class="bu">len</span>(parent2[<span class="st">'layers'</span>]))</span>
<span id="cb15-61"><a href="#cb15-61" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> min_length <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb15-62"><a href="#cb15-62" aria-hidden="true" tabindex="-1"></a>                crossover_point <span class="op">=</span> random.randint(<span class="dv">1</span>, min_length <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb15-63"><a href="#cb15-63" aria-hidden="true" tabindex="-1"></a>                child[<span class="st">'layers'</span>] <span class="op">=</span> (parent1[<span class="st">'layers'</span>][:crossover_point] <span class="op">+</span> </span>
<span id="cb15-64"><a href="#cb15-64" aria-hidden="true" tabindex="-1"></a>                                 parent2[<span class="st">'layers'</span>][crossover_point:])</span>
<span id="cb15-65"><a href="#cb15-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-66"><a href="#cb15-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> child</span>
<span id="cb15-67"><a href="#cb15-67" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-68"><a href="#cb15-68" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _mutate(<span class="va">self</span>, individual, search_space, mutation_rate: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span>):</span>
<span id="cb15-69"><a href="#cb15-69" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Mutate individual"""</span></span>
<span id="cb15-70"><a href="#cb15-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> random.random() <span class="op">&lt;</span> mutation_rate:</span>
<span id="cb15-71"><a href="#cb15-71" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">'layers'</span> <span class="kw">in</span> individual <span class="kw">and</span> individual[<span class="st">'layers'</span>]:</span>
<span id="cb15-72"><a href="#cb15-72" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Randomly mutate a layer</span></span>
<span id="cb15-73"><a href="#cb15-73" aria-hidden="true" tabindex="-1"></a>                layer_idx <span class="op">=</span> random.randint(<span class="dv">0</span>, <span class="bu">len</span>(individual[<span class="st">'layers'</span>]) <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb15-74"><a href="#cb15-74" aria-hidden="true" tabindex="-1"></a>                layer <span class="op">=</span> individual[<span class="st">'layers'</span>][layer_idx]</span>
<span id="cb15-75"><a href="#cb15-75" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-76"><a href="#cb15-76" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Mutate operation</span></span>
<span id="cb15-77"><a href="#cb15-77" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.5</span>:</span>
<span id="cb15-78"><a href="#cb15-78" aria-hidden="true" tabindex="-1"></a>                    layer[<span class="st">'operation'</span>] <span class="op">=</span> random.choice(search_space.operations)</span>
<span id="cb15-79"><a href="#cb15-79" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-80"><a href="#cb15-80" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Mutate filters</span></span>
<span id="cb15-81"><a href="#cb15-81" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.5</span>:</span>
<span id="cb15-82"><a href="#cb15-82" aria-hidden="true" tabindex="-1"></a>                    layer[<span class="st">'filters'</span>] <span class="op">=</span> random.choice([<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">512</span>])</span>
<span id="cb15-83"><a href="#cb15-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-84"><a href="#cb15-84" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> individual</span>
<span id="cb15-85"><a href="#cb15-85" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-86"><a href="#cb15-86" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-87"><a href="#cb15-87" aria-hidden="true" tabindex="-1"></a><span class="co">### Reinforcement Learning NAS Example</span></span>
<span id="cb15-88"><a href="#cb15-88" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-89"><a href="#cb15-89" aria-hidden="true" tabindex="-1"></a>```python</span>
<span id="cb15-90"><a href="#cb15-90" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RLNASController(nn.Module):</span>
<span id="cb15-91"><a href="#cb15-91" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_layers: <span class="bu">int</span> <span class="op">=</span> <span class="dv">6</span>, lstm_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">32</span>, </span>
<span id="cb15-92"><a href="#cb15-92" aria-hidden="true" tabindex="-1"></a>                 num_branches: <span class="bu">int</span> <span class="op">=</span> <span class="dv">6</span>, out_filters: <span class="bu">int</span> <span class="op">=</span> <span class="dv">48</span>):</span>
<span id="cb15-93"><a href="#cb15-93" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb15-94"><a href="#cb15-94" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_layers <span class="op">=</span> num_layers</span>
<span id="cb15-95"><a href="#cb15-95" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lstm_size <span class="op">=</span> lstm_size</span>
<span id="cb15-96"><a href="#cb15-96" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_branches <span class="op">=</span> num_branches</span>
<span id="cb15-97"><a href="#cb15-97" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.out_filters <span class="op">=</span> out_filters</span>
<span id="cb15-98"><a href="#cb15-98" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-99"><a href="#cb15-99" aria-hidden="true" tabindex="-1"></a>        <span class="co"># LSTM controller</span></span>
<span id="cb15-100"><a href="#cb15-100" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lstm <span class="op">=</span> nn.LSTMCell(lstm_size, lstm_size)</span>
<span id="cb15-101"><a href="#cb15-101" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-102"><a href="#cb15-102" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Embedding layers for different architecture decisions</span></span>
<span id="cb15-103"><a href="#cb15-103" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.g_emb <span class="op">=</span> nn.Embedding(<span class="dv">1</span>, lstm_size)  <span class="co"># Go embedding</span></span>
<span id="cb15-104"><a href="#cb15-104" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.encoder <span class="op">=</span> nn.Linear(lstm_size, lstm_size)</span>
<span id="cb15-105"><a href="#cb15-105" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-106"><a href="#cb15-106" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decision heads</span></span>
<span id="cb15-107"><a href="#cb15-107" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv_op <span class="op">=</span> nn.Linear(lstm_size, <span class="bu">len</span>(CONV_OPS))</span>
<span id="cb15-108"><a href="#cb15-108" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv_ksize <span class="op">=</span> nn.Linear(lstm_size, <span class="bu">len</span>(CONV_KERNEL_SIZES))</span>
<span id="cb15-109"><a href="#cb15-109" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv_filters <span class="op">=</span> nn.Linear(lstm_size, <span class="bu">len</span>(CONV_FILTERS))</span>
<span id="cb15-110"><a href="#cb15-110" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pooling_op <span class="op">=</span> nn.Linear(lstm_size, <span class="bu">len</span>(POOLING_OPS))</span>
<span id="cb15-111"><a href="#cb15-111" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pooling_ksize <span class="op">=</span> nn.Linear(lstm_size, <span class="bu">len</span>(POOLING_KERNEL_SIZES))</span>
<span id="cb15-112"><a href="#cb15-112" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-113"><a href="#cb15-113" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize parameters</span></span>
<span id="cb15-114"><a href="#cb15-114" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.reset_parameters()</span>
<span id="cb15-115"><a href="#cb15-115" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-116"><a href="#cb15-116" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> reset_parameters(<span class="va">self</span>):</span>
<span id="cb15-117"><a href="#cb15-117" aria-hidden="true" tabindex="-1"></a>        init_range <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb15-118"><a href="#cb15-118" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.parameters():</span>
<span id="cb15-119"><a href="#cb15-119" aria-hidden="true" tabindex="-1"></a>            param.data.uniform_(<span class="op">-</span>init_range, init_range)</span>
<span id="cb15-120"><a href="#cb15-120" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-121"><a href="#cb15-121" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">1</span>):</span>
<span id="cb15-122"><a href="#cb15-122" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sample architecture using the controller"""</span></span>
<span id="cb15-123"><a href="#cb15-123" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize hidden state</span></span>
<span id="cb15-124"><a href="#cb15-124" aria-hidden="true" tabindex="-1"></a>        h <span class="op">=</span> torch.zeros(batch_size, <span class="va">self</span>.lstm_size)</span>
<span id="cb15-125"><a href="#cb15-125" aria-hidden="true" tabindex="-1"></a>        c <span class="op">=</span> torch.zeros(batch_size, <span class="va">self</span>.lstm_size)</span>
<span id="cb15-126"><a href="#cb15-126" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-127"><a href="#cb15-127" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Start with go embedding</span></span>
<span id="cb15-128"><a href="#cb15-128" aria-hidden="true" tabindex="-1"></a>        inputs <span class="op">=</span> <span class="va">self</span>.g_emb.weight.repeat(batch_size, <span class="dv">1</span>)</span>
<span id="cb15-129"><a href="#cb15-129" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-130"><a href="#cb15-130" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Store sampled architecture</span></span>
<span id="cb15-131"><a href="#cb15-131" aria-hidden="true" tabindex="-1"></a>        arc_seq <span class="op">=</span> []</span>
<span id="cb15-132"><a href="#cb15-132" aria-hidden="true" tabindex="-1"></a>        entropies <span class="op">=</span> []</span>
<span id="cb15-133"><a href="#cb15-133" aria-hidden="true" tabindex="-1"></a>        log_probs <span class="op">=</span> []</span>
<span id="cb15-134"><a href="#cb15-134" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-135"><a href="#cb15-135" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer_id <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_layers):</span>
<span id="cb15-136"><a href="#cb15-136" aria-hidden="true" tabindex="-1"></a>            <span class="co"># LSTM step</span></span>
<span id="cb15-137"><a href="#cb15-137" aria-hidden="true" tabindex="-1"></a>            h, c <span class="op">=</span> <span class="va">self</span>.lstm(inputs, (h, c))</span>
<span id="cb15-138"><a href="#cb15-138" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-139"><a href="#cb15-139" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample convolution operation</span></span>
<span id="cb15-140"><a href="#cb15-140" aria-hidden="true" tabindex="-1"></a>            conv_op_logits <span class="op">=</span> <span class="va">self</span>.conv_op(h)</span>
<span id="cb15-141"><a href="#cb15-141" aria-hidden="true" tabindex="-1"></a>            conv_op_prob <span class="op">=</span> F.softmax(conv_op_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb15-142"><a href="#cb15-142" aria-hidden="true" tabindex="-1"></a>            conv_op_log_prob <span class="op">=</span> F.log_softmax(conv_op_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb15-143"><a href="#cb15-143" aria-hidden="true" tabindex="-1"></a>            conv_op_entropy <span class="op">=</span> <span class="op">-</span>(conv_op_log_prob <span class="op">*</span> conv_op_prob).<span class="bu">sum</span>(<span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-144"><a href="#cb15-144" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-145"><a href="#cb15-145" aria-hidden="true" tabindex="-1"></a>            conv_op_sample <span class="op">=</span> torch.multinomial(conv_op_prob, <span class="dv">1</span>)</span>
<span id="cb15-146"><a href="#cb15-146" aria-hidden="true" tabindex="-1"></a>            conv_op_sample <span class="op">=</span> conv_op_sample.view(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb15-147"><a href="#cb15-147" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-148"><a href="#cb15-148" aria-hidden="true" tabindex="-1"></a>            arc_seq.append(conv_op_sample)</span>
<span id="cb15-149"><a href="#cb15-149" aria-hidden="true" tabindex="-1"></a>            entropies.append(conv_op_entropy)</span>
<span id="cb15-150"><a href="#cb15-150" aria-hidden="true" tabindex="-1"></a>            log_probs.append(conv_op_log_prob.gather(<span class="dv">1</span>, conv_op_sample.unsqueeze(<span class="dv">1</span>)))</span>
<span id="cb15-151"><a href="#cb15-151" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-152"><a href="#cb15-152" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample kernel size</span></span>
<span id="cb15-153"><a href="#cb15-153" aria-hidden="true" tabindex="-1"></a>            conv_ksize_logits <span class="op">=</span> <span class="va">self</span>.conv_ksize(h)</span>
<span id="cb15-154"><a href="#cb15-154" aria-hidden="true" tabindex="-1"></a>            conv_ksize_prob <span class="op">=</span> F.softmax(conv_ksize_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb15-155"><a href="#cb15-155" aria-hidden="true" tabindex="-1"></a>            conv_ksize_log_prob <span class="op">=</span> F.log_softmax(conv_ksize_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb15-156"><a href="#cb15-156" aria-hidden="true" tabindex="-1"></a>            conv_ksize_entropy <span class="op">=</span> <span class="op">-</span>(conv_ksize_log_prob <span class="op">*</span> conv_ksize_prob).<span class="bu">sum</span>(<span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-157"><a href="#cb15-157" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-158"><a href="#cb15-158" aria-hidden="true" tabindex="-1"></a>            conv_ksize_sample <span class="op">=</span> torch.multinomial(conv_ksize_prob, <span class="dv">1</span>)</span>
<span id="cb15-159"><a href="#cb15-159" aria-hidden="true" tabindex="-1"></a>            conv_ksize_sample <span class="op">=</span> conv_ksize_sample.view(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb15-160"><a href="#cb15-160" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-161"><a href="#cb15-161" aria-hidden="true" tabindex="-1"></a>            arc_seq.append(conv_ksize_sample)</span>
<span id="cb15-162"><a href="#cb15-162" aria-hidden="true" tabindex="-1"></a>            entropies.append(conv_ksize_entropy)</span>
<span id="cb15-163"><a href="#cb15-163" aria-hidden="true" tabindex="-1"></a>            log_probs.append(conv_ksize_log_prob.gather(<span class="dv">1</span>, conv_ksize_sample.unsqueeze(<span class="dv">1</span>)))</span>
<span id="cb15-164"><a href="#cb15-164" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-165"><a href="#cb15-165" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Continue for other decisions...</span></span>
<span id="cb15-166"><a href="#cb15-166" aria-hidden="true" tabindex="-1"></a>            inputs <span class="op">=</span> h  <span class="co"># Use current hidden state as input for next step</span></span>
<span id="cb15-167"><a href="#cb15-167" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-168"><a href="#cb15-168" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> arc_seq, torch.cat(log_probs), torch.cat(entropies)</span>
<span id="cb15-169"><a href="#cb15-169" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-170"><a href="#cb15-170" aria-hidden="true" tabindex="-1"></a><span class="co"># Constants for architecture choices</span></span>
<span id="cb15-171"><a href="#cb15-171" aria-hidden="true" tabindex="-1"></a>CONV_OPS <span class="op">=</span> [<span class="st">'conv'</span>, <span class="st">'depthwise_conv'</span>, <span class="st">'separable_conv'</span>]</span>
<span id="cb15-172"><a href="#cb15-172" aria-hidden="true" tabindex="-1"></a>CONV_KERNEL_SIZES <span class="op">=</span> [<span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>]</span>
<span id="cb15-173"><a href="#cb15-173" aria-hidden="true" tabindex="-1"></a>CONV_FILTERS <span class="op">=</span> [<span class="dv">24</span>, <span class="dv">36</span>, <span class="dv">48</span>, <span class="dv">64</span>]</span>
<span id="cb15-174"><a href="#cb15-174" aria-hidden="true" tabindex="-1"></a>POOLING_OPS <span class="op">=</span> [<span class="st">'max_pool'</span>, <span class="st">'avg_pool'</span>, <span class="st">'no_pool'</span>]</span>
<span id="cb15-175"><a href="#cb15-175" aria-hidden="true" tabindex="-1"></a>POOLING_KERNEL_SIZES <span class="op">=</span> [<span class="dv">2</span>, <span class="dv">3</span>]</span>
<span id="cb15-176"><a href="#cb15-176" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-177"><a href="#cb15-177" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RLNASTrainer:</span>
<span id="cb15-178"><a href="#cb15-178" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, controller, child_model_builder, evaluator):</span>
<span id="cb15-179"><a href="#cb15-179" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.controller <span class="op">=</span> controller</span>
<span id="cb15-180"><a href="#cb15-180" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.child_model_builder <span class="op">=</span> child_model_builder</span>
<span id="cb15-181"><a href="#cb15-181" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.evaluator <span class="op">=</span> evaluator</span>
<span id="cb15-182"><a href="#cb15-182" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-183"><a href="#cb15-183" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Controller optimizer</span></span>
<span id="cb15-184"><a href="#cb15-184" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.controller_optimizer <span class="op">=</span> torch.optim.Adam(</span>
<span id="cb15-185"><a href="#cb15-185" aria-hidden="true" tabindex="-1"></a>            controller.parameters(), lr<span class="op">=</span><span class="fl">3.5e-4</span></span>
<span id="cb15-186"><a href="#cb15-186" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-187"><a href="#cb15-187" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-188"><a href="#cb15-188" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Baseline for variance reduction</span></span>
<span id="cb15-189"><a href="#cb15-189" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.baseline <span class="op">=</span> <span class="va">None</span></span>
<span id="cb15-190"><a href="#cb15-190" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.baseline_decay <span class="op">=</span> <span class="fl">0.99</span></span>
<span id="cb15-191"><a href="#cb15-191" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-192"><a href="#cb15-192" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_controller(<span class="va">self</span>, num_epochs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2000</span>):</span>
<span id="cb15-193"><a href="#cb15-193" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train the controller using REINFORCE"""</span></span>
<span id="cb15-194"><a href="#cb15-194" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb15-195"><a href="#cb15-195" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Sample architectures</span></span>
<span id="cb15-196"><a href="#cb15-196" aria-hidden="true" tabindex="-1"></a>            arc_seq, log_probs, entropies <span class="op">=</span> <span class="va">self</span>.controller()</span>
<span id="cb15-197"><a href="#cb15-197" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-198"><a href="#cb15-198" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Build and evaluate child model</span></span>
<span id="cb15-199"><a href="#cb15-199" aria-hidden="true" tabindex="-1"></a>            child_model <span class="op">=</span> <span class="va">self</span>.child_model_builder.build(arc_seq)</span>
<span id="cb15-200"><a href="#cb15-200" aria-hidden="true" tabindex="-1"></a>            reward <span class="op">=</span> <span class="va">self</span>.evaluator.evaluate(child_model)</span>
<span id="cb15-201"><a href="#cb15-201" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-202"><a href="#cb15-202" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update baseline</span></span>
<span id="cb15-203"><a href="#cb15-203" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.baseline <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb15-204"><a href="#cb15-204" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.baseline <span class="op">=</span> reward</span>
<span id="cb15-205"><a href="#cb15-205" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb15-206"><a href="#cb15-206" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.baseline <span class="op">=</span> <span class="va">self</span>.baseline_decay <span class="op">*</span> <span class="va">self</span>.baseline <span class="op">+</span> <span class="op">\</span></span>
<span id="cb15-207"><a href="#cb15-207" aria-hidden="true" tabindex="-1"></a>                              (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.baseline_decay) <span class="op">*</span> reward</span>
<span id="cb15-208"><a href="#cb15-208" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-209"><a href="#cb15-209" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute advantage</span></span>
<span id="cb15-210"><a href="#cb15-210" aria-hidden="true" tabindex="-1"></a>            advantage <span class="op">=</span> reward <span class="op">-</span> <span class="va">self</span>.baseline</span>
<span id="cb15-211"><a href="#cb15-211" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-212"><a href="#cb15-212" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Controller loss (REINFORCE)</span></span>
<span id="cb15-213"><a href="#cb15-213" aria-hidden="true" tabindex="-1"></a>            controller_loss <span class="op">=</span> <span class="op">-</span>log_probs <span class="op">*</span> advantage</span>
<span id="cb15-214"><a href="#cb15-214" aria-hidden="true" tabindex="-1"></a>            controller_loss <span class="op">=</span> controller_loss.<span class="bu">sum</span>()</span>
<span id="cb15-215"><a href="#cb15-215" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-216"><a href="#cb15-216" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add entropy regularization</span></span>
<span id="cb15-217"><a href="#cb15-217" aria-hidden="true" tabindex="-1"></a>            entropy_penalty <span class="op">=</span> <span class="op">-</span>entropies.<span class="bu">sum</span>() <span class="op">*</span> <span class="fl">1e-4</span></span>
<span id="cb15-218"><a href="#cb15-218" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">=</span> controller_loss <span class="op">+</span> entropy_penalty</span>
<span id="cb15-219"><a href="#cb15-219" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-220"><a href="#cb15-220" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update controller</span></span>
<span id="cb15-221"><a href="#cb15-221" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.controller_optimizer.zero_grad()</span>
<span id="cb15-222"><a href="#cb15-222" aria-hidden="true" tabindex="-1"></a>            total_loss.backward()</span>
<span id="cb15-223"><a href="#cb15-223" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.controller.parameters(), <span class="fl">5.0</span>)</span>
<span id="cb15-224"><a href="#cb15-224" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.controller_optimizer.step()</span>
<span id="cb15-225"><a href="#cb15-225" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-226"><a href="#cb15-226" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> epoch <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb15-227"><a href="#cb15-227" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Reward: </span><span class="sc">{</span>reward<span class="sc">:.4f}</span><span class="ss">, '</span></span>
<span id="cb15-228"><a href="#cb15-228" aria-hidden="true" tabindex="-1"></a>                      <span class="ss">f'Baseline: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>baseline<span class="sc">:.4f}</span><span class="ss">, Loss: </span><span class="sc">{</span>total_loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="search-space-design-guidelines" class="level3">
<h3 class="anchored" data-anchor-id="search-space-design-guidelines" id="search-space-design-guidelines">1. Search Space Design Guidelines</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SearchSpaceDesignPrinciples:</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Guidelines for designing effective search spaces</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.principles <span class="op">=</span> {</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>            <span class="st">'expressiveness'</span>: <span class="st">'Include diverse operations and connections'</span>,</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">'efficiency'</span>: <span class="st">'Balance search space size with computational cost'</span>,</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'human_knowledge'</span>: <span class="st">'Incorporate domain-specific insights'</span>,</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'scalability'</span>: <span class="st">'Design for different input sizes and tasks'</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> design_macro_space(<span class="va">self</span>, task_type: <span class="bu">str</span>):</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Design macro search space based on task"""</span></span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> task_type <span class="op">==</span> <span class="st">'image_classification'</span>:</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>                <span class="st">'operations'</span>: [<span class="st">'conv3x3'</span>, <span class="st">'conv5x5'</span>, <span class="st">'depthwise_conv'</span>, <span class="st">'pointwise_conv'</span>,</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>                              <span class="st">'max_pool'</span>, <span class="st">'avg_pool'</span>, <span class="st">'global_pool'</span>, <span class="st">'identity'</span>],</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>                <span class="st">'max_layers'</span>: <span class="dv">20</span>,</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>                <span class="st">'channels'</span>: [<span class="dv">16</span>, <span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">512</span>],</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>                <span class="st">'skip_connections'</span>: <span class="va">True</span>,</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>                <span class="st">'batch_norm'</span>: <span class="va">True</span>,</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>                <span class="st">'activation'</span>: [<span class="st">'relu'</span>, <span class="st">'relu6'</span>, <span class="st">'swish'</span>]</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> task_type <span class="op">==</span> <span class="st">'object_detection'</span>:</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">'operations'</span>: [<span class="st">'conv3x3'</span>, <span class="st">'conv5x5'</span>, <span class="st">'depthwise_conv'</span>, <span class="st">'atrous_conv'</span>,</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>                              <span class="st">'max_pool'</span>, <span class="st">'avg_pool'</span>, <span class="st">'identity'</span>],</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">'max_layers'</span>: <span class="dv">30</span>,</span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">'channels'</span>: [<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">512</span>, <span class="dv">1024</span>],</span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>                <span class="st">'skip_connections'</span>: <span class="va">True</span>,</span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a>                <span class="st">'fpn_layers'</span>: <span class="va">True</span>,</span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a>                <span class="st">'anchor_scales'</span>: [<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">512</span>]</span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb16-36"><a href="#cb16-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-37"><a href="#cb16-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate_search_space(<span class="va">self</span>, search_space):</span>
<span id="cb16-38"><a href="#cb16-38" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Validate search space design"""</span></span>
<span id="cb16-39"><a href="#cb16-39" aria-hidden="true" tabindex="-1"></a>        issues <span class="op">=</span> []</span>
<span id="cb16-40"><a href="#cb16-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-41"><a href="#cb16-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for minimal viable operations</span></span>
<span id="cb16-42"><a href="#cb16-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(search_space.get(<span class="st">'operations'</span>, [])) <span class="op">&lt;</span> <span class="dv">3</span>:</span>
<span id="cb16-43"><a href="#cb16-43" aria-hidden="true" tabindex="-1"></a>            issues.append(<span class="st">"Too few operations - may limit expressiveness"</span>)</span>
<span id="cb16-44"><a href="#cb16-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-45"><a href="#cb16-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check for identity operation</span></span>
<span id="cb16-46"><a href="#cb16-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="st">'identity'</span> <span class="kw">not</span> <span class="kw">in</span> search_space.get(<span class="st">'operations'</span>, []):</span>
<span id="cb16-47"><a href="#cb16-47" aria-hidden="true" tabindex="-1"></a>            issues.append(<span class="st">"Missing identity operation - may hurt skip connections"</span>)</span>
<span id="cb16-48"><a href="#cb16-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-49"><a href="#cb16-49" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check channel progression</span></span>
<span id="cb16-50"><a href="#cb16-50" aria-hidden="true" tabindex="-1"></a>        channels <span class="op">=</span> search_space.get(<span class="st">'channels'</span>, [])</span>
<span id="cb16-51"><a href="#cb16-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> channels <span class="kw">and</span> <span class="kw">not</span> <span class="bu">all</span>(channels[i] <span class="op">&lt;=</span> channels[i<span class="op">+</span><span class="dv">1</span>] <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(channels)<span class="op">-</span><span class="dv">1</span>)):</span>
<span id="cb16-52"><a href="#cb16-52" aria-hidden="true" tabindex="-1"></a>            issues.append(<span class="st">"Non-monotonic channel progression"</span>)</span>
<span id="cb16-53"><a href="#cb16-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-54"><a href="#cb16-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> issues</span></code></pre></div></div>
</section>
<section id="training-strategies" class="level3">
<h3 class="anchored" data-anchor-id="training-strategies" id="training-strategies">2. Training Strategies</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> NASTrainingStrategies:</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Advanced training strategies for NAS"""</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.strategies <span class="op">=</span> {}</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> progressive_shrinking(<span class="va">self</span>, supernet, dataset, stages: <span class="bu">int</span> <span class="op">=</span> <span class="dv">4</span>):</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Progressive shrinking strategy"""</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        current_channels <span class="op">=</span> supernet.max_channels</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> stage <span class="kw">in</span> <span class="bu">range</span>(stages):</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Reduce search space</span></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>            target_channels <span class="op">=</span> current_channels <span class="op">//</span> (<span class="dv">2</span> <span class="op">**</span> stage)</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>            supernet.set_channel_constraint(target_channels)</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Train for this stage</span></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._train_stage(supernet, dataset, epochs<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Stage </span><span class="sc">{</span>stage <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss">: Max channels = </span><span class="sc">{</span>target_channels<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sandwich_sampling(<span class="va">self</span>, supernet, dataset):</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Sandwich sampling for training efficiency"""</span></span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.SGD(supernet.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> batch <span class="kw">in</span> dataset:</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>                inputs, targets <span class="op">=</span> batch</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Sample architectures: largest, smallest, and random</span></span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>                architectures <span class="op">=</span> [</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>                    supernet.largest_architecture(),</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>                    supernet.smallest_architecture(),</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a>                    supernet.random_architecture(),</span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a>                    supernet.random_architecture()</span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a>                ]</span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>                total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> arch <span class="kw">in</span> architectures:</span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>                    supernet.set_active_subnet(arch)</span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>                    optimizer.zero_grad()</span>
<span id="cb17-42"><a href="#cb17-42" aria-hidden="true" tabindex="-1"></a>                    outputs <span class="op">=</span> supernet(inputs)</span>
<span id="cb17-43"><a href="#cb17-43" aria-hidden="true" tabindex="-1"></a>                    loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb17-44"><a href="#cb17-44" aria-hidden="true" tabindex="-1"></a>                    loss.backward()</span>
<span id="cb17-45"><a href="#cb17-45" aria-hidden="true" tabindex="-1"></a>                    total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb17-46"><a href="#cb17-46" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-47"><a href="#cb17-47" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb17-48"><a href="#cb17-48" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-49"><a href="#cb17-49" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> epoch <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb17-50"><a href="#cb17-50" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>total_loss<span class="op">/</span><span class="bu">len</span>(architectures)<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb17-51"><a href="#cb17-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-52"><a href="#cb17-52" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> knowledge_distillation(<span class="va">self</span>, student_arch, teacher_model, dataset):</span>
<span id="cb17-53"><a href="#cb17-53" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Knowledge distillation for architecture evaluation"""</span></span>
<span id="cb17-54"><a href="#cb17-54" aria-hidden="true" tabindex="-1"></a>        student_model <span class="op">=</span> <span class="va">self</span>._build_model(student_arch)</span>
<span id="cb17-55"><a href="#cb17-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-56"><a href="#cb17-56" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.SGD(student_model.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb17-57"><a href="#cb17-57" aria-hidden="true" tabindex="-1"></a>        kd_loss <span class="op">=</span> nn.KLDivLoss()</span>
<span id="cb17-58"><a href="#cb17-58" aria-hidden="true" tabindex="-1"></a>        ce_loss <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb17-59"><a href="#cb17-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-60"><a href="#cb17-60" aria-hidden="true" tabindex="-1"></a>        alpha <span class="op">=</span> <span class="fl">0.7</span>  <span class="co"># Distillation weight</span></span>
<span id="cb17-61"><a href="#cb17-61" aria-hidden="true" tabindex="-1"></a>        temperature <span class="op">=</span> <span class="fl">4.0</span></span>
<span id="cb17-62"><a href="#cb17-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-63"><a href="#cb17-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):</span>
<span id="cb17-64"><a href="#cb17-64" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> batch <span class="kw">in</span> dataset:</span>
<span id="cb17-65"><a href="#cb17-65" aria-hidden="true" tabindex="-1"></a>                inputs, targets <span class="op">=</span> batch</span>
<span id="cb17-66"><a href="#cb17-66" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-67"><a href="#cb17-67" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Teacher predictions</span></span>
<span id="cb17-68"><a href="#cb17-68" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> torch.no_grad():</span>
<span id="cb17-69"><a href="#cb17-69" aria-hidden="true" tabindex="-1"></a>                    teacher_outputs <span class="op">=</span> teacher_model(inputs)</span>
<span id="cb17-70"><a href="#cb17-70" aria-hidden="true" tabindex="-1"></a>                    teacher_probs <span class="op">=</span> F.softmax(teacher_outputs <span class="op">/</span> temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb17-71"><a href="#cb17-71" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-72"><a href="#cb17-72" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Student predictions</span></span>
<span id="cb17-73"><a href="#cb17-73" aria-hidden="true" tabindex="-1"></a>                student_outputs <span class="op">=</span> student_model(inputs)</span>
<span id="cb17-74"><a href="#cb17-74" aria-hidden="true" tabindex="-1"></a>                student_log_probs <span class="op">=</span> F.log_softmax(student_outputs <span class="op">/</span> temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb17-75"><a href="#cb17-75" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-76"><a href="#cb17-76" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Combined loss</span></span>
<span id="cb17-77"><a href="#cb17-77" aria-hidden="true" tabindex="-1"></a>                distill_loss <span class="op">=</span> kd_loss(student_log_probs, teacher_probs)</span>
<span id="cb17-78"><a href="#cb17-78" aria-hidden="true" tabindex="-1"></a>                hard_loss <span class="op">=</span> ce_loss(student_outputs, targets)</span>
<span id="cb17-79"><a href="#cb17-79" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-80"><a href="#cb17-80" aria-hidden="true" tabindex="-1"></a>                total_loss <span class="op">=</span> alpha <span class="op">*</span> distill_loss <span class="op">+</span> (<span class="dv">1</span> <span class="op">-</span> alpha) <span class="op">*</span> hard_loss</span>
<span id="cb17-81"><a href="#cb17-81" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb17-82"><a href="#cb17-82" aria-hidden="true" tabindex="-1"></a>                optimizer.zero_grad()</span>
<span id="cb17-83"><a href="#cb17-83" aria-hidden="true" tabindex="-1"></a>                total_loss.backward()</span>
<span id="cb17-84"><a href="#cb17-84" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb17-85"><a href="#cb17-85" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-86"><a href="#cb17-86" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._evaluate_model(student_model)</span></code></pre></div></div>
</section>
<section id="evaluation-and-benchmarking" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-and-benchmarking" id="evaluation-and-benchmarking">3. Evaluation and Benchmarking</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> NASBenchmarking:</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Benchmarking and evaluation utilities"""</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics <span class="op">=</span> {}</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> comprehensive_evaluation(<span class="va">self</span>, architecture, datasets):</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Comprehensive evaluation across multiple metrics"""</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> {}</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build model</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> <span class="va">self</span>._build_model(architecture)</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Accuracy metrics</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> dataset_name, dataset <span class="kw">in</span> datasets.items():</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>            accuracy <span class="op">=</span> <span class="va">self</span>._evaluate_accuracy(model, dataset)</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>            results[<span class="ss">f'</span><span class="sc">{</span>dataset_name<span class="sc">}</span><span class="ss">_accuracy'</span>] <span class="op">=</span> accuracy</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Efficiency metrics</span></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'params'</span>] <span class="op">=</span> <span class="va">self</span>._count_parameters(model)</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'flops'</span>] <span class="op">=</span> <span class="va">self</span>._count_flops(model)</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'latency'</span>] <span class="op">=</span> <span class="va">self</span>._measure_latency(model)</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'memory'</span>] <span class="op">=</span> <span class="va">self</span>._measure_memory(model)</span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Robustness metrics</span></span>
<span id="cb18-26"><a href="#cb18-26" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'adversarial_robustness'</span>] <span class="op">=</span> <span class="va">self</span>._evaluate_adversarial_robustness(model)</span>
<span id="cb18-27"><a href="#cb18-27" aria-hidden="true" tabindex="-1"></a>        results[<span class="st">'noise_robustness'</span>] <span class="op">=</span> <span class="va">self</span>._evaluate_noise_robustness(model)</span>
<span id="cb18-28"><a href="#cb18-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-29"><a href="#cb18-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb18-30"><a href="#cb18-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-31"><a href="#cb18-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _count_parameters(<span class="va">self</span>, model):</span>
<span id="cb18-32"><a href="#cb18-32" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Count model parameters"""</span></span>
<span id="cb18-33"><a href="#cb18-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters() <span class="cf">if</span> p.requires_grad)</span>
<span id="cb18-34"><a href="#cb18-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-35"><a href="#cb18-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _count_flops(<span class="va">self</span>, model, input_size<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb18-36"><a href="#cb18-36" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Count FLOPs using a simple profiler"""</span></span>
<span id="cb18-37"><a href="#cb18-37" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> flop_count_hook(module, <span class="bu">input</span>, output):</span>
<span id="cb18-38"><a href="#cb18-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, nn.Conv2d):</span>
<span id="cb18-39"><a href="#cb18-39" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Conv2d FLOPs</span></span>
<span id="cb18-40"><a href="#cb18-40" aria-hidden="true" tabindex="-1"></a>                batch_size, in_channels, input_height, input_width <span class="op">=</span> <span class="bu">input</span>[<span class="dv">0</span>].shape</span>
<span id="cb18-41"><a href="#cb18-41" aria-hidden="true" tabindex="-1"></a>                output_height, output_width <span class="op">=</span> output.shape[<span class="dv">2</span>], output.shape[<span class="dv">3</span>]</span>
<span id="cb18-42"><a href="#cb18-42" aria-hidden="true" tabindex="-1"></a>                kernel_height, kernel_width <span class="op">=</span> module.kernel_size</span>
<span id="cb18-43"><a href="#cb18-43" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-44"><a href="#cb18-44" aria-hidden="true" tabindex="-1"></a>                flops <span class="op">=</span> batch_size <span class="op">*</span> in_channels <span class="op">*</span> kernel_height <span class="op">*</span> kernel_width <span class="op">*</span> <span class="op">\</span></span>
<span id="cb18-45"><a href="#cb18-45" aria-hidden="true" tabindex="-1"></a>                       output_height <span class="op">*</span> output_width <span class="op">*</span> module.out_channels</span>
<span id="cb18-46"><a href="#cb18-46" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-47"><a href="#cb18-47" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="bu">hasattr</span>(module, <span class="st">'flops'</span>):</span>
<span id="cb18-48"><a href="#cb18-48" aria-hidden="true" tabindex="-1"></a>                    module.flops <span class="op">+=</span> flops</span>
<span id="cb18-49"><a href="#cb18-49" aria-hidden="true" tabindex="-1"></a>                <span class="cf">else</span>:</span>
<span id="cb18-50"><a href="#cb18-50" aria-hidden="true" tabindex="-1"></a>                    module.flops <span class="op">=</span> flops</span>
<span id="cb18-51"><a href="#cb18-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-52"><a href="#cb18-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Register hooks</span></span>
<span id="cb18-53"><a href="#cb18-53" aria-hidden="true" tabindex="-1"></a>        hooks <span class="op">=</span> []</span>
<span id="cb18-54"><a href="#cb18-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> module <span class="kw">in</span> model.modules():</span>
<span id="cb18-55"><a href="#cb18-55" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, (nn.Conv2d, nn.Linear)):</span>
<span id="cb18-56"><a href="#cb18-56" aria-hidden="true" tabindex="-1"></a>                hooks.append(module.register_forward_hook(flop_count_hook))</span>
<span id="cb18-57"><a href="#cb18-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-58"><a href="#cb18-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb18-59"><a href="#cb18-59" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb18-60"><a href="#cb18-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb18-61"><a href="#cb18-61" aria-hidden="true" tabindex="-1"></a>            dummy_input <span class="op">=</span> torch.randn(input_size)</span>
<span id="cb18-62"><a href="#cb18-62" aria-hidden="true" tabindex="-1"></a>            model(dummy_input)</span>
<span id="cb18-63"><a href="#cb18-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-64"><a href="#cb18-64" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Collect FLOPs</span></span>
<span id="cb18-65"><a href="#cb18-65" aria-hidden="true" tabindex="-1"></a>        total_flops <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb18-66"><a href="#cb18-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> module <span class="kw">in</span> model.modules():</span>
<span id="cb18-67"><a href="#cb18-67" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">hasattr</span>(module, <span class="st">'flops'</span>):</span>
<span id="cb18-68"><a href="#cb18-68" aria-hidden="true" tabindex="-1"></a>                total_flops <span class="op">+=</span> module.flops</span>
<span id="cb18-69"><a href="#cb18-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-70"><a href="#cb18-70" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Remove hooks</span></span>
<span id="cb18-71"><a href="#cb18-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> hook <span class="kw">in</span> hooks:</span>
<span id="cb18-72"><a href="#cb18-72" aria-hidden="true" tabindex="-1"></a>            hook.remove()</span>
<span id="cb18-73"><a href="#cb18-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-74"><a href="#cb18-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_flops</span>
<span id="cb18-75"><a href="#cb18-75" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-76"><a href="#cb18-76" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _measure_latency(<span class="va">self</span>, model, input_size<span class="op">=</span>(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>), runs<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb18-77"><a href="#cb18-77" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Measure inference latency"""</span></span>
<span id="cb18-78"><a href="#cb18-78" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb18-79"><a href="#cb18-79" aria-hidden="true" tabindex="-1"></a>        dummy_input <span class="op">=</span> torch.randn(input_size)</span>
<span id="cb18-80"><a href="#cb18-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-81"><a href="#cb18-81" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb18-82"><a href="#cb18-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb18-83"><a href="#cb18-83" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb18-84"><a href="#cb18-84" aria-hidden="true" tabindex="-1"></a>                model(dummy_input)</span>
<span id="cb18-85"><a href="#cb18-85" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-86"><a href="#cb18-86" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Measure</span></span>
<span id="cb18-87"><a href="#cb18-87" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb18-88"><a href="#cb18-88" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb18-89"><a href="#cb18-89" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-90"><a href="#cb18-90" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(runs):</span>
<span id="cb18-91"><a href="#cb18-91" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb18-92"><a href="#cb18-92" aria-hidden="true" tabindex="-1"></a>                model(dummy_input)</span>
<span id="cb18-93"><a href="#cb18-93" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-94"><a href="#cb18-94" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb18-95"><a href="#cb18-95" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb18-96"><a href="#cb18-96" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-97"><a href="#cb18-97" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> (end_time <span class="op">-</span> start_time) <span class="op">/</span> runs</span>
<span id="cb18-98"><a href="#cb18-98" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-99"><a href="#cb18-99" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compare_search_methods(<span class="va">self</span>, methods, search_space, evaluator, runs<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb18-100"><a href="#cb18-100" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compare different search methods"""</span></span>
<span id="cb18-101"><a href="#cb18-101" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> {}</span>
<span id="cb18-102"><a href="#cb18-102" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-103"><a href="#cb18-103" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> method_name, method <span class="kw">in</span> methods.items():</span>
<span id="cb18-104"><a href="#cb18-104" aria-hidden="true" tabindex="-1"></a>            method_results <span class="op">=</span> []</span>
<span id="cb18-105"><a href="#cb18-105" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-106"><a href="#cb18-106" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> run <span class="kw">in</span> <span class="bu">range</span>(runs):</span>
<span id="cb18-107"><a href="#cb18-107" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Set random seed for reproducibility</span></span>
<span id="cb18-108"><a href="#cb18-108" aria-hidden="true" tabindex="-1"></a>                torch.manual_seed(run)</span>
<span id="cb18-109"><a href="#cb18-109" aria-hidden="true" tabindex="-1"></a>                random.seed(run)</span>
<span id="cb18-110"><a href="#cb18-110" aria-hidden="true" tabindex="-1"></a>                np.random.seed(run)</span>
<span id="cb18-111"><a href="#cb18-111" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb18-112"><a href="#cb18-112" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Run search</span></span>
<span id="cb18-113"><a href="#cb18-113" aria-hidden="true" tabindex="-1"></a>                best_arch, best_performance <span class="op">=</span> method.search(search_space, evaluator)</span>
<span id="cb18-114"><a href="#cb18-114" aria-hidden="true" tabindex="-1"></a>                method_results.append({</span>
<span id="cb18-115"><a href="#cb18-115" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'architecture'</span>: best_arch,</span>
<span id="cb18-116"><a href="#cb18-116" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'performance'</span>: best_performance,</span>
<span id="cb18-117"><a href="#cb18-117" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'run'</span>: run</span>
<span id="cb18-118"><a href="#cb18-118" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb18-119"><a href="#cb18-119" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-120"><a href="#cb18-120" aria-hidden="true" tabindex="-1"></a>            results[method_name] <span class="op">=</span> method_results</span>
<span id="cb18-121"><a href="#cb18-121" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-122"><a href="#cb18-122" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._analyze_comparison_results(results)</span>
<span id="cb18-123"><a href="#cb18-123" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-124"><a href="#cb18-124" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _analyze_comparison_results(<span class="va">self</span>, results):</span>
<span id="cb18-125"><a href="#cb18-125" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Analyze comparison results"""</span></span>
<span id="cb18-126"><a href="#cb18-126" aria-hidden="true" tabindex="-1"></a>        analysis <span class="op">=</span> {}</span>
<span id="cb18-127"><a href="#cb18-127" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-128"><a href="#cb18-128" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> method_name, method_results <span class="kw">in</span> results.items():</span>
<span id="cb18-129"><a href="#cb18-129" aria-hidden="true" tabindex="-1"></a>            performances <span class="op">=</span> [r[<span class="st">'performance'</span>] <span class="cf">for</span> r <span class="kw">in</span> method_results]</span>
<span id="cb18-130"><a href="#cb18-130" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb18-131"><a href="#cb18-131" aria-hidden="true" tabindex="-1"></a>            analysis[method_name] <span class="op">=</span> {</span>
<span id="cb18-132"><a href="#cb18-132" aria-hidden="true" tabindex="-1"></a>                <span class="st">'mean_performance'</span>: np.mean(performances),</span>
<span id="cb18-133"><a href="#cb18-133" aria-hidden="true" tabindex="-1"></a>                <span class="st">'std_performance'</span>: np.std(performances),</span>
<span id="cb18-134"><a href="#cb18-134" aria-hidden="true" tabindex="-1"></a>                <span class="st">'best_performance'</span>: np.<span class="bu">max</span>(performances),</span>
<span id="cb18-135"><a href="#cb18-135" aria-hidden="true" tabindex="-1"></a>                <span class="st">'worst_performance'</span>: np.<span class="bu">min</span>(performances),</span>
<span id="cb18-136"><a href="#cb18-136" aria-hidden="true" tabindex="-1"></a>                <span class="st">'median_performance'</span>: np.median(performances)</span>
<span id="cb18-137"><a href="#cb18-137" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb18-138"><a href="#cb18-138" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-139"><a href="#cb18-139" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Rank methods</span></span>
<span id="cb18-140"><a href="#cb18-140" aria-hidden="true" tabindex="-1"></a>        ranked_methods <span class="op">=</span> <span class="bu">sorted</span>(analysis.items(), </span>
<span id="cb18-141"><a href="#cb18-141" aria-hidden="true" tabindex="-1"></a>                              key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">1</span>][<span class="st">'mean_performance'</span>], </span>
<span id="cb18-142"><a href="#cb18-142" aria-hidden="true" tabindex="-1"></a>                              reverse<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb18-143"><a href="#cb18-143" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-144"><a href="#cb18-144" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb18-145"><a href="#cb18-145" aria-hidden="true" tabindex="-1"></a>            <span class="st">'detailed_results'</span>: analysis,</span>
<span id="cb18-146"><a href="#cb18-146" aria-hidden="true" tabindex="-1"></a>            <span class="st">'ranking'</span>: ranked_methods,</span>
<span id="cb18-147"><a href="#cb18-147" aria-hidden="true" tabindex="-1"></a>            <span class="st">'summary'</span>: <span class="va">self</span>._generate_summary(ranked_methods)</span>
<span id="cb18-148"><a href="#cb18-148" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb18-149"><a href="#cb18-149" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-150"><a href="#cb18-150" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _generate_summary(<span class="va">self</span>, ranked_methods):</span>
<span id="cb18-151"><a href="#cb18-151" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Generate summary of comparison"""</span></span>
<span id="cb18-152"><a href="#cb18-152" aria-hidden="true" tabindex="-1"></a>        summary <span class="op">=</span> []</span>
<span id="cb18-153"><a href="#cb18-153" aria-hidden="true" tabindex="-1"></a>        summary.append(<span class="st">"=== NAS Method Comparison Results ==="</span>)</span>
<span id="cb18-154"><a href="#cb18-154" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-155"><a href="#cb18-155" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, (method_name, stats) <span class="kw">in</span> <span class="bu">enumerate</span>(ranked_methods):</span>
<span id="cb18-156"><a href="#cb18-156" aria-hidden="true" tabindex="-1"></a>            summary.append(<span class="ss">f"</span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">. </span><span class="sc">{</span>method_name<span class="sc">}</span><span class="ss">:"</span>)</span>
<span id="cb18-157"><a href="#cb18-157" aria-hidden="true" tabindex="-1"></a>            summary.append(<span class="ss">f"   Mean: </span><span class="sc">{</span>stats[<span class="st">'mean_performance'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb18-158"><a href="#cb18-158" aria-hidden="true" tabindex="-1"></a>            summary.append(<span class="ss">f"   Std:  </span><span class="sc">{</span>stats[<span class="st">'std_performance'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb18-159"><a href="#cb18-159" aria-hidden="true" tabindex="-1"></a>            summary.append(<span class="ss">f"   Best: </span><span class="sc">{</span>stats[<span class="st">'best_performance'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb18-160"><a href="#cb18-160" aria-hidden="true" tabindex="-1"></a>            summary.append(<span class="st">""</span>)</span>
<span id="cb18-161"><a href="#cb18-161" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-162"><a href="#cb18-162" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="st">"</span><span class="ch">\n</span><span class="st">"</span>.join(summary)</span></code></pre></div></div>
</section>
</section>
<section id="tools-and-frameworks" class="level2">
<h2 class="anchored" data-anchor-id="tools-and-frameworks" id="tools-and-frameworks">Tools and Frameworks</h2>
<section id="popular-nas-libraries" class="level3">
<h3 class="anchored" data-anchor-id="popular-nas-libraries" id="popular-nas-libraries">Popular NAS Libraries</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> NASFrameworkGuide:</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Guide to popular NAS frameworks"""</span></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.frameworks <span class="op">=</span> {</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>            <span class="st">'nni'</span>: {</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>                <span class="st">'description'</span>: <span class="st">'Microsoft</span><span class="ch">\'</span><span class="st">s Neural Network Intelligence toolkit'</span>,</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>                <span class="st">'strengths'</span>: [<span class="st">'Easy to use'</span>, <span class="st">'Multiple search algorithms'</span>, <span class="st">'Good documentation'</span>],</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>                <span class="st">'installation'</span>: <span class="st">'pip install nni'</span>,</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>                <span class="st">'example_usage'</span>: <span class="st">'''</span></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a><span class="st">from nni.nas.pytorch import DartsTrainer</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a><span class="st">from nni.nas.pytorch.search_space_zoo import ENASMacroSearchSpace</span></span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a><span class="st"># Define search space</span></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a><span class="st">search_space = ENASMacroSearchSpace()</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a><span class="st"># Create trainer</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a><span class="st">trainer = DartsTrainer(</span></span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a><span class="st">    model=search_space,</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a><span class="st">    loss=nn.CrossEntropyLoss(),</span></span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a><span class="st">    optimizer=torch.optim.SGD(search_space.parameters(), lr=0.1)</span></span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a><span class="st">)</span></span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a><span class="st"># Train</span></span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a><span class="st">trainer.train()</span></span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a><span class="st">'''</span></span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>            <span class="st">'automl'</span>: {</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>                <span class="st">'description'</span>: <span class="st">'Google</span><span class="ch">\'</span><span class="st">s AutoML toolkit'</span>,</span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>                <span class="st">'strengths'</span>: [<span class="st">'State-of-the-art methods'</span>, <span class="st">'Research-oriented'</span>],</span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>                <span class="st">'installation'</span>: <span class="st">'Custom installation from GitHub'</span>,</span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>                <span class="st">'example_usage'</span>: <span class="st">'''</span></span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a><span class="st"># Example for AdaNet</span></span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a><span class="st">import adanet</span></span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a><span class="st"># Define search space and estimator</span></span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a><span class="st">estimator = adanet.Estimator(</span></span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a><span class="st">    head=head,</span></span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a><span class="st">    subnetwork_generator=generator,</span></span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a><span class="st">    max_iteration_steps=1000</span></span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a><span class="st">)</span></span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a><span class="st"># Train</span></span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a><span class="st">estimator.train(input_fn=train_input_fn)</span></span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a><span class="st">'''</span></span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">'optuna'</span>: {</span>
<span id="cb19-48"><a href="#cb19-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">'description'</span>: <span class="st">'Hyperparameter optimization framework'</span>,</span>
<span id="cb19-49"><a href="#cb19-49" aria-hidden="true" tabindex="-1"></a>                <span class="st">'strengths'</span>: [<span class="st">'Flexible'</span>, <span class="st">'Multiple optimization algorithms'</span>, <span class="st">'Good for hyperparameter tuning'</span>],</span>
<span id="cb19-50"><a href="#cb19-50" aria-hidden="true" tabindex="-1"></a>                <span class="st">'installation'</span>: <span class="st">'pip install optuna'</span>,</span>
<span id="cb19-51"><a href="#cb19-51" aria-hidden="true" tabindex="-1"></a>                <span class="st">'example_usage'</span>: <span class="st">'''</span></span>
<span id="cb19-52"><a href="#cb19-52" aria-hidden="true" tabindex="-1"></a><span class="st">import optuna</span></span>
<span id="cb19-53"><a href="#cb19-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-54"><a href="#cb19-54" aria-hidden="true" tabindex="-1"></a><span class="st">def objective(trial):</span></span>
<span id="cb19-55"><a href="#cb19-55" aria-hidden="true" tabindex="-1"></a><span class="st">    # Define architecture parameters</span></span>
<span id="cb19-56"><a href="#cb19-56" aria-hidden="true" tabindex="-1"></a><span class="st">    n_layers = trial.suggest_int('n_layers', 2, 8)</span></span>
<span id="cb19-57"><a href="#cb19-57" aria-hidden="true" tabindex="-1"></a><span class="st">    n_filters = trial.suggest_int('n_filters', 16, 128)</span></span>
<span id="cb19-58"><a href="#cb19-58" aria-hidden="true" tabindex="-1"></a><span class="st">    </span></span>
<span id="cb19-59"><a href="#cb19-59" aria-hidden="true" tabindex="-1"></a><span class="st">    # Build and train model</span></span>
<span id="cb19-60"><a href="#cb19-60" aria-hidden="true" tabindex="-1"></a><span class="st">    model = build_model(n_layers, n_filters)</span></span>
<span id="cb19-61"><a href="#cb19-61" aria-hidden="true" tabindex="-1"></a><span class="st">    accuracy = train_and_evaluate(model)</span></span>
<span id="cb19-62"><a href="#cb19-62" aria-hidden="true" tabindex="-1"></a><span class="st">    </span></span>
<span id="cb19-63"><a href="#cb19-63" aria-hidden="true" tabindex="-1"></a><span class="st">    return accuracy</span></span>
<span id="cb19-64"><a href="#cb19-64" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-65"><a href="#cb19-65" aria-hidden="true" tabindex="-1"></a><span class="st">study = optuna.create_study(direction='maximize')</span></span>
<span id="cb19-66"><a href="#cb19-66" aria-hidden="true" tabindex="-1"></a><span class="st">study.optimize(objective, n_trials=100)</span></span>
<span id="cb19-67"><a href="#cb19-67" aria-hidden="true" tabindex="-1"></a><span class="st">'''</span></span>
<span id="cb19-68"><a href="#cb19-68" aria-hidden="true" tabindex="-1"></a>            },</span>
<span id="cb19-69"><a href="#cb19-69" aria-hidden="true" tabindex="-1"></a>            <span class="st">'ray_tune'</span>: {</span>
<span id="cb19-70"><a href="#cb19-70" aria-hidden="true" tabindex="-1"></a>                <span class="st">'description'</span>: <span class="st">'Distributed hyperparameter tuning'</span>,</span>
<span id="cb19-71"><a href="#cb19-71" aria-hidden="true" tabindex="-1"></a>                <span class="st">'strengths'</span>: [<span class="st">'Scalable'</span>, <span class="st">'Multiple search algorithms'</span>, <span class="st">'Good for distributed training'</span>],</span>
<span id="cb19-72"><a href="#cb19-72" aria-hidden="true" tabindex="-1"></a>                <span class="st">'installation'</span>: <span class="st">'pip install ray[tune]'</span>,</span>
<span id="cb19-73"><a href="#cb19-73" aria-hidden="true" tabindex="-1"></a>                <span class="st">'example_usage'</span>: <span class="st">'''</span></span>
<span id="cb19-74"><a href="#cb19-74" aria-hidden="true" tabindex="-1"></a><span class="st">from ray import tune</span></span>
<span id="cb19-75"><a href="#cb19-75" aria-hidden="true" tabindex="-1"></a><span class="st">from ray.tune.schedulers import ASHAScheduler</span></span>
<span id="cb19-76"><a href="#cb19-76" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-77"><a href="#cb19-77" aria-hidden="true" tabindex="-1"></a><span class="st">def trainable(config):</span></span>
<span id="cb19-78"><a href="#cb19-78" aria-hidden="true" tabindex="-1"></a><span class="st">    model = build_model(config)</span></span>
<span id="cb19-79"><a href="#cb19-79" aria-hidden="true" tabindex="-1"></a><span class="st">    accuracy = train_model(model)</span></span>
<span id="cb19-80"><a href="#cb19-80" aria-hidden="true" tabindex="-1"></a><span class="st">    tune.report(accuracy=accuracy)</span></span>
<span id="cb19-81"><a href="#cb19-81" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-82"><a href="#cb19-82" aria-hidden="true" tabindex="-1"></a><span class="st">scheduler = ASHAScheduler(metric="accuracy", mode="max")</span></span>
<span id="cb19-83"><a href="#cb19-83" aria-hidden="true" tabindex="-1"></a><span class="st">result = tune.run(</span></span>
<span id="cb19-84"><a href="#cb19-84" aria-hidden="true" tabindex="-1"></a><span class="st">    trainable,</span></span>
<span id="cb19-85"><a href="#cb19-85" aria-hidden="true" tabindex="-1"></a><span class="st">    resources_per_trial={"cpu": 2, "gpu": 1},</span></span>
<span id="cb19-86"><a href="#cb19-86" aria-hidden="true" tabindex="-1"></a><span class="st">    config={</span></span>
<span id="cb19-87"><a href="#cb19-87" aria-hidden="true" tabindex="-1"></a><span class="st">        "n_layers": tune.choice([2, 4, 6, 8]),</span></span>
<span id="cb19-88"><a href="#cb19-88" aria-hidden="true" tabindex="-1"></a><span class="st">        "n_filters": tune.choice([16, 32, 64, 128])</span></span>
<span id="cb19-89"><a href="#cb19-89" aria-hidden="true" tabindex="-1"></a><span class="st">    },</span></span>
<span id="cb19-90"><a href="#cb19-90" aria-hidden="true" tabindex="-1"></a><span class="st">    scheduler=scheduler</span></span>
<span id="cb19-91"><a href="#cb19-91" aria-hidden="true" tabindex="-1"></a><span class="st">)</span></span>
<span id="cb19-92"><a href="#cb19-92" aria-hidden="true" tabindex="-1"></a><span class="st">'''</span></span>
<span id="cb19-93"><a href="#cb19-93" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb19-94"><a href="#cb19-94" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb19-95"><a href="#cb19-95" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-96"><a href="#cb19-96" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_framework_recommendation(<span class="va">self</span>, use_case: <span class="bu">str</span>):</span>
<span id="cb19-97"><a href="#cb19-97" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get framework recommendation based on use case"""</span></span>
<span id="cb19-98"><a href="#cb19-98" aria-hidden="true" tabindex="-1"></a>        recommendations <span class="op">=</span> {</span>
<span id="cb19-99"><a href="#cb19-99" aria-hidden="true" tabindex="-1"></a>            <span class="st">'research'</span>: [<span class="st">'automl'</span>, <span class="st">'custom_implementation'</span>],</span>
<span id="cb19-100"><a href="#cb19-100" aria-hidden="true" tabindex="-1"></a>            <span class="st">'production'</span>: [<span class="st">'nni'</span>, <span class="st">'ray_tune'</span>],</span>
<span id="cb19-101"><a href="#cb19-101" aria-hidden="true" tabindex="-1"></a>            <span class="st">'hyperparameter_tuning'</span>: [<span class="st">'optuna'</span>, <span class="st">'ray_tune'</span>],</span>
<span id="cb19-102"><a href="#cb19-102" aria-hidden="true" tabindex="-1"></a>            <span class="st">'distributed_training'</span>: [<span class="st">'ray_tune'</span>],</span>
<span id="cb19-103"><a href="#cb19-103" aria-hidden="true" tabindex="-1"></a>            <span class="st">'beginner_friendly'</span>: [<span class="st">'nni'</span>, <span class="st">'optuna'</span>]</span>
<span id="cb19-104"><a href="#cb19-104" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb19-105"><a href="#cb19-105" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-106"><a href="#cb19-106" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> recommendations.get(use_case, [<span class="st">'nni'</span>])</span>
<span id="cb19-107"><a href="#cb19-107" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-108"><a href="#cb19-108" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Custom NAS implementation using PyTorch</span></span>
<span id="cb19-109"><a href="#cb19-109" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CustomNASExample:</span>
<span id="cb19-110"><a href="#cb19-110" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb19-111"><a href="#cb19-111" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_space <span class="op">=</span> <span class="va">None</span></span>
<span id="cb19-112"><a href="#cb19-112" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_strategy <span class="op">=</span> <span class="va">None</span></span>
<span id="cb19-113"><a href="#cb19-113" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.evaluator <span class="op">=</span> <span class="va">None</span></span>
<span id="cb19-114"><a href="#cb19-114" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-115"><a href="#cb19-115" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_cifar10_nas(<span class="va">self</span>):</span>
<span id="cb19-116"><a href="#cb19-116" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup NAS for CIFAR-10"""</span></span>
<span id="cb19-117"><a href="#cb19-117" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define search space</span></span>
<span id="cb19-118"><a href="#cb19-118" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_space <span class="op">=</span> CellSearchSpace(num_nodes<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb19-119"><a href="#cb19-119" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-120"><a href="#cb19-120" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define search strategy</span></span>
<span id="cb19-121"><a href="#cb19-121" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.search_strategy <span class="op">=</span> EvolutionarySearch(population_size<span class="op">=</span><span class="dv">20</span>)</span>
<span id="cb19-122"><a href="#cb19-122" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-123"><a href="#cb19-123" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define evaluator</span></span>
<span id="cb19-124"><a href="#cb19-124" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.evaluator <span class="op">=</span> EarlyStoppingEvaluator(</span>
<span id="cb19-125"><a href="#cb19-125" aria-hidden="true" tabindex="-1"></a>            dataset<span class="op">=</span><span class="va">self</span>._get_cifar10_dataset(),</span>
<span id="cb19-126"><a href="#cb19-126" aria-hidden="true" tabindex="-1"></a>            max_epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb19-127"><a href="#cb19-127" aria-hidden="true" tabindex="-1"></a>            patience<span class="op">=</span><span class="dv">3</span></span>
<span id="cb19-128"><a href="#cb19-128" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb19-129"><a href="#cb19-129" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-130"><a href="#cb19-130" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> run_nas_experiment(<span class="va">self</span>):</span>
<span id="cb19-131"><a href="#cb19-131" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run complete NAS experiment"""</span></span>
<span id="cb19-132"><a href="#cb19-132" aria-hidden="true" tabindex="-1"></a>        framework <span class="op">=</span> NASFramework(</span>
<span id="cb19-133"><a href="#cb19-133" aria-hidden="true" tabindex="-1"></a>            search_space<span class="op">=</span><span class="va">self</span>.search_space,</span>
<span id="cb19-134"><a href="#cb19-134" aria-hidden="true" tabindex="-1"></a>            search_strategy<span class="op">=</span><span class="va">self</span>.search_strategy,</span>
<span id="cb19-135"><a href="#cb19-135" aria-hidden="true" tabindex="-1"></a>            performance_estimator<span class="op">=</span><span class="va">self</span>.evaluator</span>
<span id="cb19-136"><a href="#cb19-136" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb19-137"><a href="#cb19-137" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-138"><a href="#cb19-138" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run search</span></span>
<span id="cb19-139"><a href="#cb19-139" aria-hidden="true" tabindex="-1"></a>        best_result <span class="op">=</span> framework.search(num_iterations<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb19-140"><a href="#cb19-140" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-141"><a href="#cb19-141" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Analyze results</span></span>
<span id="cb19-142"><a href="#cb19-142" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Best architecture: </span><span class="sc">{</span>best_result[<span class="st">'architecture'</span>]<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb19-143"><a href="#cb19-143" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Best performance: </span><span class="sc">{</span>best_result[<span class="st">'performance'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb19-144"><a href="#cb19-144" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-145"><a href="#cb19-145" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> best_result</span>
<span id="cb19-146"><a href="#cb19-146" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-147"><a href="#cb19-147" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _get_cifar10_dataset(<span class="va">self</span>):</span>
<span id="cb19-148"><a href="#cb19-148" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get CIFAR-10 dataset"""</span></span>
<span id="cb19-149"><a href="#cb19-149" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation depends on your data loading setup</span></span>
<span id="cb19-150"><a href="#cb19-150" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
<p>This comprehensive guide covers the essential aspects of Neural Architecture Search, from theoretical foundations to practical implementations. The code examples provide a solid foundation for understanding and implementing NAS algorithms, while the best practices and framework recommendations help guide practical applications.</p>
<p>The key takeaways from this guide are:</p>
<ol type="1">
<li><strong>NAS Framework</strong>: Understanding the three core components (search space, search strategy, performance estimation) is crucial</li>
<li><strong>Search Space Design</strong>: Careful design of search spaces balances expressiveness with computational efficiency</li>
<li><strong>Search Strategies</strong>: Different strategies have different trade-offs between exploration and exploitation</li>
<li><strong>Performance Estimation</strong>: Efficient evaluation methods are essential for practical NAS</li>
<li><strong>Implementation</strong>: Modern frameworks provide good starting points, but custom implementations offer more control</li>
<li><strong>Best Practices</strong>: Following established guidelines improves NAS effectiveness and reproducibility</li>
</ol>
<p>For beginners, I recommend starting with existing frameworks like NNI or Optuna, then gradually moving to custom implementations as understanding deepens. For research applications, implementing methods from scratch using the patterns shown in this guide provides the most flexibility and insight into the algorithm</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[The Mathematics Behind Neural Architecture Search]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-mathematics/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-mathematics/</guid>
      <pubDate>Fri, 11 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="the-mathematics-behind-neural-architecture-search" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-mathematics/nas-math.png" class="img-fluid"></p>
<p>Neural Architecture Search (NAS) represents one of the most sophisticated applications of automated machine learning, where algorithms autonomously design neural network architectures. This field combines optimization theory, probability, and deep learning to solve the fundamental question: what is the optimal neural network architecture for a given task?</p>
<section id="problem-formulation" class="level2">
<h2 class="anchored" data-anchor-id="problem-formulation" id="problem-formulation">Problem Formulation</h2>
<p>The core mathematical challenge in NAS can be formulated as a bilevel optimization problem. Given a dataset <span class="math inline">\(D = \{(x_i, y_i)\}_{i=1}^N\)</span>, we seek to find the optimal architecture <span class="math inline">\(\alpha^*\)</span> that minimizes the validation loss:</p>
<p><span class="math display">\[\alpha^* = \arg \min_\alpha L_{\text{val}}(w^*(\alpha), \alpha)\]</span></p>
<p>where <span class="math inline">\(w^*(\alpha)\)</span> is the optimal set of weights for architecture <span class="math inline">\(\alpha\)</span>, obtained by solving:</p>
<p><span class="math display">\[w^*(\alpha) = \arg \min_w L_{\text{train}}(w, \alpha)\]</span></p>
<p>This bilevel structure creates significant computational challenges, as evaluating each architecture requires full training to obtain <span class="math inline">\(w^*(\alpha)\)</span>.</p>
</section>
<section id="search-space-representation" class="level2">
<h2 class="anchored" data-anchor-id="search-space-representation" id="search-space-representation">Search Space Representation</h2>
<section id="continuous-relaxation" class="level3">
<h3 class="anchored" data-anchor-id="continuous-relaxation" id="continuous-relaxation">Continuous Relaxation</h3>
<p>One of the key mathematical innovations in NAS is the continuous relaxation of the discrete architecture search space. Instead of searching over discrete architectural choices, we can represent the search space as a continuous optimization problem.</p>
<p>Consider a search space where each edge in the network can be one of <span class="math inline">\(O\)</span> operations from a set <span class="math inline">\(\mathcal{O} = \{o^1, o^2, \ldots, o^{|\mathcal{O}|}\}\)</span>. The continuous relaxation introduces architecture parameters <span class="math inline">\(\alpha = \{\alpha_{i,j}\}_{i,j}\)</span> where <span class="math inline">\(\alpha_{i,j} \in \mathbb{R}^{|\mathcal{O}|}\)</span>.</p>
<p>The mixed operation at edge <span class="math inline">\((i,j)\)</span> becomes:</p>
<p><span class="math display">\[o^{\text{mixed}}_{i,j}(x) = \sum_{k=1}^{|\mathcal{O}|} \frac{\exp(\alpha_{i,j}^{(k)})}{\sum_{l=1}^{|\mathcal{O}|} \exp(\alpha_{i,j}^{(l)})} \cdot o^{(k)}(x)\]</span></p>
<p>This softmax weighting allows gradient-based optimization while maintaining the constraint that weights sum to 1.</p>
</section>
<section id="graph-based-representations" class="level3">
<h3 class="anchored" data-anchor-id="graph-based-representations" id="graph-based-representations">Graph-Based Representations</h3>
<p>Neural architectures can be represented as directed acyclic graphs (DAGs) <span class="math inline">\(G = (V, E)\)</span> where:</p>
<ul>
<li><span class="math inline">\(V\)</span> represents computational nodes (layers, operations)</li>
<li><span class="math inline">\(E\)</span> represents data flow connections</li>
</ul>
<p>The adjacency matrix <span class="math inline">\(A \in \{0,1\}^{|V|\times|V|}\)</span> encodes the connectivity, where <span class="math inline">\(A_{i,j} = 1\)</span> indicates a connection from node <span class="math inline">\(i\)</span> to node <span class="math inline">\(j\)</span>.</p>
<p>For a node <span class="math inline">\(j\)</span> with incoming edges from nodes <span class="math inline">\(i_1, i_2, \ldots, i_k\)</span>, the output is:</p>
<p><span class="math display">\[h_j = f_j\left(\sum_{i \in \text{pred}(j)} A_{i,j} \cdot h_i\right)\]</span></p>
<p>where <span class="math inline">\(f_j\)</span> is the operation at node <span class="math inline">\(j\)</span> and <span class="math inline">\(\text{pred}(j)\)</span> denotes the predecessor nodes.</p>
</section>
</section>
<section id="optimization-strategies" class="level2">
<h2 class="anchored" data-anchor-id="optimization-strategies" id="optimization-strategies">Optimization Strategies</h2>
<section id="gradient-based-methods-darts" class="level3">
<h3 class="anchored" data-anchor-id="gradient-based-methods-darts" id="gradient-based-methods-darts">Gradient-Based Methods (DARTS)</h3>
<p>Differentiable Architecture Search (DARTS) transforms the discrete search into a continuous optimization problem. The architecture parameters <span class="math inline">\(\alpha\)</span> and network weights <span class="math inline">\(w\)</span> are optimized alternately:</p>
<p><span class="math display">\[\alpha_{t+1} = \alpha_t - \xi_\alpha \nabla_\alpha L_{\text{val}}(w_t, \alpha_t)\]</span> <span class="math display">\[w_{t+1} = w_t - \xi_w \nabla_w L_{\text{train}}(w_t, \alpha_t)\]</span></p>
<p>The gradient with respect to architecture parameters is:</p>
<p><span class="math display">\[\nabla_\alpha L_{\text{val}} = \sum_{i,j} \nabla_\alpha o^{\text{mixed}}_{i,j} \cdot \nabla_{o^{\text{mixed}}_{i,j}} L_{\text{val}}\]</span></p>
<p>The chain rule application requires careful handling of the softmax operation:</p>
<p><span class="math display">\[\nabla_{\alpha_{i,j}^{(k)}} o^{\text{mixed}}_{i,j} = (\delta_{k,l} - p_{i,j}^{(k)}) p_{i,j}^{(l)} \cdot o^{(l)}\]</span></p>
<p>where <span class="math inline">\(p_{i,j}^{(k)} = \frac{\exp(\alpha_{i,j}^{(k)})}{\sum_l \exp(\alpha_{i,j}^{(l)})}\)</span> and <span class="math inline">\(\delta_{k,l}\)</span> is the Kronecker delta.</p>
</section>
<section id="evolutionary-approaches" class="level3">
<h3 class="anchored" data-anchor-id="evolutionary-approaches" id="evolutionary-approaches">Evolutionary Approaches</h3>
<p>Evolutionary algorithms treat architecture search as a population-based optimization problem. Each architecture is represented as a genome <span class="math inline">\(g\)</span>, and the fitness function is typically the validation accuracy.</p>
<p>The mutation operator <span class="math inline">\(M: \mathcal{G} \to \mathcal{G}\)</span> modifies architectures:</p>
<ul>
<li><strong>Node mutations</strong>: Add/remove computational nodes</li>
<li><strong>Edge mutations</strong>: Add/remove connections<br>
</li>
<li><strong>Operation mutations</strong>: Change operation types</li>
</ul>
<p>The crossover operator <span class="math inline">\(C: \mathcal{G} \times \mathcal{G} \to \mathcal{G}\)</span> combines two parent architectures:</p>
<p><span class="math display">\[g_{\text{child}} = C(g_{\text{parent1}}, g_{\text{parent2}})\]</span></p>
<p>Common crossover strategies include:</p>
<ul>
<li><strong>Uniform crossover</strong>: Each gene inherited from parent1 with probability <span class="math inline">\(p\)</span></li>
<li><strong>Graph crossover</strong>: Combine subgraphs from both parents</li>
</ul>
</section>
<section id="reinforcement-learning-formulation" class="level3">
<h3 class="anchored" data-anchor-id="reinforcement-learning-formulation" id="reinforcement-learning-formulation">Reinforcement Learning Formulation</h3>
<p>NAS can be formulated as a sequential decision problem where an agent (controller) generates architectures. The state space <span class="math inline">\(\mathcal{S}\)</span> represents partial architectures, actions <span class="math inline">\(\mathcal{A}\)</span> represent architectural choices, and rewards <span class="math inline">\(\mathcal{R}\)</span> correspond to validation performance.</p>
<p>The policy <span class="math inline">\(\pi(a|s)\)</span> gives the probability of selecting action <span class="math inline">\(a\)</span> in state <span class="math inline">\(s\)</span>. The objective is to maximize expected reward:</p>
<p><span class="math display">\[J(\theta) = \mathbb{E}_{\pi_\theta}[R(\tau)]\]</span></p>
<p>where <span class="math inline">\(\tau\)</span> is a trajectory (sequence of architectural decisions) and <span class="math inline">\(\theta\)</span> are the controller parameters.</p>
<p>Using the REINFORCE algorithm, the gradient is:</p>
<p><span class="math display">\[\nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta}[\nabla_\theta \log \pi_\theta(a|s) \cdot (R(\tau) - b)]\]</span></p>
<p>where <span class="math inline">\(b\)</span> is a baseline to reduce variance.</p>
</section>
</section>
<section id="probability-and-sampling" class="level2">
<h2 class="anchored" data-anchor-id="probability-and-sampling" id="probability-and-sampling">Probability and Sampling</h2>
<section id="architecture-sampling" class="level3">
<h3 class="anchored" data-anchor-id="architecture-sampling" id="architecture-sampling">Architecture Sampling</h3>
<p>When using continuous relaxation, the final discrete architecture must be sampled. The Gumbel-Softmax trick provides a differentiable sampling mechanism:</p>
<p><span class="math display">\[\alpha_{\text{sampled}} = \text{softmax}\left(\frac{\log(\alpha) + g}{\tau}\right)\]</span></p>
<p>where <span class="math inline">\(g \sim \text{Gumbel}(0,1)\)</span> and <span class="math inline">\(\tau\)</span> is a temperature parameter controlling the sampling sharpness.</p>
</section>
<section id="bayesian-optimization" class="level3">
<h3 class="anchored" data-anchor-id="bayesian-optimization" id="bayesian-optimization">Bayesian Optimization</h3>
<p>Some NAS methods model the architecture performance as a Gaussian process. Given observed architectures and performances <span class="math inline">\(\{(\alpha_i, y_i)\}_{i=1}^n\)</span>, we model:</p>
<p><span class="math display">\[f(\alpha) \sim \mathcal{GP}(\mu(\alpha), k(\alpha, \alpha'))\]</span></p>
<p>The acquisition function guides the search:</p>
<p><span class="math display">\[\alpha_{\text{next}} = \arg \max_\alpha a(\alpha|\{(\alpha_i, y_i)\}_{i=1}^n)\]</span></p>
<p>Common acquisition functions include:</p>
<ul>
<li><strong>Expected Improvement</strong>: <span class="math inline">\(\text{EI}(\alpha) = \mathbb{E}[\max(0, f(\alpha) - f(\alpha_{\text{best}}))]\)</span></li>
<li><strong>Upper Confidence Bound</strong>: <span class="math inline">\(\text{UCB}(\alpha) = \mu(\alpha) + \beta \cdot \sigma(\alpha)\)</span></li>
</ul>
</section>
</section>
<section id="weight-sharing-and-supernets" class="level2">
<h2 class="anchored" data-anchor-id="weight-sharing-and-supernets" id="weight-sharing-and-supernets">Weight Sharing and Supernets</h2>
<section id="one-shot-architecture-search" class="level3">
<h3 class="anchored" data-anchor-id="one-shot-architecture-search" id="one-shot-architecture-search">One-Shot Architecture Search</h3>
<p>Weight sharing reduces computational cost by training a single “supernet” containing all possible architectures. The supernet weight tensor <span class="math inline">\(W\)</span> has dimensions accommodating all operations.</p>
<p>For a mixed operation with architecture weights <span class="math inline">\(\alpha\)</span>, the effective computation is:</p>
<p><span class="math display">\[\text{output} = \sum_k \alpha_k \cdot \text{op}_k(\text{input}, W_k)\]</span></p>
<p>The challenge is ensuring that shared weights <span class="math inline">\(W_k\)</span> generalize across different architectural contexts.</p>
</section>
<section id="progressive-shrinking" class="level3">
<h3 class="anchored" data-anchor-id="progressive-shrinking" id="progressive-shrinking">Progressive Shrinking</h3>
<p>Progressive shrinking gradually reduces the search space by removing poorly-performing operations. The pruning criterion at iteration <span class="math inline">\(t\)</span> is:</p>
<p><span class="math display">\[\text{keep}_k = \begin{cases}
1 &amp; \text{if } \alpha_k^{(t)} &gt; \text{threshold}_t \\
0 &amp; \text{otherwise}
\end{cases}\]</span></p>
<p>This creates a sequence of nested search spaces: <span class="math inline">\(\mathcal{S}_0 \supset \mathcal{S}_1 \supset \ldots \supset \mathcal{S}_T\)</span>.</p>
</section>
</section>
<section id="performance-prediction" class="level2">
<h2 class="anchored" data-anchor-id="performance-prediction" id="performance-prediction">Performance Prediction</h2>
<section id="learning-curves-and-extrapolation" class="level3">
<h3 class="anchored" data-anchor-id="learning-curves-and-extrapolation" id="learning-curves-and-extrapolation">Learning Curves and Extrapolation</h3>
<p>Early stopping strategies predict final performance from partial training curves. Common models include:</p>
<ul>
<li><strong>Power Law</strong>: <span class="math inline">\(f(x) = a \cdot x^b + c\)</span></li>
<li><strong>Exponential</strong>: <span class="math inline">\(f(x) = a \cdot e^{-bx} + c\)</span></li>
<li><strong>Logarithmic</strong>: <span class="math inline">\(f(x) = a \cdot \log(x) + b\)</span></li>
</ul>
<p>The parameters are fitted using least squares on early training data, then extrapolated to predict full training performance.</p>
</section>
<section id="neural-predictors" class="level3">
<h3 class="anchored" data-anchor-id="neural-predictors" id="neural-predictors">Neural Predictors</h3>
<p>Neural networks can predict architecture performance from structural features. Given an architecture encoding <span class="math inline">\(\phi(\alpha)\)</span>, a predictor network estimates:</p>
<p><span class="math display">\[\hat{y} = f_\theta(\phi(\alpha))\]</span></p>
<p>where <span class="math inline">\(\phi(\alpha)\)</span> might include:</p>
<ul>
<li>Graph neural network embeddings</li>
<li>Handcrafted features (depth, width, parameter count)</li>
<li>Learned representations</li>
</ul>
</section>
</section>
<section id="multi-objective-optimization" class="level2">
<h2 class="anchored" data-anchor-id="multi-objective-optimization" id="multi-objective-optimization">Multi-Objective Optimization</h2>
<p>Real-world NAS often involves multiple objectives: accuracy, latency, energy consumption, and memory usage. This creates a multi-objective optimization problem:</p>
<p><span class="math display">\[\min F(\alpha) = (f_1(\alpha), f_2(\alpha), \ldots, f_m(\alpha))\]</span></p>
<section id="pareto-optimality" class="level3">
<h3 class="anchored" data-anchor-id="pareto-optimality" id="pareto-optimality">Pareto Optimality</h3>
<p>An architecture <span class="math inline">\(\alpha^*\)</span> is Pareto optimal if there exists no <span class="math inline">\(\alpha\)</span> such that:</p>
<ul>
<li><span class="math inline">\(f_i(\alpha) \leq f_i(\alpha^*)\)</span> for all <span class="math inline">\(i\)</span></li>
<li><span class="math inline">\(f_j(\alpha) &lt; f_j(\alpha^*)\)</span> for at least one <span class="math inline">\(j\)</span></li>
</ul>
<p>The Pareto front represents the set of all Pareto optimal solutions.</p>
</section>
<section id="scalarization-methods" class="level3">
<h3 class="anchored" data-anchor-id="scalarization-methods" id="scalarization-methods">Scalarization Methods</h3>
<ul>
<li><strong>Weighted Sum</strong>: <span class="math inline">\(\min_\alpha \sum_i w_i \cdot f_i(\alpha)\)</span></li>
<li><strong>ε-Constraint</strong>: <span class="math inline">\(\min_\alpha f_1(\alpha)\)</span> subject to <span class="math inline">\(f_i(\alpha) \leq \varepsilon_i\)</span> for <span class="math inline">\(i &gt; 1\)</span></li>
<li><strong>Chebyshev</strong>: <span class="math inline">\(\min_\alpha \max_i w_i \cdot |f_i(\alpha) - z_i^*|\)</span></li>
</ul>
<p>where <span class="math inline">\(z_i^*\)</span> is the ideal value for objective <span class="math inline">\(i\)</span>.</p>
</section>
</section>
<section id="complexity-analysis" class="level2">
<h2 class="anchored" data-anchor-id="complexity-analysis" id="complexity-analysis">Complexity Analysis</h2>
<section id="search-space-size" class="level3">
<h3 class="anchored" data-anchor-id="search-space-size" id="search-space-size">Search Space Size</h3>
<p>The size of the discrete search space grows exponentially with the number of choices. For a search space with:</p>
<ul>
<li><span class="math inline">\(L\)</span> layers</li>
<li><span class="math inline">\(O\)</span> operations per layer<br>
</li>
<li><span class="math inline">\(C\)</span> connections per layer</li>
</ul>
<p>The total number of architectures is approximately <span class="math inline">\(O^L \cdot 2^{LC}\)</span>, making exhaustive search intractable for realistic problem sizes.</p>
</section>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<p>Different NAS methods have varying computational requirements:</p>
<ul>
<li><strong>Exhaustive Search</strong>: <span class="math inline">\(\mathcal{O}(|\mathcal{S}| \cdot T)\)</span> where <span class="math inline">\(|\mathcal{S}|\)</span> is search space size and <span class="math inline">\(T\)</span> is training time</li>
<li><strong>Gradient-Based</strong>: <span class="math inline">\(\mathcal{O}(K \cdot T)\)</span> where <span class="math inline">\(K\)</span> is number of gradient steps</li>
<li><strong>Evolutionary</strong>: <span class="math inline">\(\mathcal{O}(P \cdot G \cdot T)\)</span> where <span class="math inline">\(P\)</span> is population size and <span class="math inline">\(G\)</span> is number of generations</li>
<li><strong>One-Shot</strong>: <span class="math inline">\(\mathcal{O}(T_{\text{supernet}} + |\mathcal{S}| \cdot T_{\text{eval}})\)</span> where <span class="math inline">\(T_{\text{eval}} \ll T\)</span></li>
</ul>
</section>
</section>
<section id="convergence-analysis" class="level2">
<h2 class="anchored" data-anchor-id="convergence-analysis" id="convergence-analysis">Convergence Analysis</h2>
<section id="darts-convergence" class="level3">
<h3 class="anchored" data-anchor-id="darts-convergence" id="darts-convergence">DARTS Convergence</h3>
<p>For DARTS, convergence depends on the interplay between architecture and weight optimization. The coupled dynamics can be analyzed using:</p>
<p><span class="math display">\[\alpha_{t+1} = \alpha_t - \xi_\alpha \nabla_\alpha L_{\text{val}}(w^*(\alpha_t), \alpha_t)\]</span> <span class="math display">\[w_{t+1} = w_t - \xi_w \nabla_w L_{\text{train}}(w_t, \alpha_t)\]</span></p>
<p>Under certain conditions (convexity, smoothness), this alternating optimization converges to a stationary point. However, the bilevel nature and non-convexity of neural networks make theoretical guarantees challenging.</p>
</section>
<section id="evolutionary-algorithm-convergence" class="level3">
<h3 class="anchored" data-anchor-id="evolutionary-algorithm-convergence" id="evolutionary-algorithm-convergence">Evolutionary Algorithm Convergence</h3>
<p>For evolutionary NAS, convergence analysis involves studying the transition probabilities between population states. The probability of finding the optimal architecture depends on:</p>
<ul>
<li>Selection pressure</li>
<li>Mutation rates</li>
<li>Population diversity</li>
</ul>
<p>The expected hitting time to the optimum can be bounded using Markov chain analysis.</p>
</section>
</section>
<section id="practical-considerations" class="level2">
<h2 class="anchored" data-anchor-id="practical-considerations" id="practical-considerations">Practical Considerations</h2>
<section id="regularization" class="level3">
<h3 class="anchored" data-anchor-id="regularization" id="regularization">Regularization</h3>
<p>Architecture search often requires regularization to prevent overfitting:</p>
<ul>
<li><strong>Dropout on Architecture Parameters</strong>: Randomly zero some <span class="math inline">\(\alpha\)</span> values during training</li>
<li><strong>Weight Decay</strong>: Add L2 penalty <span class="math inline">\(\lambda ||\alpha||^2\)</span> to the loss</li>
<li><strong>Early Stopping</strong>: Stop search when validation performance plateaus</li>
</ul>
</section>
<section id="search-space-design" class="level3">
<h3 class="anchored" data-anchor-id="search-space-design" id="search-space-design">Search Space Design</h3>
<p>The choice of search space significantly impacts results. Key considerations include:</p>
<ul>
<li><strong>Expressivity</strong>: Can the space represent effective architectures?</li>
<li><strong>Efficiency</strong>: Can the space be searched efficiently?</li>
<li><strong>Inductive Bias</strong>: Does the space encode useful architectural priors?</li>
</ul>
<p>Mathematical analysis of search spaces involves studying their geometric properties, connectivity, and the distribution of high-performing architectures.</p>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<p>Neural Architecture Search continues to evolve, with emerging mathematical frameworks addressing:</p>
<ul>
<li><strong>Theoretical foundations</strong>: Convergence guarantees and optimality conditions</li>
<li><strong>Efficient search</strong>: Better approximation algorithms and search strategies<br>
</li>
<li><strong>Transferability</strong>: Mathematical models for cross-domain architecture transfer</li>
<li><strong>Interpretability</strong>: Understanding why certain architectures perform well</li>
</ul>
<p>The mathematical sophistication of NAS continues to grow, drawing from diverse fields including optimization theory, probability, graph theory, and control theory. As the field matures, we expect to see more principled approaches that combine theoretical rigor with practical effectiveness.</p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The intersection of discrete optimization, continuous relaxation, and deep learning in NAS represents one of the most mathematically rich areas in modern machine learning, with applications extending far beyond neural network design to general automated algorithm design and meta-learning.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Neural Architecture Search: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-summary/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-summary/</guid>
      <pubDate>Fri, 11 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="neural-architecture-search-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/neural-architecture-search/nas-summary/nas.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Neural Architecture Search (NAS) represents a paradigm shift in deep learning, moving from manual architecture design to automated discovery of optimal neural network structures. This field has emerged as one of the most promising areas in machine learning, addressing the fundamental challenge of designing neural networks that are both effective and efficient for specific tasks.</p>
<p>The traditional approach to neural network design relies heavily on human expertise, intuition, and extensive trial-and-error experimentation. Researchers and practitioners spend considerable time crafting architectures, tuning hyperparameters, and adapting existing designs to new domains. NAS automates this process, using computational methods to explore the vast space of possible architectures and identify designs that achieve superior performance with minimal human intervention.</p>
</section>
<section id="the-architecture-design-challenge" class="level2">
<h2 class="anchored" data-anchor-id="the-architecture-design-challenge" id="the-architecture-design-challenge">The Architecture Design Challenge</h2>
<p>Neural network architecture design involves making numerous interconnected decisions about layer types, connectivity patterns, activation functions, and structural components. The complexity of these decisions grows exponentially with network depth and the variety of available operations. Consider that even a simple decision tree for a 10-layer network with 5 possible layer types per position yields <span class="math inline">\(5^{10}\)</span> possible architectures—nearly 10 million configurations.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>The manual design process typically follows established patterns and heuristics. Researchers begin with proven architectures like ResNet or VGG, then modify them based on domain knowledge and empirical results.</p>
</div>
</div>
<p>This approach has limitations: it’s time-consuming, potentially biased toward human preconceptions, and may miss novel architectural innovations that could significantly improve performance.</p>
</section>
<section id="core-concepts-and-definitions" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts-and-definitions" id="core-concepts-and-definitions">Core Concepts and Definitions</h2>
<section id="search-space" class="level3">
<h3 class="anchored" data-anchor-id="search-space" id="search-space">Search Space</h3>
<p>The search space defines the set of all possible architectures that the NAS algorithm can explore. A well-designed search space balances expressiveness with computational tractability. Common search space formulations include:</p>
<ul>
<li><strong>Cell-based Search Spaces</strong>: These define repeatable computational cells that are stacked to form complete architectures. Each cell contains a directed acyclic graph of operations, with the final architecture determined by the cell structure and stacking pattern.</li>
<li><strong>Macro Search Spaces</strong>: These consider the overall network structure, including the number of layers, layer types, and connectivity patterns across the entire network.</li>
<li><strong>Hierarchical Search Spaces</strong>: These decompose the architecture search into multiple levels, such as searching for optimal cells at one level and optimal cell arrangements at another.</li>
</ul>
</section>
<section id="performance-estimation" class="level3">
<h3 class="anchored" data-anchor-id="performance-estimation" id="performance-estimation">Performance Estimation</h3>
<p>Evaluating architecture performance is computationally expensive, as it typically requires training each candidate architecture to convergence. NAS methods employ various strategies to reduce this computational burden:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Proxy Tasks</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Performance Prediction</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Weight Sharing</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>Training on simplified versions of the target task, such as using fewer epochs, smaller datasets, or reduced model sizes.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Using machine learning models to predict architecture performance based on structural features without full training.</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Sharing weights among similar architectural components to reduce training time.</p>
</div>
</div>
</div>
</section>
<section id="search-strategy" class="level3">
<h3 class="anchored" data-anchor-id="search-strategy" id="search-strategy">Search Strategy</h3>
<p>The search strategy determines how the NAS algorithm navigates the search space to find optimal architectures. Different strategies make different trade-offs between exploration and exploitation:</p>
<ul>
<li><strong>Random Search</strong>: Samples architectures uniformly from the search space. While simple, it can be surprisingly effective for well-designed search spaces.</li>
<li><strong>Evolutionary Algorithms</strong>: Use principles of natural selection to evolve populations of architectures over generations.</li>
<li><strong>Reinforcement Learning</strong>: Treats architecture search as a sequential decision-making problem, using RL agents to generate architectures.</li>
<li><strong>Gradient-based Methods</strong>: Relax the discrete search space into a continuous one, enabling gradient-based optimization.</li>
</ul>
</section>
</section>
<section id="historical-development" class="level2">
<h2 class="anchored" data-anchor-id="historical-development" id="historical-development">Historical Development</h2>
<p>Neural Architecture Search emerged from the broader field of evolutionary computation and neural evolution. Early work in the 1990s explored evolving neural network topologies using genetic algorithms, but computational limitations prevented widespread adoption.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Important
</div>
</div>
<div class="callout-body-container callout-body">
<p>The modern NAS era began with the 2017 paper “Neural Architecture Search with Reinforcement Learning” by Zoph and Le. This work demonstrated that reinforcement learning could automatically design architectures that matched or exceeded human-designed networks on image classification tasks.</p>
</div>
</div>
<p>Key milestones in NAS development include:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 35%">
<col style="width: 64%">
</colgroup>
<thead>
<tr class="header">
<th>Year</th>
<th>Milestone</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>2017</td>
<td>Introduction of reinforcement learning-based NAS</td>
</tr>
<tr class="even">
<td>2018</td>
<td>Development of Efficient Neural Architecture Search (ENAS) with weight sharing</td>
</tr>
<tr class="odd">
<td>2019</td>
<td>Introduction of differentiable architecture search (DARTS)</td>
</tr>
<tr class="even">
<td>2020</td>
<td>Hardware-aware NAS and multi-objective optimization</td>
</tr>
<tr class="odd">
<td>2021</td>
<td>Zero-shot NAS and training-free performance estimation</td>
</tr>
<tr class="even">
<td>2022</td>
<td>Transformer architecture search and large-scale NAS</td>
</tr>
</tbody>
</table>
</section>
<section id="major-nas-methodologies" class="level2">
<h2 class="anchored" data-anchor-id="major-nas-methodologies" id="major-nas-methodologies">Major NAS Methodologies</h2>
<section id="reinforcement-learning-based-nas" class="level3">
<h3 class="anchored" data-anchor-id="reinforcement-learning-based-nas" id="reinforcement-learning-based-nas">Reinforcement Learning-Based NAS</h3>
<p>Reinforcement learning approaches model architecture search as a sequential decision-making problem. A controller (typically an RNN) generates architecture descriptions by making a sequence of decisions about layer types, connections, and hyperparameters. The controller is trained using reinforcement learning, with the validation accuracy of generated architectures serving as the reward signal.</p>
<p>The original NAS formulation used the REINFORCE algorithm to train the controller. The process involves:</p>
<ol type="1">
<li>The controller samples an architecture from the search space</li>
<li>The architecture is trained on the target task</li>
<li>The validation accuracy provides a reward signal</li>
<li>The controller parameters are updated using policy gradients</li>
</ol>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Warning
</div>
</div>
<div class="callout-body-container callout-body">
<p>This approach achieved remarkable results, discovering architectures that outperformed human-designed networks on ImageNet classification. However, the computational cost was enormous—the original paper required 22,400 GPU-days to find optimal architectures.</p>
</div>
</div>
</section>
<section id="evolutionary-approaches" class="level3">
<h3 class="anchored" data-anchor-id="evolutionary-approaches" id="evolutionary-approaches">Evolutionary Approaches</h3>
<p>Evolutionary methods maintain a population of candidate architectures and evolve them over generations using genetic operators like mutation and crossover. These methods are naturally suited to architecture search because they can handle discrete search spaces and don’t require gradient information.</p>
<p>The evolutionary process typically follows these steps:</p>
<ol type="1">
<li>Initialize a population of random architectures</li>
<li>Evaluate each architecture’s fitness (usually validation accuracy)</li>
<li>Select parents based on fitness scores</li>
<li>Generate offspring using crossover and mutation</li>
<li>Replace the least fit individuals with offspring</li>
<li>Repeat until convergence</li>
</ol>
<p>Evolutionary approaches offer several advantages: they’re robust to noisy fitness evaluations, can handle multi-objective optimization naturally, and are less likely to get stuck in local optima compared to gradient-based methods.</p>
</section>
<section id="differentiable-architecture-search-darts" class="level3">
<h3 class="anchored" data-anchor-id="differentiable-architecture-search-darts" id="differentiable-architecture-search-darts">Differentiable Architecture Search (DARTS)</h3>
<p>DARTS revolutionized NAS by making the search process differentiable, enabling gradient-based optimization. The key insight is to relax the discrete architecture search into a continuous optimization problem.</p>
<p>In DARTS, instead of selecting a single operation for each edge in the architecture graph, all possible operations are initially included with learnable weights. The architecture is represented as a weighted combination of all operations, with the weights learned through gradient descent.</p>
<p>The DARTS formulation involves:</p>
<ul>
<li><strong>Architecture Parameters</strong>: Weights that determine the importance of each operation</li>
<li><strong>Network Weights</strong>: Standard neural network parameters</li>
<li><strong>Bilevel Optimization</strong>: Alternating between optimizing network weights and architecture parameters</li>
</ul>
<p>After training, the final architecture is obtained by selecting the operation with the highest weight for each edge. This approach reduces search time from thousands of GPU-days to a few GPU-days.</p>
</section>
<section id="one-shot-architecture-search" class="level3">
<h3 class="anchored" data-anchor-id="one-shot-architecture-search" id="one-shot-architecture-search">One-Shot Architecture Search</h3>
<p>One-shot methods train a single “supernet” that contains all possible architectures in the search space as subnetworks. Once trained, different architectures can be evaluated by sampling subnetworks without additional training.</p>
<p>The supernet approach works by:</p>
<ol type="1">
<li><strong>Supernet Training</strong>: Training a large network that encompasses all candidate architectures</li>
<li><strong>Architecture Sampling</strong>: Evaluating specific architectures by activating corresponding subnetworks</li>
<li><strong>Performance Estimation</strong>: Using the sampled subnetwork’s performance as a proxy for the full architecture’s performance</li>
</ol>
<p>This method dramatically reduces computational cost since it requires training only once. However, it introduces challenges related to weight sharing and potential interference between different architectural paths.</p>
</section>
</section>
<section id="search-space-design" class="level2">
<h2 class="anchored" data-anchor-id="search-space-design" id="search-space-design">Search Space Design</h2>
<section id="cell-based-search-spaces" class="level3">
<h3 class="anchored" data-anchor-id="cell-based-search-spaces" id="cell-based-search-spaces">Cell-Based Search Spaces</h3>
<p>Cell-based search spaces focus on finding optimal computational cells that can be stacked to form complete architectures. This approach reduces the search space size while maintaining architectural diversity.</p>
<p>A typical cell contains:</p>
<ul>
<li><strong>Input Nodes</strong>: Receive inputs from previous cells or external sources</li>
<li><strong>Intermediate Nodes</strong>: Apply operations to transform inputs</li>
<li><strong>Output Nodes</strong>: Combine intermediate results to produce cell outputs</li>
</ul>
<p>The cell structure is defined by:</p>
<ul>
<li><strong>Operations</strong>: Convolutions, pooling, skip connections, etc.</li>
<li><strong>Connections</strong>: How nodes are connected within the cell</li>
<li><strong>Combination Functions</strong>: How multiple inputs to a node are combined</li>
</ul>
<p>Popular cell-based search spaces include:</p>
<ul>
<li><strong>NASNet Search Space</strong>: Used in the original NAS paper</li>
<li><strong>DARTS Search Space</strong>: Simplified version focusing on common operations</li>
<li><strong>PC-DARTS Search Space</strong>: Extends DARTS with partial channel connections</li>
</ul>
</section>
<section id="macro-search-spaces" class="level3">
<h3 class="anchored" data-anchor-id="macro-search-spaces" id="macro-search-spaces">Macro Search Spaces</h3>
<p>Macro search spaces consider the overall network structure, including decisions about:</p>
<ul>
<li><strong>Network Depth</strong>: Total number of layers</li>
<li><strong>Layer Types</strong>: Convolution, pooling, normalization, activation</li>
<li><strong>Channel Dimensions</strong>: Number of filters in each layer</li>
<li><strong>Skip Connections</strong>: Long-range connections between layers</li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p>Macro search is more challenging than cell-based search because: - The search space is typically much larger - Architectural decisions are more interdependent - Performance evaluation is more expensive</p>
</div>
</div>
</section>
<section id="hierarchical-search-spaces" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-search-spaces" id="hierarchical-search-spaces">Hierarchical Search Spaces</h3>
<p>Hierarchical approaches decompose architecture search into multiple levels:</p>
<ul>
<li><strong>Level 1</strong>: Micro-architecture search (within cells)</li>
<li><strong>Level 2</strong>: Macro-architecture search (cell arrangement)</li>
<li><strong>Level 3</strong>: Network-level search (overall structure)</li>
</ul>
<p>This decomposition allows for:</p>
<ul>
<li>More efficient search by reducing complexity at each level</li>
<li>Better generalization across different tasks</li>
<li>Modular design that can be adapted to various domains</li>
</ul>
</section>
</section>
<section id="performance-estimation-strategies" class="level2">
<h2 class="anchored" data-anchor-id="performance-estimation-strategies" id="performance-estimation-strategies">Performance Estimation Strategies</h2>
<section id="proxy-tasks-1" class="level3">
<h3 class="anchored" data-anchor-id="proxy-tasks-1" id="proxy-tasks-1">Proxy Tasks</h3>
<p>Proxy tasks reduce evaluation cost by training on simplified versions of the target problem:</p>
<ul>
<li><strong>Reduced Epochs</strong>: Training for fewer iterations to get approximate performance</li>
<li><strong>Smaller Datasets</strong>: Using subsets of the training data</li>
<li><strong>Lower Resolution</strong>: Reducing image size or sequence length</li>
<li><strong>Fewer Channels</strong>: Using narrower networks during search</li>
</ul>
<p>The effectiveness of proxy tasks depends on:</p>
<ul>
<li><strong>Rank Correlation</strong>: How well proxy performance predicts full performance</li>
<li><strong>Computational Savings</strong>: The reduction in training time</li>
<li><strong>Task Similarity</strong>: How closely the proxy resembles the target task</li>
</ul>
</section>
<section id="weight-sharing-1" class="level3">
<h3 class="anchored" data-anchor-id="weight-sharing-1" id="weight-sharing-1">Weight Sharing</h3>
<p>Weight sharing reduces training time by reusing parameters across similar architectural components:</p>
<ul>
<li><strong>Parameter Inheritance</strong>: New architectures inherit weights from previously trained models</li>
<li><strong>Shared Backbones</strong>: Common layers share parameters across different architectures</li>
<li><strong>Progressive Training</strong>: Gradually building up architectures while sharing lower-level weights</li>
</ul>
<p>Challenges with weight sharing include:</p>
<ul>
<li><strong>Interference</strong>: Different architectures may require conflicting parameter values</li>
<li><strong>Bias</strong>: Shared weights may favor certain architectural patterns</li>
<li><strong>Optimization</strong>: Balancing individual architecture performance with shared efficiency</li>
</ul>
</section>
<section id="performance-prediction-1" class="level3">
<h3 class="anchored" data-anchor-id="performance-prediction-1" id="performance-prediction-1">Performance Prediction</h3>
<p>Machine learning models can predict architecture performance without full training:</p>
<ul>
<li><strong>Feature Engineering</strong>: Extracting architectural features (depth, width, connectivity)</li>
<li><strong>Graph Neural Networks</strong>: Using GNNs to encode architectural structure</li>
<li><strong>Surrogate Models</strong>: Training regression models on architecture-performance pairs</li>
</ul>
<p>Key considerations:</p>
<ul>
<li><strong>Training Data</strong>: Sufficient architecture-performance pairs for training</li>
<li><strong>Generalization</strong>: Ability to predict performance on unseen architectures</li>
<li><strong>Computational Cost</strong>: Prediction should be much faster than full training</li>
</ul>
</section>
</section>
<section id="hardware-aware-nas" class="level2">
<h2 class="anchored" data-anchor-id="hardware-aware-nas" id="hardware-aware-nas">Hardware-Aware NAS</h2>
<section id="motivation" class="level3">
<h3 class="anchored" data-anchor-id="motivation" id="motivation">Motivation</h3>
<p>Modern deployment scenarios require architectures that are not only accurate but also efficient in terms of:</p>
<ul>
<li><strong>Latency</strong>: Inference time on target hardware</li>
<li><strong>Energy Consumption</strong>: Power usage during operation</li>
<li><strong>Memory Usage</strong>: RAM and storage requirements</li>
<li><strong>Throughput</strong>: Number of samples processed per second</li>
</ul>
<p>Traditional NAS methods optimize primarily for accuracy, often producing architectures that are impractical for deployment. Hardware-aware NAS addresses this by incorporating efficiency metrics into the search process.</p>
</section>
<section id="multi-objective-optimization" class="level3">
<h3 class="anchored" data-anchor-id="multi-objective-optimization" id="multi-objective-optimization">Multi-Objective Optimization</h3>
<p>Hardware-aware NAS typically involves multiple, often conflicting objectives:</p>
<ul>
<li><strong>Accuracy</strong>: Model performance on the target task</li>
<li><strong>Efficiency</strong>: Hardware-specific metrics (latency, energy, memory)</li>
<li><strong>Size</strong>: Model parameter count and storage requirements</li>
</ul>
<p>Common approaches include:</p>
<ul>
<li><strong>Pareto-optimal Search</strong>: Finding architectures that represent optimal trade-offs</li>
<li><strong>Weighted Objectives</strong>: Combining multiple metrics into a single score</li>
<li><strong>Constraint-based Search</strong>: Searching within efficiency constraints</li>
</ul>
</section>
<section id="platform-specific-considerations" class="level3">
<h3 class="anchored" data-anchor-id="platform-specific-considerations" id="platform-specific-considerations">Platform-Specific Considerations</h3>
<p>Different hardware platforms have unique characteristics that affect architecture performance:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 26%">
<col style="width: 42%">
<col style="width: 31%">
</colgroup>
<thead>
<tr class="header">
<th>Platform</th>
<th>Characteristics</th>
<th>Priorities</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Mobile Devices</td>
<td>Limited memory and battery life</td>
<td>Efficiency</td>
</tr>
<tr class="even">
<td>Edge Devices</td>
<td>Extreme resource constraints</td>
<td>Real-time performance</td>
</tr>
<tr class="odd">
<td>Cloud GPUs</td>
<td>High throughput capabilities</td>
<td>Parallel processing</td>
</tr>
<tr class="even">
<td>Specialized Hardware</td>
<td>TPUs, FPGAs, custom accelerators</td>
<td>Optimized operations</td>
</tr>
</tbody>
</table>
</section>
<section id="latency-prediction" class="level3">
<h3 class="anchored" data-anchor-id="latency-prediction" id="latency-prediction">Latency Prediction</h3>
<p>Accurate latency prediction is crucial for hardware-aware NAS:</p>
<ul>
<li><strong>Direct Measurement</strong>: Running architectures on target hardware</li>
<li><strong>Analytical Models</strong>: Using theoretical models based on operation counts</li>
<li><strong>Learned Predictors</strong>: Training models to predict latency from architectural features</li>
</ul>
<p>Challenges include:</p>
<ul>
<li><strong>Hardware Variability</strong>: Different devices have different performance characteristics</li>
<li><strong>Optimization Effects</strong>: Compiler optimizations can significantly affect performance</li>
<li><strong>Batch Size Dependencies</strong>: Latency often varies with batch size</li>
</ul>
</section>
</section>
<section id="applications-across-domains" class="level2">
<h2 class="anchored" data-anchor-id="applications-across-domains" id="applications-across-domains">Applications Across Domains</h2>
<section id="computer-vision" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision" id="computer-vision">Computer Vision</h3>
<p>NAS has achieved remarkable success in computer vision tasks:</p>
<ul>
<li><strong>Image Classification</strong>: Discovering architectures that outperform ResNet and other human-designed networks</li>
<li><strong>Object Detection</strong>: Finding efficient architectures for real-time detection systems</li>
<li><strong>Semantic Segmentation</strong>: Optimizing architectures for dense prediction tasks</li>
<li><strong>Image Generation</strong>: Searching for GAN architectures with improved stability and quality</li>
</ul>
<p>Notable achievements:</p>
<ul>
<li><strong>EfficientNet</strong>: Achieved state-of-the-art ImageNet accuracy with fewer parameters</li>
<li><strong>NAS-FPN</strong>: Improved object detection performance through architecture search</li>
<li><strong>Auto-DeepLab</strong>: Automated architecture search for semantic segmentation</li>
</ul>
</section>
<section id="natural-language-processing" class="level3">
<h3 class="anchored" data-anchor-id="natural-language-processing" id="natural-language-processing">Natural Language Processing</h3>
<p>NAS applications in NLP have focused on:</p>
<ul>
<li><strong>Language Modeling</strong>: Finding efficient architectures for sequence modeling</li>
<li><strong>Machine Translation</strong>: Optimizing encoder-decoder architectures</li>
<li><strong>Text Classification</strong>: Discovering architectures for various NLP tasks</li>
<li><strong>Question Answering</strong>: Searching for architectures that can effectively reason over text</li>
</ul>
<p>Key developments:</p>
<ul>
<li><strong>Evolved Transformer</strong>: Used evolutionary search to improve Transformer architectures</li>
<li><strong>NASH</strong>: Applied NAS to find efficient architectures for language understanding</li>
<li><strong>AutoML for NLP</strong>: Automated architecture search for various NLP tasks</li>
</ul>
</section>
<section id="speech-recognition" class="level3">
<h3 class="anchored" data-anchor-id="speech-recognition" id="speech-recognition">Speech Recognition</h3>
<p>Speech recognition presents unique challenges for NAS:</p>
<ul>
<li><strong>Temporal Modeling</strong>: Architectures must effectively capture temporal dependencies</li>
<li><strong>Computational Constraints</strong>: Real-time processing requirements</li>
<li><strong>Robustness</strong>: Handling various acoustic conditions and speaking styles</li>
</ul>
<p>Applications include:</p>
<ul>
<li><strong>Automatic Speech Recognition</strong>: Finding efficient architectures for speech-to-text</li>
<li><strong>Speaker Recognition</strong>: Optimizing architectures for speaker identification</li>
<li><strong>Speech Enhancement</strong>: Searching for architectures that can improve audio quality</li>
</ul>
</section>
<section id="recommendation-systems" class="level3">
<h3 class="anchored" data-anchor-id="recommendation-systems" id="recommendation-systems">Recommendation Systems</h3>
<p>NAS has been applied to recommendation systems for:</p>
<ul>
<li><strong>Feature Interaction</strong>: Finding optimal ways to combine user and item features</li>
<li><strong>Embedding Architectures</strong>: Optimizing embedding dimensions and structures</li>
<li><strong>Multi-Task Learning</strong>: Balancing multiple recommendation objectives</li>
</ul>
<p>Challenges specific to recommendation systems:</p>
<ul>
<li><strong>Large-Scale Data</strong>: Handling massive user-item interaction datasets</li>
<li><strong>Cold Start</strong>: Dealing with new users and items</li>
<li><strong>Interpretability</strong>: Maintaining explainable recommendation decisions</li>
</ul>
</section>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<section id="computational-cost" class="level3">
<h3 class="anchored" data-anchor-id="computational-cost" id="computational-cost">Computational Cost</h3>
<p>Despite significant progress, NAS remains computationally expensive:</p>
<ul>
<li><strong>Search Time</strong>: Finding optimal architectures can take days or weeks</li>
<li><strong>Hardware Requirements</strong>: Requiring substantial computational resources</li>
<li><strong>Energy Consumption</strong>: High carbon footprint of extensive architecture search</li>
</ul>
<p>Mitigation strategies include:</p>
<ul>
<li><strong>Efficient Search Methods</strong>: Developing faster search algorithms</li>
<li><strong>Better Performance Estimation</strong>: Reducing evaluation cost</li>
<li><strong>Transfer Learning</strong>: Reusing search results across similar tasks</li>
</ul>
</section>
<section id="search-space-bias" class="level3">
<h3 class="anchored" data-anchor-id="search-space-bias" id="search-space-bias">Search Space Bias</h3>
<p>The design of search spaces introduces inherent biases:</p>
<ul>
<li><strong>Human Bias</strong>: Search spaces reflect human assumptions about good architectures</li>
<li><strong>Limited Diversity</strong>: Constrained search spaces may miss innovative designs</li>
<li><strong>Task Specificity</strong>: Search spaces designed for one task may not generalize</li>
</ul>
</section>
<section id="reproducibility" class="level3">
<h3 class="anchored" data-anchor-id="reproducibility" id="reproducibility">Reproducibility</h3>
<p>NAS research faces significant reproducibility challenges:</p>
<ul>
<li><strong>Computational Requirements</strong>: Not all researchers have access to required resources</li>
<li><strong>Implementation Details</strong>: Many important details are often omitted from papers</li>
<li><strong>Evaluation Protocols</strong>: Inconsistent evaluation methods across studies</li>
</ul>
</section>
<section id="generalization" class="level3">
<h3 class="anchored" data-anchor-id="generalization" id="generalization">Generalization</h3>
<p>Architectures found by NAS may not generalize well:</p>
<ul>
<li><strong>Task Transfer</strong>: Architectures optimized for one task may not work well on others</li>
<li><strong>Dataset Dependence</strong>: Performance may not transfer to different datasets</li>
<li><strong>Scale Sensitivity</strong>: Architectures may not scale to different problem sizes</li>
</ul>
</section>
</section>
<section id="recent-advances-and-future-directions" class="level2">
<h2 class="anchored" data-anchor-id="recent-advances-and-future-directions" id="recent-advances-and-future-directions">Recent Advances and Future Directions</h2>
<section id="zero-shot-nas" class="level3">
<h3 class="anchored" data-anchor-id="zero-shot-nas" id="zero-shot-nas">Zero-Shot NAS</h3>
<p>Zero-shot NAS aims to evaluate architectures without training:</p>
<ul>
<li><strong>Architecture Encoders</strong>: Using graph neural networks to encode architectural structure</li>
<li><strong>Performance Predictors</strong>: Training models to predict performance from structure alone</li>
<li><strong>Gradient-Based Metrics</strong>: Using gradient information to assess architecture quality</li>
</ul>
<p>This approach promises to eliminate the training bottleneck entirely, making NAS accessible to researchers with limited computational resources.</p>
</section>
<section id="automated-machine-learning-automl" class="level3">
<h3 class="anchored" data-anchor-id="automated-machine-learning-automl" id="automated-machine-learning-automl">Automated Machine Learning (AutoML)</h3>
<p>NAS is increasingly integrated into broader AutoML systems:</p>
<ul>
<li><strong>End-to-End Automation</strong>: Combining architecture search with hyperparameter optimization</li>
<li><strong>Data Preprocessing</strong>: Jointly optimizing data augmentation and architecture</li>
<li><strong>Model Selection</strong>: Automatically choosing between different model families</li>
</ul>
</section>
<section id="federated-nas" class="level3">
<h3 class="anchored" data-anchor-id="federated-nas" id="federated-nas">Federated NAS</h3>
<p>Federated learning scenarios present new challenges for NAS:</p>
<ul>
<li><strong>Heterogeneous Data</strong>: Different clients may have different data distributions</li>
<li><strong>Communication Constraints</strong>: Limited bandwidth for sharing architectural information</li>
<li><strong>Privacy Concerns</strong>: Protecting client data during architecture search</li>
</ul>
</section>
<section id="transformer-architecture-search" class="level3">
<h3 class="anchored" data-anchor-id="transformer-architecture-search" id="transformer-architecture-search">Transformer Architecture Search</h3>
<p>The success of Transformers has sparked interest in automated Transformer design:</p>
<ul>
<li><strong>Attention Mechanisms</strong>: Searching for optimal attention patterns</li>
<li><strong>Positional Encodings</strong>: Finding better ways to encode positional information</li>
<li><strong>Architecture Scaling</strong>: Optimizing Transformer architectures for different scales</li>
</ul>
</section>
<section id="multi-modal-nas" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-nas" id="multi-modal-nas">Multi-Modal NAS</h3>
<p>As AI systems become more multi-modal, NAS must handle:</p>
<ul>
<li><strong>Cross-Modal Interactions</strong>: Optimizing architectures for multiple input modalities</li>
<li><strong>Fusion Strategies</strong>: Finding optimal ways to combine different types of information</li>
<li><strong>Unified Architectures</strong>: Searching for architectures that can handle multiple tasks</li>
</ul>
</section>
</section>
<section id="best-practices-and-recommendations" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-recommendations" id="best-practices-and-recommendations">Best Practices and Recommendations</h2>
<section id="search-space-design-1" class="level3">
<h3 class="anchored" data-anchor-id="search-space-design-1" id="search-space-design-1">Search Space Design</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Best Practices
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Start Simple</strong>: Begin with well-understood search spaces before exploring novel designs</li>
<li><strong>Validate Assumptions</strong>: Ensure that the search space can express effective architectures</li>
<li><strong>Consider Constraints</strong>: Incorporate deployment constraints into the search space design</li>
<li><strong>Enable Diversity</strong>: Allow for architectural diversity to avoid local optima</li>
</ul>
</div>
</div>
</section>
<section id="performance-estimation-1" class="level3">
<h3 class="anchored" data-anchor-id="performance-estimation-1" id="performance-estimation-1">Performance Estimation</h3>
<ul>
<li><strong>Validate Proxies</strong>: Ensure that proxy tasks correlate well with full performance</li>
<li><strong>Use Multiple Metrics</strong>: Consider multiple performance indicators beyond accuracy</li>
<li><strong>Account for Variance</strong>: Properly handle performance variability across runs</li>
<li><strong>Benchmark Carefully</strong>: Compare against appropriate baselines</li>
</ul>
</section>
<section id="implementation" class="level3">
<h3 class="anchored" data-anchor-id="implementation" id="implementation">Implementation</h3>
<ul>
<li><strong>Modular Code</strong>: Design systems that can easily incorporate new search methods</li>
<li><strong>Efficient Implementation</strong>: Optimize code for the specific computational constraints</li>
<li><strong>Careful Logging</strong>: Track all experiments and intermediate results</li>
<li><strong>Reproducible Setup</strong>: Document all implementation details and hyperparameters</li>
</ul>
</section>
<section id="evaluation" class="level3">
<h3 class="anchored" data-anchor-id="evaluation" id="evaluation">Evaluation</h3>
<ul>
<li><strong>Multiple Runs</strong>: Average results over multiple independent runs</li>
<li><strong>Statistical Significance</strong>: Use appropriate statistical tests for comparing methods</li>
<li><strong>Comprehensive Baselines</strong>: Compare against relevant human-designed architectures</li>
<li><strong>Transfer Evaluation</strong>: Test architectures on multiple tasks and datasets</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Neural Architecture Search represents a fundamental shift in how we approach neural network design, moving from manual crafting to automated discovery. The field has made remarkable progress in reducing computational costs, improving search efficiency, and expanding to new domains and applications.</p>
<p>Key achievements include the development of efficient search methods like DARTS, the integration of hardware constraints into the search process, and the successful application of NAS to diverse domains beyond computer vision. These advances have democratized access to high-quality architectures and enabled the discovery of designs that outperform human-crafted networks.</p>
<p>However, significant challenges remain. Computational costs, while reduced, are still substantial. Search space design continues to introduce biases that may limit architectural diversity. Reproducibility issues persist due to the computational requirements and implementation complexity. Generalization across tasks and datasets remains an active area of research.</p>
<p>The future of NAS looks promising, with emerging directions including zero-shot evaluation, federated learning integration, and multi-modal architecture search. As the field matures, we can expect to see more efficient methods, better theoretical understanding, and broader adoption in practical applications.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<p>For practitioners looking to apply NAS, the key is to start with established methods and well-designed search spaces, carefully validate performance estimation strategies, and consider the specific constraints and requirements of their deployment scenarios.</p>
</div>
</div>
<p>The ultimate goal of NAS is not just to automate architecture design, but to discover fundamental principles of neural network structure that can inform future research and development. By understanding what makes architectures effective across different tasks and constraints, we can build more intelligent, efficient, and capable AI systems.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Optuna for Deep Learning and Neural Architecture Search: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/neural-architecture-search/optuna-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/neural-architecture-search/optuna-code/</guid>
      <pubDate>Fri, 11 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="optuna-for-deep-learning-and-neural-architecture-search-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/neural-architecture-search/optuna-code/optuna.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Hyperparameter optimization is one of the most critical yet challenging aspects of deep learning. With the exponential growth in model complexity and the vast hyperparameter search spaces, manual tuning becomes impractical. Optuna, developed by Preferred Networks, emerges as a powerful automatic hyperparameter optimization framework that addresses these challenges with sophisticated algorithms and intuitive APIs.</p>
<p>This comprehensive guide explores how Optuna revolutionizes deep learning workflows, from basic hyperparameter tuning to advanced neural architecture search (NAS), providing practical implementations and real-world optimization strategies.</p>
</section>
<section id="what-is-optuna" class="level2">
<h2 class="anchored" data-anchor-id="what-is-optuna" id="what-is-optuna">What is Optuna?</h2>
<p>Optuna is an open-source hyperparameter optimization framework designed for machine learning. It offers several key advantages:</p>
<ul>
<li><strong>Efficient Sampling</strong>: Uses Tree-structured Parzen Estimator (TPE) and other advanced algorithms</li>
<li><strong>Pruning</strong>: Automatically stops unpromising trials early</li>
<li><strong>Distributed Optimization</strong>: Supports parallel and distributed hyperparameter search</li>
<li><strong>Framework Agnostic</strong>: Works with PyTorch, TensorFlow, Keras, and other ML frameworks</li>
<li><strong>Visualization</strong>: Rich dashboard for monitoring optimization progress</li>
</ul>
</section>
<section id="core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts" id="core-concepts">Core Concepts</h2>
<section id="studies-and-trials" class="level3">
<h3 class="anchored" data-anchor-id="studies-and-trials" id="studies-and-trials">Studies and Trials</h3>
<p>In Optuna terminology:</p>
<ul>
<li><strong>Study</strong>: An optimization session that tries to find optimal hyperparameters</li>
<li><strong>Trial</strong>: A single execution of the objective function with specific hyperparameter values</li>
<li><strong>Objective Function</strong>: The function to optimize (typically validation loss or accuracy)</li>
</ul>
</section>
<section id="sampling-algorithms" class="level3">
<h3 class="anchored" data-anchor-id="sampling-algorithms" id="sampling-algorithms">Sampling Algorithms</h3>
<p>Optuna implements several sophisticated sampling strategies:</p>
<ol type="1">
<li><strong>TPE (Tree-structured Parzen Estimator)</strong>: Default algorithm that models the probability distribution of hyperparameters</li>
<li><strong>Random Sampling</strong>: Baseline method for comparison</li>
<li><strong>Grid Search</strong>: Exhaustive search over specified parameter combinations</li>
<li><strong>CMA-ES</strong>: Covariance Matrix Adaptation Evolution Strategy for continuous optimization</li>
</ol>
</section>
<section id="pruning-algorithms" class="level3">
<h3 class="anchored" data-anchor-id="pruning-algorithms" id="pruning-algorithms">Pruning Algorithms</h3>
<p>Pruning eliminates unpromising trials early:</p>
<ul>
<li><strong>Median Pruner</strong>: Prunes trials below the median performance</li>
<li><strong>Successive Halving</strong>: Allocates resources progressively to promising trials</li>
<li><strong>Hyperband</strong>: Combines successive halving with different resource allocations</li>
</ul>
</section>
</section>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install optuna</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install optuna-dashboard  <span class="co"># Optional: for visualization</span></span></code></pre></div></div>
<p>For specific deep learning frameworks:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision  <span class="co"># PyTorch</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install tensorflow  <span class="co"># TensorFlow</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install optuna<span class="pp">[</span><span class="ss">integration</span><span class="pp">]</span>  <span class="co"># Framework integrations</span></span></code></pre></div></div>
</section>
<section id="basic-hyperparameter-optimization" class="level2">
<h2 class="anchored" data-anchor-id="basic-hyperparameter-optimization" id="basic-hyperparameter-optimization">Basic Hyperparameter Optimization</h2>
<section id="simple-pytorch-example" class="level3">
<h3 class="anchored" data-anchor-id="simple-pytorch-example" id="simple-pytorch-example">Simple PyTorch Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> optuna</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_model(trial):</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Suggest hyperparameters</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    n_layers <span class="op">=</span> trial.suggest_int(<span class="st">'n_layers'</span>, <span class="dv">1</span>, <span class="dv">3</span>)</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    n_units <span class="op">=</span> trial.suggest_int(<span class="st">'n_units'</span>, <span class="dv">64</span>, <span class="dv">512</span>)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    dropout_rate <span class="op">=</span> trial.suggest_float(<span class="st">'dropout_rate'</span>, <span class="fl">0.1</span>, <span class="fl">0.5</span>)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    layers <span class="op">=</span> []</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    in_features <span class="op">=</span> <span class="dv">784</span>  <span class="co"># MNIST input size</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_layers):</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Linear(in_features, n_units))</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.ReLU())</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Dropout(dropout_rate))</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        in_features <span class="op">=</span> n_units</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    layers.append(nn.Linear(in_features, <span class="dv">10</span>))  <span class="co"># Output layer</span></span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> nn.Sequential(<span class="op">*</span>layers)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> objective(trial):</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model hyperparameters</span></span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model(trial)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimizer hyperparameters</span></span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> trial.suggest_float(<span class="st">'lr'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-1</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>    optimizer_name <span class="op">=</span> trial.suggest_categorical(<span class="st">'optimizer'</span>, [<span class="st">'Adam'</span>, <span class="st">'SGD'</span>, <span class="st">'RMSprop'</span>])</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> optimizer_name <span class="op">==</span> <span class="st">'Adam'</span>:</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.Adam(model.parameters(), lr<span class="op">=</span>lr)</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> optimizer_name <span class="op">==</span> <span class="st">'SGD'</span>:</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        momentum <span class="op">=</span> trial.suggest_float(<span class="st">'momentum'</span>, <span class="fl">0.0</span>, <span class="fl">0.99</span>)</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.SGD(model.parameters(), lr<span class="op">=</span>lr, momentum<span class="op">=</span>momentum)</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:  <span class="co"># RMSprop</span></span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.RMSprop(model.parameters(), lr<span class="op">=</span>lr)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data loading</span></span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'data'</span>, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>transform)</span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">128</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>    test_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'data'</span>, train<span class="op">=</span><span class="va">False</span>, transform<span class="op">=</span>transform)</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>    test_loader <span class="op">=</span> DataLoader(test_dataset, batch_size<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training</span></span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.view(<span class="op">-</span><span class="dv">1</span>, <span class="dv">784</span>), target</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Optional: Report intermediate values for pruning</span></span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a>                trial.report(loss.item(), epoch <span class="op">*</span> <span class="bu">len</span>(train_loader) <span class="op">+</span> batch_idx)</span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> trial.should_prune():</span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">raise</span> optuna.exceptions.TrialPruned()</span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validation</span></span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data, target <span class="kw">in</span> test_loader:</span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.view(<span class="op">-</span><span class="dv">1</span>, <span class="dv">784</span>), target</span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(data)</span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs.data, <span class="dv">1</span>)</span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> target.size(<span class="dv">0</span>)</span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (predicted <span class="op">==</span> target).<span class="bu">sum</span>().item()</span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> correct <span class="op">/</span> total</span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a><span class="co"># Create study and optimize</span></span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(direction<span class="op">=</span><span class="st">'maximize'</span>)</span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>study.optimize(objective, n_trials<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Best trial: </span><span class="sc">{</span>study<span class="sc">.</span>best_trial<span class="sc">.</span>value<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Best params: </span><span class="sc">{</span>study<span class="sc">.</span>best_params<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-hyperparameter-optimization" class="level2">
<h2 class="anchored" data-anchor-id="advanced-hyperparameter-optimization" id="advanced-hyperparameter-optimization">Advanced Hyperparameter Optimization</h2>
<section id="multi-objective-optimization" class="level3">
<h3 class="anchored" data-anchor-id="multi-objective-optimization" id="multi-objective-optimization">Multi-Objective Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multi_objective_function(trial):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Suggest hyperparameters</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    n_layers <span class="op">=</span> trial.suggest_int(<span class="st">'n_layers'</span>, <span class="dv">1</span>, <span class="dv">5</span>)</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    n_units <span class="op">=</span> trial.suggest_int(<span class="st">'n_units'</span>, <span class="dv">32</span>, <span class="dv">512</span>)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    dropout_rate <span class="op">=</span> trial.suggest_float(<span class="st">'dropout_rate'</span>, <span class="fl">0.1</span>, <span class="fl">0.5</span>)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create and train model (simplified)</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model(trial)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> train_and_evaluate(model)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate model complexity (number of parameters)</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    model_size <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Return multiple objectives</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy, <span class="op">-</span>model_size  <span class="co"># Maximize accuracy, minimize model size</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Multi-objective study</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(directions<span class="op">=</span>[<span class="st">'maximize'</span>, <span class="st">'maximize'</span>])</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>study.optimize(multi_objective_function, n_trials<span class="op">=</span><span class="dv">100</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Get Pareto front</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>pareto_front <span class="op">=</span> study.best_trials</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> trial <span class="kw">in</span> pareto_front:</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Trial </span><span class="sc">{</span>trial<span class="sc">.</span>number<span class="sc">}</span><span class="ss">: Accuracy=</span><span class="sc">{</span>trial<span class="sc">.</span>values[<span class="dv">0</span>]<span class="sc">:.3f}</span><span class="ss">, "</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Model Size=</span><span class="sc">{</span><span class="op">-</span>trial<span class="sc">.</span>values[<span class="dv">1</span>]<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="conditional-hyperparameters" class="level3">
<h3 class="anchored" data-anchor-id="conditional-hyperparameters" id="conditional-hyperparameters">Conditional Hyperparameters</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> conditional_objective(trial):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Main architecture choice</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    model_type <span class="op">=</span> trial.suggest_categorical(<span class="st">'model_type'</span>, [<span class="st">'CNN'</span>, <span class="st">'ResNet'</span>, <span class="st">'DenseNet'</span>])</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> model_type <span class="op">==</span> <span class="st">'CNN'</span>:</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># CNN-specific parameters</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        n_conv_layers <span class="op">=</span> trial.suggest_int(<span class="st">'n_conv_layers'</span>, <span class="dv">2</span>, <span class="dv">4</span>)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        kernel_size <span class="op">=</span> trial.suggest_categorical(<span class="st">'kernel_size'</span>, [<span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>])</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        n_filters <span class="op">=</span> trial.suggest_int(<span class="st">'n_filters'</span>, <span class="dv">32</span>, <span class="dv">128</span>)</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_cnn(n_conv_layers, kernel_size, n_filters)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> model_type <span class="op">==</span> <span class="st">'ResNet'</span>:</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># ResNet-specific parameters</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        depth <span class="op">=</span> trial.suggest_categorical(<span class="st">'depth'</span>, [<span class="dv">18</span>, <span class="dv">34</span>, <span class="dv">50</span>])</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        width_multiplier <span class="op">=</span> trial.suggest_float(<span class="st">'width_multiplier'</span>, <span class="fl">0.5</span>, <span class="fl">2.0</span>)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_resnet(depth, width_multiplier)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:  <span class="co"># DenseNet</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># DenseNet-specific parameters</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        growth_rate <span class="op">=</span> trial.suggest_int(<span class="st">'growth_rate'</span>, <span class="dv">12</span>, <span class="dv">48</span>)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        block_config <span class="op">=</span> trial.suggest_categorical(<span class="st">'block_config'</span>, </span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>                                               [(<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">24</span>, <span class="dv">16</span>), (<span class="dv">6</span>, <span class="dv">12</span>, <span class="dv">32</span>, <span class="dv">32</span>)])</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_densenet(growth_rate, block_config)</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Common hyperparameters</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> trial.suggest_float(<span class="st">'lr'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-1</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> trial.suggest_categorical(<span class="st">'batch_size'</span>, [<span class="dv">16</span>, <span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>])</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_and_evaluate(model, lr, batch_size)</span></code></pre></div></div>
</section>
</section>
<section id="neural-architecture-search-nas" class="level2">
<h2 class="anchored" data-anchor-id="neural-architecture-search-nas" id="neural-architecture-search-nas">Neural Architecture Search (NAS)</h2>
<section id="basic-nas-implementation" class="level3">
<h3 class="anchored" data-anchor-id="basic-nas-implementation" id="basic-nas-implementation">Basic NAS Implementation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SearchableBlock(nn.Module):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, trial, block_id):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.block_id <span class="op">=</span> block_id</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Searchable operations</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        op_name <span class="op">=</span> trial.suggest_categorical(<span class="ss">f'op_</span><span class="sc">{</span>block_id<span class="sc">}</span><span class="ss">'</span>, [</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv3x3'</span>, <span class="st">'conv5x5'</span>, <span class="st">'conv7x7'</span>, <span class="st">'depthwise_conv'</span>, <span class="st">'skip_connect'</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> op_name <span class="op">==</span> <span class="st">'conv3x3'</span>:</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.op <span class="op">=</span> nn.Conv2d(in_channels, out_channels, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'conv5x5'</span>:</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.op <span class="op">=</span> nn.Conv2d(in_channels, out_channels, <span class="dv">5</span>, padding<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'conv7x7'</span>:</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.op <span class="op">=</span> nn.Conv2d(in_channels, out_channels, <span class="dv">7</span>, padding<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op_name <span class="op">==</span> <span class="st">'depthwise_conv'</span>:</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.op <span class="op">=</span> nn.Sequential(</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(in_channels, in_channels, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>, groups<span class="op">=</span>in_channels),</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:  <span class="co"># skip_connect</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.op <span class="op">=</span> nn.Identity() <span class="cf">if</span> in_channels <span class="op">==</span> out_channels <span class="cf">else</span> nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>)</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Activation and normalization</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.activation <span class="op">=</span> trial.suggest_categorical(<span class="ss">f'activation_</span><span class="sc">{</span>block_id<span class="sc">}</span><span class="ss">'</span>, </span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>                                                   [<span class="st">'relu'</span>, <span class="st">'gelu'</span>, <span class="st">'swish'</span>])</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.use_batch_norm <span class="op">=</span> trial.suggest_categorical(<span class="ss">f'batch_norm_</span><span class="sc">{</span>block_id<span class="sc">}</span><span class="ss">'</span>, [<span class="va">True</span>, <span class="va">False</span>])</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.use_batch_norm:</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.bn <span class="op">=</span> nn.BatchNorm2d(out_channels)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.activation <span class="op">==</span> <span class="st">'relu'</span>:</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.act <span class="op">=</span> nn.ReLU()</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> <span class="va">self</span>.activation <span class="op">==</span> <span class="st">'gelu'</span>:</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.act <span class="op">=</span> nn.GELU()</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:  <span class="co"># swish</span></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.act <span class="op">=</span> nn.SiLU()</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.op(x)</span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.use_batch_norm:</span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>            out <span class="op">=</span> <span class="va">self</span>.bn(out)</span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.act(out)</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SearchableNet(nn.Module):</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, trial, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Search for overall architecture</span></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>        num_stages <span class="op">=</span> trial.suggest_int(<span class="st">'num_stages'</span>, <span class="dv">3</span>, <span class="dv">5</span>)</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>        base_channels <span class="op">=</span> trial.suggest_int(<span class="st">'base_channels'</span>, <span class="dv">32</span>, <span class="dv">128</span>)</span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build searchable architecture</span></span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stages <span class="op">=</span> nn.ModuleList()</span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>        in_channels <span class="op">=</span> <span class="dv">3</span></span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> stage <span class="kw">in</span> <span class="bu">range</span>(num_stages):</span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Number of blocks in this stage</span></span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>            num_blocks <span class="op">=</span> trial.suggest_int(<span class="ss">f'num_blocks_stage_</span><span class="sc">{</span>stage<span class="sc">}</span><span class="ss">'</span>, <span class="dv">1</span>, <span class="dv">4</span>)</span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Channel progression</span></span>
<span id="cb6-66"><a href="#cb6-66" aria-hidden="true" tabindex="-1"></a>            out_channels <span class="op">=</span> base_channels <span class="op">*</span> (<span class="dv">2</span> <span class="op">**</span> stage)</span>
<span id="cb6-67"><a href="#cb6-67" aria-hidden="true" tabindex="-1"></a>            stage_blocks <span class="op">=</span> nn.ModuleList()</span>
<span id="cb6-68"><a href="#cb6-68" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-69"><a href="#cb6-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> block <span class="kw">in</span> <span class="bu">range</span>(num_blocks):</span>
<span id="cb6-70"><a href="#cb6-70" aria-hidden="true" tabindex="-1"></a>                block_id <span class="op">=</span> <span class="ss">f'stage_</span><span class="sc">{</span>stage<span class="sc">}</span><span class="ss">_block_</span><span class="sc">{</span>block<span class="sc">}</span><span class="ss">'</span></span>
<span id="cb6-71"><a href="#cb6-71" aria-hidden="true" tabindex="-1"></a>                stage_blocks.append(SearchableBlock(in_channels, out_channels, trial, block_id))</span>
<span id="cb6-72"><a href="#cb6-72" aria-hidden="true" tabindex="-1"></a>                in_channels <span class="op">=</span> out_channels</span>
<span id="cb6-73"><a href="#cb6-73" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-74"><a href="#cb6-74" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.stages.append(stage_blocks)</span>
<span id="cb6-75"><a href="#cb6-75" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-76"><a href="#cb6-76" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Downsampling between stages</span></span>
<span id="cb6-77"><a href="#cb6-77" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> stage <span class="op">&lt;</span> num_stages <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb6-78"><a href="#cb6-78" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.stages.append(nn.MaxPool2d(<span class="dv">2</span>))</span>
<span id="cb6-79"><a href="#cb6-79" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-80"><a href="#cb6-80" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global pooling and classifier</span></span>
<span id="cb6-81"><a href="#cb6-81" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_pool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb6-82"><a href="#cb6-82" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(in_channels, num_classes)</span>
<span id="cb6-83"><a href="#cb6-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-84"><a href="#cb6-84" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Dropout</span></span>
<span id="cb6-85"><a href="#cb6-85" aria-hidden="true" tabindex="-1"></a>        dropout_rate <span class="op">=</span> trial.suggest_float(<span class="st">'dropout_rate'</span>, <span class="fl">0.0</span>, <span class="fl">0.5</span>)</span>
<span id="cb6-86"><a href="#cb6-86" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb6-87"><a href="#cb6-87" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-88"><a href="#cb6-88" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-89"><a href="#cb6-89" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> stage <span class="kw">in</span> <span class="va">self</span>.stages:</span>
<span id="cb6-90"><a href="#cb6-90" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(stage, nn.ModuleList):</span>
<span id="cb6-91"><a href="#cb6-91" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> block <span class="kw">in</span> stage:</span>
<span id="cb6-92"><a href="#cb6-92" aria-hidden="true" tabindex="-1"></a>                    x <span class="op">=</span> block(x)</span>
<span id="cb6-93"><a href="#cb6-93" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb6-94"><a href="#cb6-94" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> stage(x)</span>
<span id="cb6-95"><a href="#cb6-95" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-96"><a href="#cb6-96" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.global_pool(x)</span>
<span id="cb6-97"><a href="#cb6-97" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb6-98"><a href="#cb6-98" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb6-99"><a href="#cb6-99" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb6-100"><a href="#cb6-100" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb6-101"><a href="#cb6-101" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-102"><a href="#cb6-102" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> nas_objective(trial):</span>
<span id="cb6-103"><a href="#cb6-103" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create searchable model</span></span>
<span id="cb6-104"><a href="#cb6-104" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SearchableNet(trial)</span>
<span id="cb6-105"><a href="#cb6-105" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-106"><a href="#cb6-106" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training hyperparameters</span></span>
<span id="cb6-107"><a href="#cb6-107" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> trial.suggest_float(<span class="st">'lr'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-1</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-108"><a href="#cb6-108" aria-hidden="true" tabindex="-1"></a>    weight_decay <span class="op">=</span> trial.suggest_float(<span class="st">'weight_decay'</span>, <span class="fl">1e-6</span>, <span class="fl">1e-2</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-109"><a href="#cb6-109" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-110"><a href="#cb6-110" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span>weight_decay)</span>
<span id="cb6-111"><a href="#cb6-111" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-112"><a href="#cb6-112" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data augmentation search</span></span>
<span id="cb6-113"><a href="#cb6-113" aria-hidden="true" tabindex="-1"></a>    use_cutmix <span class="op">=</span> trial.suggest_categorical(<span class="st">'use_cutmix'</span>, [<span class="va">True</span>, <span class="va">False</span>])</span>
<span id="cb6-114"><a href="#cb6-114" aria-hidden="true" tabindex="-1"></a>    use_mixup <span class="op">=</span> trial.suggest_categorical(<span class="st">'use_mixup'</span>, [<span class="va">True</span>, <span class="va">False</span>])</span>
<span id="cb6-115"><a href="#cb6-115" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-116"><a href="#cb6-116" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train and evaluate</span></span>
<span id="cb6-117"><a href="#cb6-117" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> train_model_with_augmentation(model, optimizer, use_cutmix, use_mixup)</span>
<span id="cb6-118"><a href="#cb6-118" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-119"><a href="#cb6-119" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span>
<span id="cb6-120"><a href="#cb6-120" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-121"><a href="#cb6-121" aria-hidden="true" tabindex="-1"></a><span class="co"># Run NAS</span></span>
<span id="cb6-122"><a href="#cb6-122" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(direction<span class="op">=</span><span class="st">'maximize'</span>, </span>
<span id="cb6-123"><a href="#cb6-123" aria-hidden="true" tabindex="-1"></a>                           pruner<span class="op">=</span>optuna.pruners.MedianPruner())</span>
<span id="cb6-124"><a href="#cb6-124" aria-hidden="true" tabindex="-1"></a>study.optimize(nas_objective, n_trials<span class="op">=</span><span class="dv">200</span>)</span></code></pre></div></div>
</section>
<section id="advanced-nas-with-weight-sharing" class="level3">
<h3 class="anchored" data-anchor-id="advanced-nas-with-weight-sharing" id="advanced-nas-with-weight-sharing">Advanced NAS with Weight Sharing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SuperNet(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Supernet that contains all possible operations"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define all possible operations</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.operations <span class="op">=</span> nn.ModuleDict({</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv3x3'</span>: nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv5x5'</span>: nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">5</span>, padding<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv7x7'</span>: nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">7</span>, padding<span class="op">=</span><span class="dv">3</span>),</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">'depthwise_conv'</span>: nn.Sequential(</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>, groups<span class="op">=</span><span class="dv">64</span>),</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">1</span>)</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            ),</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>            <span class="st">'skip_connect'</span>: nn.Identity()</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.stem <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="dv">64</span>, num_classes)</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_pool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, architecture):</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Forward pass with specific architecture"""</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.stem(x)</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, op_name <span class="kw">in</span> <span class="bu">enumerate</span>(architecture):</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> <span class="va">self</span>.operations[op_name](x)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> i <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>:  <span class="co"># Add downsampling periodically</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.global_pool(x)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> progressive_nas_objective(trial):</span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""NAS with progressive shrinking"""</span></span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sample architecture</span></span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>    num_blocks <span class="op">=</span> trial.suggest_int(<span class="st">'num_blocks'</span>, <span class="dv">4</span>, <span class="dv">8</span>)</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>    architecture <span class="op">=</span> []</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_blocks):</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>        op <span class="op">=</span> trial.suggest_categorical(<span class="ss">f'op_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">'</span>, [</span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">'conv3x3'</span>, <span class="st">'conv5x5'</span>, <span class="st">'conv7x7'</span>, <span class="st">'depthwise_conv'</span>, <span class="st">'skip_connect'</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>        architecture.append(op)</span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create supernet (shared across trials)</span></span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> <span class="bu">hasattr</span>(progressive_nas_objective, <span class="st">'supernet'</span>):</span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a>        progressive_nas_objective.supernet <span class="op">=</span> SuperNet()</span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> progressive_nas_objective.supernet</span>
<span id="cb7-55"><a href="#cb7-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-56"><a href="#cb7-56" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training with early stopping</span></span>
<span id="cb7-57"><a href="#cb7-57" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> train_with_early_stopping(model, architecture, trial)</span>
<span id="cb7-58"><a href="#cb7-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-59"><a href="#cb7-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span></code></pre></div></div>
</section>
</section>
<section id="distributed-optimization" class="level2">
<h2 class="anchored" data-anchor-id="distributed-optimization" id="distributed-optimization">Distributed Optimization</h2>
<section id="multi-gpu-training-with-optuna" class="level3">
<h3 class="anchored" data-anchor-id="multi-gpu-training-with-optuna" id="multi-gpu-training-with-optuna">Multi-GPU Training with Optuna</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> optuna</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> optuna.integration <span class="im">import</span> PyTorchLightningPruningCallback</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytorch_lightning <span class="im">as</span> pl</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LightningModel(pl.LightningModule):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, trial):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.trial <span class="op">=</span> trial</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Architecture hyperparameters</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lr <span class="op">=</span> trial.suggest_float(<span class="st">'lr'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-1</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.batch_size <span class="op">=</span> trial.suggest_categorical(<span class="st">'batch_size'</span>, [<span class="dv">16</span>, <span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>])</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Model definition</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>.build_model(trial)</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> build_model(<span class="va">self</span>, trial):</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build model based on trial suggestions</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> <span class="va">self</span>.criterion(y_hat, y)</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> <span class="va">self</span>.criterion(y_hat, y)</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> (y_hat.argmax(dim<span class="op">=</span><span class="dv">1</span>) <span class="op">==</span> y).<span class="bu">float</span>().mean()</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss, sync_dist<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_acc'</span>, acc, sync_dist<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.optim.Adam(<span class="va">self</span>.parameters(), lr<span class="op">=</span><span class="va">self</span>.lr)</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> distributed_objective(trial):</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> LightningModel(trial)</span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Pruning callback</span></span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>    pruning_callback <span class="op">=</span> PyTorchLightningPruningCallback(trial, monitor<span class="op">=</span><span class="st">'val_acc'</span>)</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Multi-GPU trainer</span></span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a>        gpus<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>        strategy<span class="op">=</span><span class="st">'ddp'</span>,</span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>        max_epochs<span class="op">=</span><span class="dv">50</span>,</span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>        callbacks<span class="op">=</span>[pruning_callback],</span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a>        enable_checkpointing<span class="op">=</span><span class="va">False</span></span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>    trainer.fit(model, train_dataloader, val_dataloader)</span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> trainer.callback_metrics[<span class="st">'val_acc'</span>].item()</span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a><span class="co"># Distributed study</span></span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(</span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a>    direction<span class="op">=</span><span class="st">'maximize'</span>,</span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a>    storage<span class="op">=</span><span class="st">'sqlite:///distributed_study.db'</span>,  <span class="co"># Shared storage</span></span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a>    study_name<span class="op">=</span><span class="st">'distributed_nas'</span></span>
<span id="cb8-65"><a href="#cb8-65" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb8-66"><a href="#cb8-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-67"><a href="#cb8-67" aria-hidden="true" tabindex="-1"></a>study.optimize(distributed_objective, n_trials<span class="op">=</span><span class="dv">500</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="hyperband-integration" class="level3">
<h3 class="anchored" data-anchor-id="hyperband-integration" id="hyperband-integration">Hyperband Integration</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> HyperbandPruner(optuna.pruners.BasePruner):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, min_resource<span class="op">=</span><span class="dv">1</span>, max_resource<span class="op">=</span><span class="dv">81</span>, reduction_factor<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.min_resource <span class="op">=</span> min_resource</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_resource <span class="op">=</span> max_resource</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.reduction_factor <span class="op">=</span> reduction_factor</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> prune(<span class="va">self</span>, study, trial):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Hyperband logic implementation</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> hyperband_objective(trial):</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Suggest resource budget</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    resource_budget <span class="op">=</span> trial.suggest_int(<span class="st">'resource_budget'</span>, <span class="dv">1</span>, <span class="dv">81</span>)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train for suggested epochs</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model(trial)</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> train_model(model, epochs<span class="op">=</span>resource_budget)</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    direction<span class="op">=</span><span class="st">'maximize'</span>,</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>    pruner<span class="op">=</span>HyperbandPruner()</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="population-based-training" class="level3">
<h3 class="anchored" data-anchor-id="population-based-training" id="population-based-training">Population-Based Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> population_based_optimization():</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    population_size <span class="op">=</span> <span class="dv">20</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    generations <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize population</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    population <span class="op">=</span> []</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(population_size):</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        trial <span class="op">=</span> optuna.trial.create_trial(</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>            params<span class="op">=</span>{</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>                <span class="st">'lr'</span>: np.random.uniform(<span class="fl">1e-5</span>, <span class="fl">1e-1</span>),</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>                <span class="st">'batch_size'</span>: np.random.choice([<span class="dv">16</span>, <span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>]),</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>                <span class="st">'weight_decay'</span>: np.random.uniform(<span class="fl">1e-6</span>, <span class="fl">1e-2</span>)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        population.append(trial)</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> generation <span class="kw">in</span> <span class="bu">range</span>(generations):</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate population</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        fitness_scores <span class="op">=</span> []</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> trial <span class="kw">in</span> population:</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>            model <span class="op">=</span> create_model(trial)</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>            score <span class="op">=</span> train_and_evaluate(model, trial.params)</span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>            fitness_scores.append(score)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select top performers</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>        top_indices <span class="op">=</span> np.argsort(fitness_scores)[<span class="op">-</span>population_size<span class="op">//</span><span class="dv">2</span>:]</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create new population</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        new_population <span class="op">=</span> []</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> idx <span class="kw">in</span> top_indices:</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>            new_population.append(population[idx])</span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mutate and add to population</span></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(population_size <span class="op">-</span> <span class="bu">len</span>(new_population)):</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>            parent <span class="op">=</span> np.random.choice(new_population)</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>            child <span class="op">=</span> mutate_hyperparameters(parent)</span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>            new_population.append(child)</span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a>        population <span class="op">=</span> new_population</span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> population</span></code></pre></div></div>
</section>
</section>
<section id="visualization-and-analysis" class="level2">
<h2 class="anchored" data-anchor-id="visualization-and-analysis" id="visualization-and-analysis">Visualization and Analysis</h2>
<section id="study-analysis" class="level3">
<h3 class="anchored" data-anchor-id="study-analysis" id="study-analysis">Study Analysis</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic study analysis</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Number of finished trials: </span><span class="sc">{</span><span class="bu">len</span>(study.trials)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Best trial: </span><span class="sc">{</span>study<span class="sc">.</span>best_trial<span class="sc">.</span>number<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Best value: </span><span class="sc">{</span>study<span class="sc">.</span>best_value<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Best parameters: </span><span class="sc">{</span>study<span class="sc">.</span>best_params<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Parameter importance</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>importance <span class="op">=</span> optuna.importance.get_param_importances(study)</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Parameter importance:"</span>)</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> param, imp <span class="kw">in</span> importance.items():</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>param<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>imp<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Visualization</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> optuna.visualization <span class="im">as</span> vis</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimization history</span></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>fig <span class="op">=</span> vis.plot_optimization_history(study)</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>fig.show()</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Parameter importance plot</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>fig <span class="op">=</span> vis.plot_param_importances(study)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>fig.show()</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Parameter relationships</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>fig <span class="op">=</span> vis.plot_parallel_coordinate(study)</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>fig.show()</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Hyperparameter slice plot</span></span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>fig <span class="op">=</span> vis.plot_slice(study)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>fig.show()</span></code></pre></div></div>
</section>
<section id="custom-metrics-tracking" class="level3">
<h3 class="anchored" data-anchor-id="custom-metrics-tracking" id="custom-metrics-tracking">Custom Metrics Tracking</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CustomCallback:</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics <span class="op">=</span> {}</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, study, trial):</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Track custom metrics</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[trial.number] <span class="op">=</span> {</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>            <span class="st">'params'</span>: trial.params,</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">'value'</span>: trial.value,</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'state'</span>: trial.state,</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'duration'</span>: trial.duration</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Custom analysis</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(study.trials) <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.analyze_progress(study)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_progress(<span class="va">self</span>, study):</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convergence analysis</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>        values <span class="op">=</span> [t.value <span class="cf">for</span> t <span class="kw">in</span> study.trials <span class="cf">if</span> t.state <span class="op">==</span> optuna.trial.TrialState.COMPLETE]</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(values) <span class="op">&gt;</span> <span class="dv">10</span>:</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>            improvement <span class="op">=</span> values[<span class="op">-</span><span class="dv">1</span>] <span class="op">-</span> values[<span class="op">-</span><span class="dv">11</span>]</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Improvement over last 10 trials: </span><span class="sc">{</span>improvement<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Use custom callback</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>callback <span class="op">=</span> CustomCallback()</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>study.optimize(objective, n_trials<span class="op">=</span><span class="dv">100</span>, callbacks<span class="op">=</span>[callback])</span></code></pre></div></div>
</section>
</section>
<section id="best-practices-and-tips" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-tips" id="best-practices-and-tips">Best Practices and Tips</h2>
<section id="study-configuration" class="level3">
<h3 class="anchored" data-anchor-id="study-configuration" id="study-configuration">Study Configuration</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimal study configuration</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>study <span class="op">=</span> optuna.create_study(</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    direction<span class="op">=</span><span class="st">'maximize'</span>,</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    sampler<span class="op">=</span>optuna.samplers.TPESampler(</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        n_startup_trials<span class="op">=</span><span class="dv">20</span>,  <span class="co"># Random trials before TPE</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        n_ei_candidates<span class="op">=</span><span class="dv">24</span>,   <span class="co"># Candidates for EI</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        multivariate<span class="op">=</span><span class="va">True</span>,    <span class="co"># Consider parameter interactions</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        seed<span class="op">=</span><span class="dv">42</span>              <span class="co"># Reproducibility</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    ),</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    pruner<span class="op">=</span>optuna.pruners.MedianPruner(</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        n_startup_trials<span class="op">=</span><span class="dv">10</span>,  <span class="co"># Minimum trials before pruning</span></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        n_warmup_steps<span class="op">=</span><span class="dv">5</span>,     <span class="co"># Steps before considering pruning</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        interval_steps<span class="op">=</span><span class="dv">1</span>      <span class="co"># Frequency of pruning checks</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">Memory Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_objective(trial):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Clear GPU memory</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache()</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use gradient checkpointing</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model(trial)</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    model.gradient_checkpointing_enable()</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Mixed precision training</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    scaler <span class="op">=</span> torch.cuda.amp.GradScaler()</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Training loop</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Cleanup</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">del</span> model</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache()</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span></code></pre></div></div>
</section>
<section id="reproducibility" class="level3">
<h3 class="anchored" data-anchor-id="reproducibility" id="reproducibility">Reproducibility</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> set_seed(seed<span class="op">=</span><span class="dv">42</span>):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> random</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    random.seed(seed)</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    np.random.seed(seed)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    torch.manual_seed(seed)</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    torch.cuda.manual_seed_all(seed)</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    torch.backends.cudnn.deterministic <span class="op">=</span> <span class="va">True</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    torch.backends.cudnn.benchmark <span class="op">=</span> <span class="va">False</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> reproducible_objective(trial):</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set seed for reproducibility</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    set_seed(trial.suggest_int(<span class="st">'seed'</span>, <span class="dv">0</span>, <span class="dv">10000</span>))</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Rest of objective function</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</section>
</section>
<section id="real-world-applications" class="level2">
<h2 class="anchored" data-anchor-id="real-world-applications" id="real-world-applications">Real-World Applications</h2>
<section id="computer-vision-nas" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-nas" id="computer-vision-nas">Computer Vision NAS</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> vision_nas_objective(trial):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data augmentation search</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    augmentation_policy <span class="op">=</span> {</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">'rotation'</span>: trial.suggest_float(<span class="st">'rotation'</span>, <span class="dv">0</span>, <span class="dv">30</span>),</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">'brightness'</span>: trial.suggest_float(<span class="st">'brightness'</span>, <span class="fl">0.8</span>, <span class="fl">1.2</span>),</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">'contrast'</span>: trial.suggest_float(<span class="st">'contrast'</span>, <span class="fl">0.8</span>, <span class="fl">1.2</span>),</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'saturation'</span>: trial.suggest_float(<span class="st">'saturation'</span>, <span class="fl">0.8</span>, <span class="fl">1.2</span>),</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'hue'</span>: trial.suggest_float(<span class="st">'hue'</span>, <span class="op">-</span><span class="fl">0.1</span>, <span class="fl">0.1</span>)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Architecture search</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>    backbone <span class="op">=</span> trial.suggest_categorical(<span class="st">'backbone'</span>, [<span class="st">'resnet'</span>, <span class="st">'efficientnet'</span>, <span class="st">'mobilenet'</span>])</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> backbone <span class="op">==</span> <span class="st">'resnet'</span>:</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        depth <span class="op">=</span> trial.suggest_categorical(<span class="st">'depth'</span>, [<span class="dv">18</span>, <span class="dv">34</span>, <span class="dv">50</span>, <span class="dv">101</span>])</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_resnet(depth)</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> backbone <span class="op">==</span> <span class="st">'efficientnet'</span>:</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        version <span class="op">=</span> trial.suggest_categorical(<span class="st">'version'</span>, [<span class="st">'b0'</span>, <span class="st">'b1'</span>, <span class="st">'b2'</span>, <span class="st">'b3'</span>])</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_efficientnet(version)</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>        width_mult <span class="op">=</span> trial.suggest_float(<span class="st">'width_mult'</span>, <span class="fl">0.25</span>, <span class="fl">2.0</span>)</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> create_mobilenet(width_mult)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training strategy</span></span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    training_strategy <span class="op">=</span> {</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>        <span class="st">'optimizer'</span>: trial.suggest_categorical(<span class="st">'optimizer'</span>, [<span class="st">'adam'</span>, <span class="st">'sgd'</span>, <span class="st">'adamw'</span>]),</span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>        <span class="st">'lr_schedule'</span>: trial.suggest_categorical(<span class="st">'lr_schedule'</span>, [<span class="st">'cosine'</span>, <span class="st">'step'</span>, <span class="st">'exponential'</span>]),</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>        <span class="st">'weight_decay'</span>: trial.suggest_float(<span class="st">'weight_decay'</span>, <span class="fl">1e-6</span>, <span class="fl">1e-2</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_vision_model(model, augmentation_policy, training_strategy)</span></code></pre></div></div>
</section>
<section id="nlp-architecture-search" class="level3">
<h3 class="anchored" data-anchor-id="nlp-architecture-search" id="nlp-architecture-search">NLP Architecture Search</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> nlp_nas_objective(trial):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Transformer architecture search</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    config <span class="op">=</span> {</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">'num_layers'</span>: trial.suggest_int(<span class="st">'num_layers'</span>, <span class="dv">4</span>, <span class="dv">12</span>),</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">'num_heads'</span>: trial.suggest_categorical(<span class="st">'num_heads'</span>, [<span class="dv">4</span>, <span class="dv">8</span>, <span class="dv">12</span>, <span class="dv">16</span>]),</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">'hidden_size'</span>: trial.suggest_categorical(<span class="st">'hidden_size'</span>, [<span class="dv">256</span>, <span class="dv">512</span>, <span class="dv">768</span>, <span class="dv">1024</span>]),</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'ffn_size'</span>: trial.suggest_categorical(<span class="st">'ffn_size'</span>, [<span class="dv">1024</span>, <span class="dv">2048</span>, <span class="dv">3072</span>, <span class="dv">4096</span>]),</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'dropout'</span>: trial.suggest_float(<span class="st">'dropout'</span>, <span class="fl">0.0</span>, <span class="fl">0.3</span>),</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'attention_dropout'</span>: trial.suggest_float(<span class="st">'attention_dropout'</span>, <span class="fl">0.0</span>, <span class="fl">0.3</span>)</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Positional encoding</span></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    pos_encoding <span class="op">=</span> trial.suggest_categorical(<span class="st">'pos_encoding'</span>, [<span class="st">'learned'</span>, <span class="st">'sinusoidal'</span>, <span class="st">'rotary'</span>])</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Activation function</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>    activation <span class="op">=</span> trial.suggest_categorical(<span class="st">'activation'</span>, [<span class="st">'gelu'</span>, <span class="st">'relu'</span>, <span class="st">'swish'</span>])</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_transformer(config, pos_encoding, activation)</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training hyperparameters</span></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> trial.suggest_float(<span class="st">'lr'</span>, <span class="fl">1e-5</span>, <span class="fl">1e-3</span>, log<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>    warmup_steps <span class="op">=</span> trial.suggest_int(<span class="st">'warmup_steps'</span>, <span class="dv">1000</span>, <span class="dv">10000</span>)</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_nlp_model(model, lr, warmup_steps)</span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Optuna provides a powerful, flexible framework for hyperparameter optimization and neural architecture search in deep learning. Its sophisticated algorithms, pruning capabilities, and extensive integration ecosystem make it an essential tool for modern ML practitioners.</p>
<p>Key takeaways:</p>
<ol type="1">
<li><strong>Start Simple</strong>: Begin with basic hyperparameter optimization before moving to complex NAS</li>
<li><strong>Use Pruning</strong>: Implement pruning to save computational resources</li>
<li><strong>Leverage Distributed Computing</strong>: Scale optimization across multiple GPUs/nodes</li>
<li><strong>Monitor Progress</strong>: Use visualization tools to understand optimization dynamics</li>
<li><strong>Consider Multi-Objective</strong>: Balance multiple criteria like accuracy and efficiency</li>
<li><strong>Reproducibility</strong>: Set seeds and use consistent evaluation protocols</li>
</ol>
<p>The future of automated ML lies in intelligent optimization frameworks like Optuna, which democratize access to state-of-the-art hyperparameter tuning and architecture search techniques. By mastering these tools, practitioners can focus on higher-level design decisions while letting algorithms handle the tedious parameter optimization process.</p>
<p>Whether you’re working on computer vision, NLP, or other domains, Optuna’s flexibility and power make it an invaluable addition to your deep learning toolkit. Start with the basic examples provided here, then gradually incorporate more advanced techniques as your optimization needs grow in complexity.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Python’s functools Module]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-functools/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-functools/</guid>
      <pubDate>Sun, 06 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-pythons-functools-module" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-functools/func.png" class="img-fluid"></p>
<p>The <code>functools</code> module in Python provides utilities for working with higher-order functions and operations on callable objects. It’s a powerful toolkit for functional programming patterns, performance optimization, and code organization.</p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The <code>functools</code> module is part of Python’s standard library and provides essential tools for functional programming. It helps you create more efficient, reusable, and maintainable code by offering utilities for function manipulation, caching, and composition. It’s particularly useful for:</p>
<ul>
<li>Creating decorators</li>
<li>Implementing caching mechanisms</li>
<li>Partial function application</li>
<li>Functional programming patterns</li>
<li>Performance optimization</li>
</ul>
<div id="91d580d4" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span></code></pre></div></div>
</div>
</section>
<section id="core-decorators" class="level2">
<h2 class="anchored" data-anchor-id="core-decorators" id="core-decorators">Core Decorators</h2>
<section id="functools.wraps" class="level3">
<h3 class="anchored" data-anchor-id="functools.wraps" id="functools.wraps"><span class="citation" data-cites="functools.wraps">@functools.wraps</span></h3>
<p>The <code>@functools.wraps</code> decorator is fundamental for creating proper decorators. It copies metadata from the original function to the wrapper function, preserving important attributes like <code>__name__</code>, <code>__doc__</code>, and <code>__module__</code>.</p>
<div id="465e879d" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> my_decorator(func):</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Calling </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="at">@my_decorator</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> greet(name):</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Greet someone by name."""</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Hello, </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">!"</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(greet.<span class="va">__name__</span>)  <span class="co"># Output: greet</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(greet.__doc__)   <span class="co"># Output: Greet someone by name.</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>greet
Greet someone by name.</code></pre>
</div>
</div>
<p>Without <code>@functools.wraps</code>, the wrapper function would lose the original function’s metadata:</p>
<div id="598af1a4" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> bad_decorator(func):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Calling function"</span>)</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="at">@bad_decorator</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> say_hello(name):</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Say hello to someone."""</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Hello, </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">!"</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(say_hello.<span class="va">__name__</span>)  <span class="co"># Output: wrapper (not say_hello!)</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(say_hello.__doc__)   <span class="co"># Output: None</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>wrapper
None</code></pre>
</div>
</div>
</section>
<section id="functools.lru_cache" class="level3">
<h3 class="anchored" data-anchor-id="functools.lru_cache" id="functools.lru_cache"><span class="citation" data-cites="functools.lru_cache">@functools.lru_cache</span></h3>
<p>The <code>@functools.lru_cache</code> decorator implements a Least Recently Used (LRU) cache for function results. It’s excellent for optimizing recursive functions and expensive computations.</p>
<div id="dfa0f790" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci(n):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Calculate Fibonacci number with memoization."""</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;</span> <span class="dv">2</span>:</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> n</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fibonacci(n <span class="op">-</span> <span class="dv">1</span>) <span class="op">+</span> fibonacci(n <span class="op">-</span> <span class="dv">2</span>)</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Performance comparison</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci_slow(n):</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Fibonacci without caching."""</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;</span> <span class="dv">2</span>:</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> n</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fibonacci_slow(n <span class="op">-</span> <span class="dv">1</span>) <span class="op">+</span> fibonacci_slow(n <span class="op">-</span> <span class="dv">2</span>)</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Cached version</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>start <span class="op">=</span> time.time()</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>result_fast <span class="op">=</span> fibonacci(<span class="dv">35</span>)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>fast_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Clear cache and test uncached version</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>fibonacci.cache_clear()</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>start <span class="op">=</span> time.time()</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>result_slow <span class="op">=</span> fibonacci_slow(<span class="dv">35</span>)</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>slow_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Cached result: </span><span class="sc">{</span>result_fast<span class="sc">}</span><span class="ss"> (Time: </span><span class="sc">{</span>fast_time<span class="sc">:.4f}</span><span class="ss">s)"</span>)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Uncached result: </span><span class="sc">{</span>result_slow<span class="sc">}</span><span class="ss"> (Time: </span><span class="sc">{</span>slow_time<span class="sc">:.4f}</span><span class="ss">s)"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Cached result: 9227465 (Time: 0.0000s)
Uncached result: 9227465 (Time: 0.6688s)</code></pre>
</div>
</div>
<section id="cache-management" class="level4">
<h4 class="anchored" data-anchor-id="cache-management">Cache Management</h4>
<p>The <code>lru_cache</code> decorator provides methods for cache management:</p>
<div id="5f7611b3" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> expensive_function(x, y):</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simulate an expensive computation."""</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    time.sleep(<span class="fl">0.1</span>)  <span class="co"># Simulate work</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">*</span> y <span class="op">+</span> x <span class="op">**</span> y</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Use the function</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>result1 <span class="op">=</span> expensive_function(<span class="dv">2</span>, <span class="dv">3</span>)</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>result2 <span class="op">=</span> expensive_function(<span class="dv">2</span>, <span class="dv">3</span>)  <span class="co"># This will be cached</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Check cache statistics</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(expensive_function.cache_info())</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Output: CacheInfo(hits=1, misses=1, maxsize=128, currsize=1)</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Clear the cache</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>expensive_function.cache_clear()</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(expensive_function.cache_info())</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Output: CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>CacheInfo(hits=1, misses=1, maxsize=128, currsize=1)
CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)</code></pre>
</div>
</div>
</section>
</section>
<section id="functools.cache-python-3.9" class="level3">
<h3 class="anchored" data-anchor-id="functools.cache-python-3.9" id="functools.cache-python-3.9"><span class="citation" data-cites="functools.cache">@functools.cache</span> (Python 3.9+)</h3>
<p>The <code>@functools.cache</code> decorator is a simplified version of <code>lru_cache</code> with no size limit:</p>
<div id="25b86b63" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.cache</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> factorial(n):</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Calculate factorial with unlimited caching."""</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;=</span> <span class="dv">1</span>:</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">1</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">*</span> factorial(n <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(factorial(<span class="dv">10</span>))  <span class="co"># 3628800</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(factorial.cache_info())</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>3628800
CacheInfo(hits=0, misses=10, maxsize=None, currsize=10)</code></pre>
</div>
</div>
</section>
<section id="functools.cached_property" class="level3">
<h3 class="anchored" data-anchor-id="functools.cached_property" id="functools.cached_property"><span class="citation" data-cites="functools.cached_property">@functools.cached_property</span></h3>
<p>Transforms a method into a property that caches its result after the first call.</p>
<div id="8ef02842" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DataProcessor:</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data):</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data <span class="op">=</span> data</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.cached_property</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> processed_data(<span class="va">self</span>):</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Expensive data processing that should only run once"""</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Processing data..."</span>)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="dv">1</span>)  <span class="co"># Simulate expensive operation</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [x <span class="op">*</span> <span class="dv">2</span> <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.data]</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> DataProcessor([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>])</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(processor.processed_data)  <span class="co"># Takes 1 second</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(processor.processed_data)  <span class="co"># Instant, uses cached result</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Processing data...
[2, 4, 6, 8, 10]
[2, 4, 6, 8, 10]</code></pre>
</div>
</div>
</section>
</section>
<section id="partial-function-application" class="level2">
<h2 class="anchored" data-anchor-id="partial-function-application" id="partial-function-application">Partial Function Application</h2>
<section id="functools.partial" class="level3">
<h3 class="anchored" data-anchor-id="functools.partial" id="functools.partial">functools.partial</h3>
<p>The <code>functools.partial</code> function creates partial function applications, allowing you to fix certain arguments of a function and create a new callable.</p>
<div id="b5e4bbf2" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multiply(x, y, z):</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Multiply three numbers."""</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">*</span> y <span class="op">*</span> z</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a partial function that always multiplies by 2 and 3</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>double_triple <span class="op">=</span> functools.partial(multiply, <span class="dv">2</span>, <span class="dv">3</span>)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(double_triple(<span class="dv">4</span>))  <span class="co"># Output: 24 (2 * 3 * 4)</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a><span class="co"># You can also fix keyword arguments</span></span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> greet(greeting, name, punctuation<span class="op">=</span><span class="st">"!"</span>):</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"</span><span class="sc">{</span>greeting<span class="sc">}</span><span class="ss">, </span><span class="sc">{</span>name<span class="sc">}{</span>punctuation<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a partial for casual greetings</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>casual_greet <span class="op">=</span> functools.partial(greet, <span class="st">"Hey"</span>, punctuation<span class="op">=</span><span class="st">"."</span>)</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(casual_greet(<span class="st">"Alice"</span>))  <span class="co"># Output: Hey, Alice.</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>24
Hey, Alice.</code></pre>
</div>
</div>
</section>
<section id="practical-example-event-handling" class="level3">
<h3 class="anchored" data-anchor-id="practical-example-event-handling" id="practical-example-event-handling">Practical Example: Event Handling</h3>
<div id="e10771a6" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> handle_event(event_type, handler_name, data):</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Generic event handler."""</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"[</span><span class="sc">{</span>event_type<span class="sc">}</span><span class="ss">] </span><span class="sc">{</span>handler_name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>data<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Create specific event handlers</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>handle_click <span class="op">=</span> functools.partial(handle_event, <span class="st">"CLICK"</span>)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>handle_keypress <span class="op">=</span> functools.partial(handle_event, <span class="st">"KEYPRESS"</span>)</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Use the handlers</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>button_click <span class="op">=</span> functools.partial(handle_click, <span class="st">"button_handler"</span>)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>input_keypress <span class="op">=</span> functools.partial(handle_keypress, <span class="st">"input_handler"</span>)</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>button_click(<span class="st">"Button was clicked"</span>)</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>input_keypress(<span class="st">"Enter key pressed"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[CLICK] button_handler: Button was clicked
[KEYPRESS] input_handler: Enter key pressed</code></pre>
</div>
</div>
</section>
<section id="functools.partialmethod" class="level3">
<h3 class="anchored" data-anchor-id="functools.partialmethod" id="functools.partialmethod">functools.partialmethod</h3>
<p>The <code>functools.partialmethod</code> is designed for creating partial methods in classes:</p>
<div id="d113076a" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Calculator:</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.result <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> operation(<span class="va">self</span>, op, value):</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> op <span class="op">==</span> <span class="st">"add"</span>:</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.result <span class="op">+=</span> value</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op <span class="op">==</span> <span class="st">"multiply"</span>:</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.result <span class="op">*=</span> value</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> op <span class="op">==</span> <span class="st">"subtract"</span>:</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.result <span class="op">-=</span> value</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.result</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create partial methods</span></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    add <span class="op">=</span> functools.partialmethod(operation, <span class="st">"add"</span>)</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    multiply <span class="op">=</span> functools.partialmethod(operation, <span class="st">"multiply"</span>)</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>    subtract <span class="op">=</span> functools.partialmethod(operation, <span class="st">"subtract"</span>)</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>calc <span class="op">=</span> Calculator()</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>calc.add(<span class="dv">5</span>)        <span class="co"># result = 5</span></span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>calc.multiply(<span class="dv">3</span>)   <span class="co"># result = 15</span></span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>calc.subtract(<span class="dv">2</span>)   <span class="co"># result = 13</span></span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(calc.result) <span class="co"># Output: 13</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>13</code></pre>
</div>
</div>
</section>
</section>
<section id="comparison-and-ordering" class="level2">
<h2 class="anchored" data-anchor-id="comparison-and-ordering" id="comparison-and-ordering">Comparison and Ordering</h2>
<section id="functools.total_ordering" class="level3">
<h3 class="anchored" data-anchor-id="functools.total_ordering" id="functools.total_ordering">functools.total_ordering</h3>
<p>The <code>@functools.total_ordering</code> decorator automatically generates comparison methods based on <code>__eq__</code> and one ordering method:</p>
<div id="544302fb" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.total_ordering</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Student:</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, name, grade):</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.name <span class="op">=</span> name</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.grade <span class="op">=</span> grade</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__eq__</span>(<span class="va">self</span>, other):</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="bu">isinstance</span>(other, Student):</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">NotImplemented</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.grade <span class="op">==</span> other.grade</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__lt__</span>(<span class="va">self</span>, other):</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="bu">isinstance</span>(other, Student):</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">NotImplemented</span></span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.grade <span class="op">&lt;</span> other.grade</span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__repr__</span>(<span class="va">self</span>):</span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Student('</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>name<span class="sc">}</span><span class="ss">', </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>grade<span class="sc">}</span><span class="ss">)"</span></span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Now all comparison operators work</span></span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>alice <span class="op">=</span> Student(<span class="st">"Alice"</span>, <span class="dv">85</span>)</span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>bob <span class="op">=</span> Student(<span class="st">"Bob"</span>, <span class="dv">92</span>)</span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>charlie <span class="op">=</span> Student(<span class="st">"Charlie"</span>, <span class="dv">85</span>)</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">&lt;</span> bob)      <span class="co"># True</span></span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">&gt;</span> bob)      <span class="co"># False</span></span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">&lt;=</span> bob)     <span class="co"># True</span></span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">&gt;=</span> bob)     <span class="co"># False</span></span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">==</span> charlie) <span class="co"># True</span></span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(alice <span class="op">!=</span> bob)     <span class="co"># True</span></span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Sorting works too</span></span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>students <span class="op">=</span> [bob, alice, charlie]</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a>students.sort()</span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(students)  <span class="co"># [Student('Alice', 85), Student('Charlie', 85), Student('Bob', 92)]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>True
False
True
False
True
True
[Student('Alice', 85), Student('Charlie', 85), Student('Bob', 92)]</code></pre>
</div>
</div>
</section>
<section id="functools.cmp_to_key" class="level3">
<h3 class="anchored" data-anchor-id="functools.cmp_to_key" id="functools.cmp_to_key">functools.cmp_to_key</h3>
<p>The <code>functools.cmp_to_key</code> function converts old-style comparison functions to key functions for use with sorting:</p>
<div id="e8619a1a" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compare_strings(a, b):</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Old-style comparison function."""</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compare by length first, then alphabetically</span></span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(a) <span class="op">!=</span> <span class="bu">len</span>(b):</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(a) <span class="op">-</span> <span class="bu">len</span>(b)</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> a <span class="op">&lt;</span> b:</span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="op">-</span><span class="dv">1</span></span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> a <span class="op">&gt;</span> b:</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">1</span></span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="dv">0</span></span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to key function</span></span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>key_func <span class="op">=</span> functools.cmp_to_key(compare_strings)</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>words <span class="op">=</span> [<span class="st">"apple"</span>, <span class="st">"pie"</span>, <span class="st">"banana"</span>, <span class="st">"cat"</span>, <span class="st">"elephant"</span>]</span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>sorted_words <span class="op">=</span> <span class="bu">sorted</span>(words, key<span class="op">=</span>key_func)</span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(sorted_words)  <span class="co"># ['cat', 'pie', 'apple', 'banana', 'elephant']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>['cat', 'pie', 'apple', 'banana', 'elephant']</code></pre>
</div>
</div>
</section>
</section>
<section id="caching-and-memoization" class="level2">
<h2 class="anchored" data-anchor-id="caching-and-memoization" id="caching-and-memoization">Caching and Memoization</h2>
<section id="advanced-caching-strategies" class="level3">
<h3 class="anchored" data-anchor-id="advanced-caching-strategies" id="advanced-caching-strategies">Advanced Caching Strategies</h3>
<div id="83ecae05" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Any, Callable</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timed_cache(seconds: <span class="bu">int</span>):</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Custom decorator for time-based caching."""</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func: Callable) <span class="op">-&gt;</span> Callable:</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>        cache <span class="op">=</span> {}</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Create a key from arguments</span></span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>            key <span class="op">=</span> <span class="bu">str</span>(args) <span class="op">+</span> <span class="bu">str</span>(<span class="bu">sorted</span>(kwargs.items()))</span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a>            current_time <span class="op">=</span> time.time()</span>
<span id="cb24-15"><a href="#cb24-15" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-16"><a href="#cb24-16" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Check if result is cached and still valid</span></span>
<span id="cb24-17"><a href="#cb24-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> key <span class="kw">in</span> cache:</span>
<span id="cb24-18"><a href="#cb24-18" aria-hidden="true" tabindex="-1"></a>                result, timestamp <span class="op">=</span> cache[key]</span>
<span id="cb24-19"><a href="#cb24-19" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> current_time <span class="op">-</span> timestamp <span class="op">&lt;</span> seconds:</span>
<span id="cb24-20"><a href="#cb24-20" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">return</span> result</span>
<span id="cb24-21"><a href="#cb24-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-22"><a href="#cb24-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Calculate new result and cache it</span></span>
<span id="cb24-23"><a href="#cb24-23" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb24-24"><a href="#cb24-24" aria-hidden="true" tabindex="-1"></a>            cache[key] <span class="op">=</span> (result, current_time)</span>
<span id="cb24-25"><a href="#cb24-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb24-26"><a href="#cb24-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-27"><a href="#cb24-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb24-28"><a href="#cb24-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb24-29"><a href="#cb24-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-30"><a href="#cb24-30" aria-hidden="true" tabindex="-1"></a><span class="at">@timed_cache</span>(seconds<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb24-31"><a href="#cb24-31" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_current_time():</span>
<span id="cb24-32"><a href="#cb24-32" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Get current time (cached for 5 seconds)."""</span></span>
<span id="cb24-33"><a href="#cb24-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> time.time()</span>
<span id="cb24-34"><a href="#cb24-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-35"><a href="#cb24-35" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the timed cache</span></span>
<span id="cb24-36"><a href="#cb24-36" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(get_current_time())  <span class="co"># Fresh calculation</span></span>
<span id="cb24-37"><a href="#cb24-37" aria-hidden="true" tabindex="-1"></a>time.sleep(<span class="dv">2</span>)</span>
<span id="cb24-38"><a href="#cb24-38" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(get_current_time())  <span class="co"># Cached result (same as above)</span></span>
<span id="cb24-39"><a href="#cb24-39" aria-hidden="true" tabindex="-1"></a>time.sleep(<span class="dv">4</span>)</span>
<span id="cb24-40"><a href="#cb24-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(get_current_time())  <span class="co"># Fresh calculation (cache expired)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>1751948377.506505
1751948377.506505
1751948383.515122</code></pre>
</div>
</div>
</section>
<section id="cache-with-custom-key-function" class="level3">
<h3 class="anchored" data-anchor-id="cache-with-custom-key-function" id="cache-with-custom-key-function">Cache with Custom Key Function</h3>
<div id="6e3965d1" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> custom_cache(key_func<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Cache decorator with custom key function."""</span></span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a>        cache <span class="op">=</span> {}</span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> key_func:</span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>                key <span class="op">=</span> key_func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>                key <span class="op">=</span> <span class="bu">str</span>(args) <span class="op">+</span> <span class="bu">str</span>(<span class="bu">sorted</span>(kwargs.items()))</span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb26-15"><a href="#cb26-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> key <span class="kw">in</span> cache:</span>
<span id="cb26-16"><a href="#cb26-16" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> cache[key]</span>
<span id="cb26-17"><a href="#cb26-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb26-18"><a href="#cb26-18" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb26-19"><a href="#cb26-19" aria-hidden="true" tabindex="-1"></a>            cache[key] <span class="op">=</span> result</span>
<span id="cb26-20"><a href="#cb26-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb26-21"><a href="#cb26-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-22"><a href="#cb26-22" aria-hidden="true" tabindex="-1"></a>        wrapper.cache_clear <span class="op">=</span> cache.clear</span>
<span id="cb26-23"><a href="#cb26-23" aria-hidden="true" tabindex="-1"></a>        wrapper.cache_info <span class="op">=</span> <span class="kw">lambda</span>: <span class="ss">f"Cache size: </span><span class="sc">{</span><span class="bu">len</span>(cache)<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb26-24"><a href="#cb26-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb26-25"><a href="#cb26-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb26-26"><a href="#cb26-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-27"><a href="#cb26-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Cache based on first argument only</span></span>
<span id="cb26-28"><a href="#cb26-28" aria-hidden="true" tabindex="-1"></a><span class="at">@custom_cache</span>(key_func<span class="op">=</span><span class="kw">lambda</span> x, y: x)</span>
<span id="cb26-29"><a href="#cb26-29" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> expensive_computation(x, y):</span>
<span id="cb26-30"><a href="#cb26-30" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Expensive computation cached by first argument only."""</span></span>
<span id="cb26-31"><a href="#cb26-31" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Computing for </span><span class="sc">{</span>x<span class="sc">}</span><span class="ss">, </span><span class="sc">{</span>y<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb26-32"><a href="#cb26-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">**</span> y</span>
<span id="cb26-33"><a href="#cb26-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-34"><a href="#cb26-34" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(expensive_computation(<span class="dv">2</span>, <span class="dv">3</span>))  <span class="co"># Computing for 2, 3 -&gt; 8</span></span>
<span id="cb26-35"><a href="#cb26-35" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(expensive_computation(<span class="dv">2</span>, <span class="dv">5</span>))  <span class="co"># Uses cached result -&gt; 8 (wrong but demonstrates key function)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Computing for 2, 3
8
8</code></pre>
</div>
</div>
</section>
</section>
<section id="function-composition" class="level2">
<h2 class="anchored" data-anchor-id="function-composition" id="function-composition">Function Composition</h2>
<section id="functools.reduce" class="level3">
<h3 class="anchored" data-anchor-id="functools.reduce" id="functools.reduce">functools.reduce</h3>
<p>The <code>functools.reduce</code> function applies a function cumulatively to items in a sequence:</p>
<div id="de96e327" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> operator</span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Sum all numbers</span></span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a>total <span class="op">=</span> functools.<span class="bu">reduce</span>(operator.add, numbers)</span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(total)  <span class="co"># Output: 15</span></span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Find maximum</span></span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a>maximum <span class="op">=</span> functools.<span class="bu">reduce</span>(<span class="kw">lambda</span> x, y: x <span class="cf">if</span> x <span class="op">&gt;</span> y <span class="cf">else</span> y, numbers)</span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(maximum)  <span class="co"># Output: 5</span></span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiply all numbers</span></span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a>product <span class="op">=</span> functools.<span class="bu">reduce</span>(operator.mul, numbers)</span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(product)  <span class="co"># Output: 120</span></span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Flatten nested lists</span></span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>nested_lists <span class="op">=</span> [[<span class="dv">1</span>, <span class="dv">2</span>], [<span class="dv">3</span>, <span class="dv">4</span>], [<span class="dv">5</span>, <span class="dv">6</span>]]</span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>flattened <span class="op">=</span> functools.<span class="bu">reduce</span>(operator.add, nested_lists)</span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(flattened)  <span class="co"># Output: [1, 2, 3, 4, 5, 6]</span></span>
<span id="cb28-21"><a href="#cb28-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-22"><a href="#cb28-22" aria-hidden="true" tabindex="-1"></a><span class="co"># With initial value</span></span>
<span id="cb28-23"><a href="#cb28-23" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> functools.<span class="bu">reduce</span>(operator.add, numbers, <span class="dv">100</span>)</span>
<span id="cb28-24"><a href="#cb28-24" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)  <span class="co"># Output: 115 (100 + 15)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>15
5
120
[1, 2, 3, 4, 5, 6]
115</code></pre>
</div>
</div>
</section>
<section id="building-complex-operations" class="level3">
<h3 class="anchored" data-anchor-id="building-complex-operations" id="building-complex-operations">Building Complex Operations</h3>
<div id="581ec335" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> operator</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compose(<span class="op">*</span>functions):</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Compose multiple functions into a single function."""</span></span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> functools.<span class="bu">reduce</span>(<span class="kw">lambda</span> f, g: <span class="kw">lambda</span> x: f(g(x)), functions, <span class="kw">lambda</span> x: x)</span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-8"><a href="#cb30-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Example functions</span></span>
<span id="cb30-9"><a href="#cb30-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> add_one(x):</span>
<span id="cb30-10"><a href="#cb30-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb30-11"><a href="#cb30-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-12"><a href="#cb30-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multiply_by_two(x):</span>
<span id="cb30-13"><a href="#cb30-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb30-14"><a href="#cb30-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-15"><a href="#cb30-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> square(x):</span>
<span id="cb30-16"><a href="#cb30-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb30-17"><a href="#cb30-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-18"><a href="#cb30-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Compose functions</span></span>
<span id="cb30-19"><a href="#cb30-19" aria-hidden="true" tabindex="-1"></a>composed <span class="op">=</span> compose(square, multiply_by_two, add_one)</span>
<span id="cb30-20"><a href="#cb30-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(composed(<span class="dv">3</span>))  <span class="co"># ((3 + 1) * 2) ** 2 = 64</span></span>
<span id="cb30-21"><a href="#cb30-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-22"><a href="#cb30-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Dictionary operations with reduce</span></span>
<span id="cb30-23"><a href="#cb30-23" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> merge_dicts(<span class="op">*</span>dicts):</span>
<span id="cb30-24"><a href="#cb30-24" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Merge multiple dictionaries."""</span></span>
<span id="cb30-25"><a href="#cb30-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> functools.<span class="bu">reduce</span>(</span>
<span id="cb30-26"><a href="#cb30-26" aria-hidden="true" tabindex="-1"></a>        <span class="kw">lambda</span> acc, d: {<span class="op">**</span>acc, <span class="op">**</span>d}, </span>
<span id="cb30-27"><a href="#cb30-27" aria-hidden="true" tabindex="-1"></a>        dicts, </span>
<span id="cb30-28"><a href="#cb30-28" aria-hidden="true" tabindex="-1"></a>        {}</span>
<span id="cb30-29"><a href="#cb30-29" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb30-30"><a href="#cb30-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-31"><a href="#cb30-31" aria-hidden="true" tabindex="-1"></a>dict1 <span class="op">=</span> {<span class="st">"a"</span>: <span class="dv">1</span>, <span class="st">"b"</span>: <span class="dv">2</span>}</span>
<span id="cb30-32"><a href="#cb30-32" aria-hidden="true" tabindex="-1"></a>dict2 <span class="op">=</span> {<span class="st">"c"</span>: <span class="dv">3</span>, <span class="st">"d"</span>: <span class="dv">4</span>}</span>
<span id="cb30-33"><a href="#cb30-33" aria-hidden="true" tabindex="-1"></a>dict3 <span class="op">=</span> {<span class="st">"e"</span>: <span class="dv">5</span>, <span class="st">"f"</span>: <span class="dv">6</span>}</span>
<span id="cb30-34"><a href="#cb30-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-35"><a href="#cb30-35" aria-hidden="true" tabindex="-1"></a>merged <span class="op">=</span> merge_dicts(dict1, dict2, dict3)</span>
<span id="cb30-36"><a href="#cb30-36" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(merged)  <span class="co"># {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6}</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>64
{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6}</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-usage-patterns" class="level2">
<h2 class="anchored" data-anchor-id="advanced-usage-patterns" id="advanced-usage-patterns">Advanced Usage Patterns</h2>
<section id="decorator-factories" class="level3">
<h3 class="anchored" data-anchor-id="decorator-factories" id="decorator-factories">Decorator Factories</h3>
<div id="65b62df4" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> retry(max_attempts<span class="op">=</span><span class="dv">3</span>, delay<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Decorator factory for retrying failed operations."""</span></span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> attempt <span class="kw">in</span> <span class="bu">range</span>(max_attempts):</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a>                <span class="cf">try</span>:</span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>                <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> attempt <span class="op">==</span> max_attempts <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">raise</span> e</span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(<span class="ss">f"Attempt </span><span class="sc">{</span>attempt <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss"> failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">. Retrying in </span><span class="sc">{</span>delay<span class="sc">}</span><span class="ss">s..."</span>)</span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>                    time.sleep(delay)</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb32-18"><a href="#cb32-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb32-19"><a href="#cb32-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb32-20"><a href="#cb32-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-21"><a href="#cb32-21" aria-hidden="true" tabindex="-1"></a><span class="at">@retry</span>(max_attempts<span class="op">=</span><span class="dv">3</span>, delay<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb32-22"><a href="#cb32-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> unreliable_function():</span>
<span id="cb32-23"><a href="#cb32-23" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Function that fails randomly."""</span></span>
<span id="cb32-24"><a href="#cb32-24" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> random</span>
<span id="cb32-25"><a href="#cb32-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.7</span>:</span>
<span id="cb32-26"><a href="#cb32-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">Exception</span>(<span class="st">"Random failure"</span>)</span>
<span id="cb32-27"><a href="#cb32-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"Success!"</span></span>
<span id="cb32-28"><a href="#cb32-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-29"><a href="#cb32-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the retry decorator</span></span>
<span id="cb32-30"><a href="#cb32-30" aria-hidden="true" tabindex="-1"></a><span class="co"># result = unreliable_function()  # May retry up to 3 times</span></span></code></pre></div></div>
</div>
</section>
<section id="method-decorators" class="level3">
<h3 class="anchored" data-anchor-id="method-decorators" id="method-decorators">Method Decorators</h3>
<div id="371d3637" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb33"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ValidationError(<span class="pp">Exception</span>):</span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate_positive(func):</span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Decorator to validate that arguments are positive."""</span></span>
<span id="cb33-8"><a href="#cb33-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb33-9"><a href="#cb33-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb33-10"><a href="#cb33-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> arg <span class="kw">in</span> args:</span>
<span id="cb33-11"><a href="#cb33-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(arg, (<span class="bu">int</span>, <span class="bu">float</span>)) <span class="kw">and</span> arg <span class="op">&lt;=</span> <span class="dv">0</span>:</span>
<span id="cb33-12"><a href="#cb33-12" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> ValidationError(<span class="ss">f"Argument </span><span class="sc">{</span>arg<span class="sc">}</span><span class="ss"> must be positive"</span>)</span>
<span id="cb33-13"><a href="#cb33-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> func(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb33-14"><a href="#cb33-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb33-15"><a href="#cb33-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-16"><a href="#cb33-16" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Calculator:</span>
<span id="cb33-17"><a href="#cb33-17" aria-hidden="true" tabindex="-1"></a>    <span class="at">@validate_positive</span></span>
<span id="cb33-18"><a href="#cb33-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> divide(<span class="va">self</span>, a, b):</span>
<span id="cb33-19"><a href="#cb33-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Divide two positive numbers."""</span></span>
<span id="cb33-20"><a href="#cb33-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> a <span class="op">/</span> b</span>
<span id="cb33-21"><a href="#cb33-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb33-22"><a href="#cb33-22" aria-hidden="true" tabindex="-1"></a>    <span class="at">@validate_positive</span></span>
<span id="cb33-23"><a href="#cb33-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> sqrt(<span class="va">self</span>, x):</span>
<span id="cb33-24"><a href="#cb33-24" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Calculate square root of a positive number."""</span></span>
<span id="cb33-25"><a href="#cb33-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">**</span> <span class="fl">0.5</span></span>
<span id="cb33-26"><a href="#cb33-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-27"><a href="#cb33-27" aria-hidden="true" tabindex="-1"></a>calc <span class="op">=</span> Calculator()</span>
<span id="cb33-28"><a href="#cb33-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(calc.divide(<span class="dv">10</span>, <span class="dv">2</span>))  <span class="co"># 5.0</span></span>
<span id="cb33-29"><a href="#cb33-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(calc.sqrt(<span class="dv">16</span>))       <span class="co"># 4.0</span></span>
<span id="cb33-30"><a href="#cb33-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-31"><a href="#cb33-31" aria-hidden="true" tabindex="-1"></a><span class="co"># This will raise ValidationError</span></span>
<span id="cb33-32"><a href="#cb33-32" aria-hidden="true" tabindex="-1"></a><span class="co"># calc.divide(-5, 2)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>5.0
4.0</code></pre>
</div>
</div>
</section>
<section id="contextual-decorators" class="level3">
<h3 class="anchored" data-anchor-id="contextual-decorators" id="contextual-decorators">Contextual Decorators</h3>
<div id="93d1b884" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_calls(logger<span class="op">=</span><span class="va">None</span>, level<span class="op">=</span>logging.INFO):</span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Decorator to log function calls."""</span></span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> logger <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a>        logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb35-8"><a href="#cb35-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb35-9"><a href="#cb35-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb35-10"><a href="#cb35-10" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb35-11"><a href="#cb35-11" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb35-12"><a href="#cb35-12" aria-hidden="true" tabindex="-1"></a>            logger.log(level, <span class="ss">f"Calling </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> with args=</span><span class="sc">{</span>args<span class="sc">}</span><span class="ss">, kwargs=</span><span class="sc">{</span>kwargs<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb35-13"><a href="#cb35-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb35-14"><a href="#cb35-14" aria-hidden="true" tabindex="-1"></a>                result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb35-15"><a href="#cb35-15" aria-hidden="true" tabindex="-1"></a>                logger.log(level, <span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> returned </span><span class="sc">{</span>result<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb35-16"><a href="#cb35-16" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> result</span>
<span id="cb35-17"><a href="#cb35-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb35-18"><a href="#cb35-18" aria-hidden="true" tabindex="-1"></a>                logger.log(logging.ERROR, <span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> raised </span><span class="sc">{</span><span class="bu">type</span>(e)<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb35-19"><a href="#cb35-19" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span></span>
<span id="cb35-20"><a href="#cb35-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb35-21"><a href="#cb35-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb35-22"><a href="#cb35-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-23"><a href="#cb35-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup logging</span></span>
<span id="cb35-24"><a href="#cb35-24" aria-hidden="true" tabindex="-1"></a>logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb35-25"><a href="#cb35-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-26"><a href="#cb35-26" aria-hidden="true" tabindex="-1"></a><span class="at">@log_calls</span>()</span>
<span id="cb35-27"><a href="#cb35-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> calculate_area(width, height):</span>
<span id="cb35-28"><a href="#cb35-28" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Calculate area of a rectangle."""</span></span>
<span id="cb35-29"><a href="#cb35-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> width <span class="op">*</span> height</span>
<span id="cb35-30"><a href="#cb35-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-31"><a href="#cb35-31" aria-hidden="true" tabindex="-1"></a><span class="at">@log_calls</span>(level<span class="op">=</span>logging.DEBUG)</span>
<span id="cb35-32"><a href="#cb35-32" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> divide_numbers(a, b):</span>
<span id="cb35-33"><a href="#cb35-33" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Divide two numbers."""</span></span>
<span id="cb35-34"><a href="#cb35-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> a <span class="op">/</span> b</span>
<span id="cb35-35"><a href="#cb35-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-36"><a href="#cb35-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the logged functions</span></span>
<span id="cb35-37"><a href="#cb35-37" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> calculate_area(<span class="dv">5</span>, <span class="dv">3</span>)</span>
<span id="cb35-38"><a href="#cb35-38" aria-hidden="true" tabindex="-1"></a><span class="co"># result = divide_numbers(10, 0)  # This will log an error</span></span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>INFO:__main__:Calling calculate_area with args=(5, 3), kwargs={}
INFO:__main__:calculate_area returned 15</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="functools.singledispatch" class="level3">
<h3 class="anchored" data-anchor-id="functools.singledispatch" id="functools.singledispatch">functools.singledispatch</h3>
<p>Creates generic functions that behave differently based on the type of their first argument.</p>
<div id="5e8c43a0" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.singledispatch</span></span>
<span id="cb37-4"><a href="#cb37-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_data(data):</span>
<span id="cb37-5"><a href="#cb37-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Default implementation for unknown types"""</span></span>
<span id="cb37-6"><a href="#cb37-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Processing unknown type: </span><span class="sc">{</span><span class="bu">type</span>(data)<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb37-7"><a href="#cb37-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-8"><a href="#cb37-8" aria-hidden="true" tabindex="-1"></a><span class="at">@process_data.register</span>(<span class="bu">str</span>)</span>
<span id="cb37-9"><a href="#cb37-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> _(data):</span>
<span id="cb37-10"><a href="#cb37-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Processing string: '</span><span class="sc">{</span>data<span class="sc">}</span><span class="ss">'"</span></span>
<span id="cb37-11"><a href="#cb37-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-12"><a href="#cb37-12" aria-hidden="true" tabindex="-1"></a><span class="at">@process_data.register</span>(<span class="bu">list</span>)</span>
<span id="cb37-13"><a href="#cb37-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> _(data):</span>
<span id="cb37-14"><a href="#cb37-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Processing list of </span><span class="sc">{</span><span class="bu">len</span>(data)<span class="sc">}</span><span class="ss"> items"</span></span>
<span id="cb37-15"><a href="#cb37-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-16"><a href="#cb37-16" aria-hidden="true" tabindex="-1"></a><span class="at">@process_data.register</span>(<span class="bu">dict</span>)</span>
<span id="cb37-17"><a href="#cb37-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> _(data):</span>
<span id="cb37-18"><a href="#cb37-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Processing dict with keys: </span><span class="sc">{</span><span class="bu">list</span>(data.keys())<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb37-19"><a href="#cb37-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-20"><a href="#cb37-20" aria-hidden="true" tabindex="-1"></a><span class="at">@process_data.register</span>(<span class="bu">int</span>)</span>
<span id="cb37-21"><a href="#cb37-21" aria-hidden="true" tabindex="-1"></a><span class="at">@process_data.register</span>(<span class="bu">float</span>)</span>
<span id="cb37-22"><a href="#cb37-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> _(data):</span>
<span id="cb37-23"><a href="#cb37-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Processing number: </span><span class="sc">{</span>data<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb37-24"><a href="#cb37-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-25"><a href="#cb37-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb37-26"><a href="#cb37-26" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(process_data(<span class="st">"hello"</span>))           <span class="co"># Processing string: 'hello'</span></span>
<span id="cb37-27"><a href="#cb37-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(process_data([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>]))         <span class="co"># Processing list of 3 items</span></span>
<span id="cb37-28"><a href="#cb37-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(process_data({<span class="st">"a"</span>: <span class="dv">1</span>, <span class="st">"b"</span>: <span class="dv">2</span>}))  <span class="co"># Processing dict with keys: ['a', 'b']</span></span>
<span id="cb37-29"><a href="#cb37-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(process_data(<span class="dv">42</span>))                <span class="co"># Processing number: 42</span></span>
<span id="cb37-30"><a href="#cb37-30" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(process_data(<span class="fl">3.14</span>))              <span class="co"># Processing number: 3.14</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Processing string: 'hello'
Processing list of 3 items
Processing dict with keys: ['a', 'b']
Processing number: 42
Processing number: 3.14</code></pre>
</div>
</div>
</section>
<section id="functools.singledispatchmethod" class="level3">
<h3 class="anchored" data-anchor-id="functools.singledispatchmethod" id="functools.singledispatchmethod">functools.singledispatchmethod</h3>
<p>Similar to singledispatch but for methods in classes.</p>
<div id="7fba186c" class="cell" data-execution_count="21">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DataProcessor:</span>
<span id="cb39-4"><a href="#cb39-4" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.singledispatchmethod</span></span>
<span id="cb39-5"><a href="#cb39-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> process(<span class="va">self</span>, data):</span>
<span id="cb39-6"><a href="#cb39-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Default processing for </span><span class="sc">{</span><span class="bu">type</span>(data)<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb39-7"><a href="#cb39-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb39-8"><a href="#cb39-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">@process.register</span></span>
<span id="cb39-9"><a href="#cb39-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _(<span class="va">self</span>, data: <span class="bu">str</span>):</span>
<span id="cb39-10"><a href="#cb39-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"String processing: </span><span class="sc">{</span>data<span class="sc">.</span>upper()<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb39-11"><a href="#cb39-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb39-12"><a href="#cb39-12" aria-hidden="true" tabindex="-1"></a>    <span class="at">@process.register</span></span>
<span id="cb39-13"><a href="#cb39-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _(<span class="va">self</span>, data: <span class="bu">list</span>):</span>
<span id="cb39-14"><a href="#cb39-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"List processing: </span><span class="sc">{</span><span class="bu">sum</span>(data) <span class="cf">if</span> <span class="bu">all</span>(<span class="bu">isinstance</span>(x, (<span class="bu">int</span>, <span class="bu">float</span>)) <span class="cf">for</span> x <span class="kw">in</span> data) <span class="cf">else</span> <span class="st">'mixed types'</span><span class="sc">}</span><span class="ss">"</span></span>
<span id="cb39-15"><a href="#cb39-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb39-16"><a href="#cb39-16" aria-hidden="true" tabindex="-1"></a>    <span class="at">@process.register</span></span>
<span id="cb39-17"><a href="#cb39-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _(<span class="va">self</span>, data: <span class="bu">dict</span>):</span>
<span id="cb39-18"><a href="#cb39-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Dict processing: </span><span class="sc">{</span><span class="bu">len</span>(data)<span class="sc">}</span><span class="ss"> items"</span></span>
<span id="cb39-19"><a href="#cb39-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-20"><a href="#cb39-20" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> DataProcessor()</span>
<span id="cb39-21"><a href="#cb39-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(processor.process(<span class="st">"hello"</span>))      <span class="co"># String processing: HELLO</span></span>
<span id="cb39-22"><a href="#cb39-22" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(processor.process([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>])) <span class="co"># List processing: 10</span></span>
<span id="cb39-23"><a href="#cb39-23" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(processor.process({<span class="st">"a"</span>: <span class="dv">1</span>}))     <span class="co"># Dict processing: 1 items</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>String processing: HELLO
List processing: 10
Dict processing: 1 items</code></pre>
</div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="use-functools.wraps-in-custom-decorators" class="level3">
<h3 class="anchored" data-anchor-id="use-functools.wraps-in-custom-decorators" id="use-functools.wraps-in-custom-decorators">1. Use <span class="citation" data-cites="functools.wraps">@functools.wraps</span> in Custom Decorators</h3>
<p>Always use <code>@functools.wraps</code> when creating decorators to preserve function metadata:</p>
<div id="8f2f4e49" class="cell" data-execution_count="22">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Good</span></span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> my_decorator(func):</span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb41-7"><a href="#cb41-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># decorator logic here</span></span>
<span id="cb41-8"><a href="#cb41-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb41-9"><a href="#cb41-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb41-10"><a href="#cb41-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-11"><a href="#cb41-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Bad - loses function metadata</span></span>
<span id="cb41-12"><a href="#cb41-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> bad_decorator(func):</span>
<span id="cb41-13"><a href="#cb41-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb41-14"><a href="#cb41-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># decorator logic here</span></span>
<span id="cb41-15"><a href="#cb41-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb41-16"><a href="#cb41-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span></code></pre></div></div>
</div>
</section>
<section id="choose-appropriate-cache-sizes" class="level3">
<h3 class="anchored" data-anchor-id="choose-appropriate-cache-sizes" id="choose-appropriate-cache-sizes">2. Choose Appropriate Cache Sizes</h3>
<p>For <code>lru_cache</code>, choose cache sizes based on your use case:</p>
<div id="fa5752ce" class="cell" data-execution_count="23">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb42"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb42-1"><a href="#cb42-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb42-2"><a href="#cb42-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb42-3"><a href="#cb42-3" aria-hidden="true" tabindex="-1"></a><span class="co"># For small, frequently accessed data</span></span>
<span id="cb42-4"><a href="#cb42-4" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb42-5"><a href="#cb42-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_user_preferences(user_id):</span>
<span id="cb42-6"><a href="#cb42-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Small cache for user data</span></span>
<span id="cb42-7"><a href="#cb42-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb42-8"><a href="#cb42-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb42-9"><a href="#cb42-9" aria-hidden="true" tabindex="-1"></a><span class="co"># For larger datasets or expensive computations</span></span>
<span id="cb42-10"><a href="#cb42-10" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb42-11"><a href="#cb42-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> complex_calculation(x, y, z):</span>
<span id="cb42-12"><a href="#cb42-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Larger cache for expensive operations</span></span>
<span id="cb42-13"><a href="#cb42-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb42-14"><a href="#cb42-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb42-15"><a href="#cb42-15" aria-hidden="true" tabindex="-1"></a><span class="co"># For unlimited caching (use with caution)</span></span>
<span id="cb42-16"><a href="#cb42-16" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.cache</span></span>
<span id="cb42-17"><a href="#cb42-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> constant_computation(x):</span>
<span id="cb42-18"><a href="#cb42-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Only for truly constant results</span></span>
<span id="cb42-19"><a href="#cb42-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</div>
</section>
<section id="choose-the-right-caching-strategy" class="level3">
<h3 class="anchored" data-anchor-id="choose-the-right-caching-strategy" id="choose-the-right-caching-strategy">3. Choose the Right Caching Strategy</h3>
<div id="281d0ad5" class="cell" data-execution_count="24">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb43"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb43-1"><a href="#cb43-1" aria-hidden="true" tabindex="-1"></a><span class="co"># For simple cases without arguments</span></span>
<span id="cb43-2"><a href="#cb43-2" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.cache</span></span>
<span id="cb43-3"><a href="#cb43-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> simple_function():</span>
<span id="cb43-4"><a href="#cb43-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb43-5"><a href="#cb43-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb43-6"><a href="#cb43-6" aria-hidden="true" tabindex="-1"></a><span class="co"># For functions with arguments and limited cache size</span></span>
<span id="cb43-7"><a href="#cb43-7" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb43-8"><a href="#cb43-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> complex_function(x, y):</span>
<span id="cb43-9"><a href="#cb43-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb43-10"><a href="#cb43-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb43-11"><a href="#cb43-11" aria-hidden="true" tabindex="-1"></a><span class="co"># For properties in classes</span></span>
<span id="cb43-12"><a href="#cb43-12" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MyClass:</span>
<span id="cb43-13"><a href="#cb43-13" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.cached_property</span></span>
<span id="cb43-14"><a href="#cb43-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> expensive_property(<span class="va">self</span>):</span>
<span id="cb43-15"><a href="#cb43-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</div>
</section>
<section id="use-partial-functions-for-configuration" class="level3">
<h3 class="anchored" data-anchor-id="use-partial-functions-for-configuration" id="use-partial-functions-for-configuration">4. Use Partial Functions for Configuration</h3>
<div id="ba907d4b" class="cell" data-execution_count="25">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb44"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb44-1"><a href="#cb44-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb44-2"><a href="#cb44-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb44-3"><a href="#cb44-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb44-4"><a href="#cb44-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> make_api_call(base_url, endpoint, headers<span class="op">=</span><span class="va">None</span>, timeout<span class="op">=</span><span class="dv">30</span>):</span>
<span id="cb44-5"><a href="#cb44-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Make an API call with configurable parameters."""</span></span>
<span id="cb44-6"><a href="#cb44-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Implementation here</span></span>
<span id="cb44-7"><a href="#cb44-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb44-8"><a href="#cb44-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb44-9"><a href="#cb44-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Create configured API callers</span></span>
<span id="cb44-10"><a href="#cb44-10" aria-hidden="true" tabindex="-1"></a>api_v1 <span class="op">=</span> functools.partial(</span>
<span id="cb44-11"><a href="#cb44-11" aria-hidden="true" tabindex="-1"></a>    make_api_call,</span>
<span id="cb44-12"><a href="#cb44-12" aria-hidden="true" tabindex="-1"></a>    base_url<span class="op">=</span><span class="st">"https://api.example.com/v1"</span>,</span>
<span id="cb44-13"><a href="#cb44-13" aria-hidden="true" tabindex="-1"></a>    headers<span class="op">=</span>{<span class="st">"Authorization"</span>: <span class="st">"Bearer token123"</span>}</span>
<span id="cb44-14"><a href="#cb44-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb44-15"><a href="#cb44-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb44-16"><a href="#cb44-16" aria-hidden="true" tabindex="-1"></a>api_v2 <span class="op">=</span> functools.partial(</span>
<span id="cb44-17"><a href="#cb44-17" aria-hidden="true" tabindex="-1"></a>    make_api_call,</span>
<span id="cb44-18"><a href="#cb44-18" aria-hidden="true" tabindex="-1"></a>    base_url<span class="op">=</span><span class="st">"https://api.example.com/v2"</span>,</span>
<span id="cb44-19"><a href="#cb44-19" aria-hidden="true" tabindex="-1"></a>    headers<span class="op">=</span>{<span class="st">"Authorization"</span>: <span class="st">"Bearer token456"</span>},</span>
<span id="cb44-20"><a href="#cb44-20" aria-hidden="true" tabindex="-1"></a>    timeout<span class="op">=</span><span class="dv">60</span></span>
<span id="cb44-21"><a href="#cb44-21" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb44-22"><a href="#cb44-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb44-23"><a href="#cb44-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Use the configured functions</span></span>
<span id="cb44-24"><a href="#cb44-24" aria-hidden="true" tabindex="-1"></a><span class="co"># result1 = api_v1("/users")</span></span>
<span id="cb44-25"><a href="#cb44-25" aria-hidden="true" tabindex="-1"></a><span class="co"># result2 = api_v2("/products")</span></span></code></pre></div></div>
</div>
</section>
<section id="performance-considerations" class="level3">
<h3 class="anchored" data-anchor-id="performance-considerations" id="performance-considerations">5. Performance Considerations</h3>
<div id="fea402f4" class="cell" data-execution_count="26">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb45"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Measure cache performance</span></span>
<span id="cb45-5"><a href="#cb45-5" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">1000</span>)</span>
<span id="cb45-6"><a href="#cb45-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> expensive_function(n):</span>
<span id="cb45-7"><a href="#cb45-7" aria-hidden="true" tabindex="-1"></a>    time.sleep(<span class="fl">0.01</span>)  <span class="co"># Simulate expensive operation</span></span>
<span id="cb45-8"><a href="#cb45-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb45-9"><a href="#cb45-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb45-10"><a href="#cb45-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Time uncached vs cached calls</span></span>
<span id="cb45-11"><a href="#cb45-11" aria-hidden="true" tabindex="-1"></a>start <span class="op">=</span> time.time()</span>
<span id="cb45-12"><a href="#cb45-12" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb45-13"><a href="#cb45-13" aria-hidden="true" tabindex="-1"></a>    expensive_function(i <span class="op">%</span> <span class="dv">10</span>)  <span class="co"># Only 10 unique values</span></span>
<span id="cb45-14"><a href="#cb45-14" aria-hidden="true" tabindex="-1"></a>end <span class="op">=</span> time.time()</span>
<span id="cb45-15"><a href="#cb45-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb45-16"><a href="#cb45-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Time taken: </span><span class="sc">{</span>end <span class="op">-</span> start<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb45-17"><a href="#cb45-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Cache info: </span><span class="sc">{</span>expensive_function<span class="sc">.</span>cache_info()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Time taken: 0.1245 seconds
Cache info: CacheInfo(hits=90, misses=10, maxsize=1000, currsize=10)</code></pre>
</div>
</div>
</section>
<section id="combine-multiple-functools-features" class="level3">
<h3 class="anchored" data-anchor-id="combine-multiple-functools-features" id="combine-multiple-functools-features">6. Combine Multiple functools Features</h3>
<div id="4a5016d7" class="cell" data-execution_count="27">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb47"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb47-1"><a href="#cb47-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb47-2"><a href="#cb47-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb47-3"><a href="#cb47-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-4"><a href="#cb47-4" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb47-5"><a href="#cb47-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci_cached(n):</span>
<span id="cb47-6"><a href="#cb47-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Fibonacci with caching."""</span></span>
<span id="cb47-7"><a href="#cb47-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;</span> <span class="dv">2</span>:</span>
<span id="cb47-8"><a href="#cb47-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> n</span>
<span id="cb47-9"><a href="#cb47-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fibonacci_cached(n <span class="op">-</span> <span class="dv">1</span>) <span class="op">+</span> fibonacci_cached(n <span class="op">-</span> <span class="dv">2</span>)</span>
<span id="cb47-10"><a href="#cb47-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-11"><a href="#cb47-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a partial function for specific range</span></span>
<span id="cb47-12"><a href="#cb47-12" aria-hidden="true" tabindex="-1"></a>fibonacci_small <span class="op">=</span> functools.partial(fibonacci_cached)</span>
<span id="cb47-13"><a href="#cb47-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-14"><a href="#cb47-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Use total_ordering for comparison</span></span>
<span id="cb47-15"><a href="#cb47-15" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.total_ordering</span></span>
<span id="cb47-16"><a href="#cb47-16" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FibonacciNumber:</span>
<span id="cb47-17"><a href="#cb47-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, n):</span>
<span id="cb47-18"><a href="#cb47-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.n <span class="op">=</span> n</span>
<span id="cb47-19"><a href="#cb47-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.value <span class="op">=</span> fibonacci_cached(n)</span>
<span id="cb47-20"><a href="#cb47-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb47-21"><a href="#cb47-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__eq__</span>(<span class="va">self</span>, other):</span>
<span id="cb47-22"><a href="#cb47-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.value <span class="op">==</span> other.value</span>
<span id="cb47-23"><a href="#cb47-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb47-24"><a href="#cb47-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__lt__</span>(<span class="va">self</span>, other):</span>
<span id="cb47-25"><a href="#cb47-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.value <span class="op">&lt;</span> other.value</span>
<span id="cb47-26"><a href="#cb47-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb47-27"><a href="#cb47-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__repr__</span>(<span class="va">self</span>):</span>
<span id="cb47-28"><a href="#cb47-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Fib(</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>n<span class="sc">}</span><span class="ss">) = </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>value<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb47-29"><a href="#cb47-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-30"><a href="#cb47-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb47-31"><a href="#cb47-31" aria-hidden="true" tabindex="-1"></a>fib_numbers <span class="op">=</span> [FibonacciNumber(i) <span class="cf">for</span> i <span class="kw">in</span> [<span class="dv">8</span>, <span class="dv">5</span>, <span class="dv">10</span>, <span class="dv">3</span>]]</span>
<span id="cb47-32"><a href="#cb47-32" aria-hidden="true" tabindex="-1"></a>fib_numbers.sort()</span>
<span id="cb47-33"><a href="#cb47-33" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fib_numbers)  <span class="co"># Sorted by Fibonacci value</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[Fib(3) = 2, Fib(5) = 5, Fib(8) = 21, Fib(10) = 55]</code></pre>
</div>
</div>
</section>
<section id="error-handling-with-functools" class="level3">
<h3 class="anchored" data-anchor-id="error-handling-with-functools" id="error-handling-with-functools">7. Error Handling with functools</h3>
<div id="6e6b3b50" class="cell" data-execution_count="28">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb49"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb49-1"><a href="#cb49-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb49-2"><a href="#cb49-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-3"><a href="#cb49-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_divide(func):</span>
<span id="cb49-4"><a href="#cb49-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Decorator to handle division by zero."""</span></span>
<span id="cb49-5"><a href="#cb49-5" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb49-6"><a href="#cb49-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb49-7"><a href="#cb49-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb49-8"><a href="#cb49-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb49-9"><a href="#cb49-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">ZeroDivisionError</span>:</span>
<span id="cb49-10"><a href="#cb49-10" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Warning: Division by zero in </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb49-11"><a href="#cb49-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="bu">float</span>(<span class="st">'inf'</span>)</span>
<span id="cb49-12"><a href="#cb49-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb49-13"><a href="#cb49-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-14"><a href="#cb49-14" aria-hidden="true" tabindex="-1"></a><span class="at">@safe_divide</span></span>
<span id="cb49-15"><a href="#cb49-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> calculate_ratio(a, b):</span>
<span id="cb49-16"><a href="#cb49-16" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Calculate the ratio of two numbers."""</span></span>
<span id="cb49-17"><a href="#cb49-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> a <span class="op">/</span> b</span>
<span id="cb49-18"><a href="#cb49-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-19"><a href="#cb49-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(calculate_ratio(<span class="dv">10</span>, <span class="dv">2</span>))  <span class="co"># 5.0</span></span>
<span id="cb49-20"><a href="#cb49-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(calculate_ratio(<span class="dv">10</span>, <span class="dv">0</span>))  <span class="co"># inf (with warning)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>5.0
Warning: Division by zero in calculate_ratio
inf</code></pre>
</div>
</div>
</section>
<section id="debugging-cached-functions" class="level3">
<h3 class="anchored" data-anchor-id="debugging-cached-functions" id="debugging-cached-functions">8. Debugging Cached Functions</h3>
<div id="59c133d2" class="cell" data-execution_count="29">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb51"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb51-1"><a href="#cb51-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb51-2"><a href="#cb51-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-3"><a href="#cb51-3" aria-hidden="true" tabindex="-1"></a><span class="at">@functools.lru_cache</span>(maxsize<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb51-4"><a href="#cb51-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> debug_function(x):</span>
<span id="cb51-5"><a href="#cb51-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Computing for </span><span class="sc">{</span>x<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb51-6"><a href="#cb51-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb51-7"><a href="#cb51-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-8"><a href="#cb51-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor cache usage</span></span>
<span id="cb51-9"><a href="#cb51-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> print_cache_stats(func):</span>
<span id="cb51-10"><a href="#cb51-10" aria-hidden="true" tabindex="-1"></a>    info <span class="op">=</span> func.cache_info()</span>
<span id="cb51-11"><a href="#cb51-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Cache stats for </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>info<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb51-12"><a href="#cb51-12" aria-hidden="true" tabindex="-1"></a>    hit_rate <span class="op">=</span> info.hits <span class="op">/</span> (info.hits <span class="op">+</span> info.misses) <span class="cf">if</span> (info.hits <span class="op">+</span> info.misses) <span class="op">&gt;</span> <span class="dv">0</span> <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb51-13"><a href="#cb51-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Hit rate: </span><span class="sc">{</span>hit_rate<span class="sc">:.2%}</span><span class="ss">"</span>)</span>
<span id="cb51-14"><a href="#cb51-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-15"><a href="#cb51-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb51-16"><a href="#cb51-16" aria-hidden="true" tabindex="-1"></a>debug_function(<span class="dv">5</span>)</span>
<span id="cb51-17"><a href="#cb51-17" aria-hidden="true" tabindex="-1"></a>debug_function(<span class="dv">5</span>)  <span class="co"># Uses cache</span></span>
<span id="cb51-18"><a href="#cb51-18" aria-hidden="true" tabindex="-1"></a>debug_function(<span class="dv">10</span>)</span>
<span id="cb51-19"><a href="#cb51-19" aria-hidden="true" tabindex="-1"></a>print_cache_stats(debug_function)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Computing for 5
Computing for 10
Cache stats for debug_function: CacheInfo(hits=1, misses=2, maxsize=128, currsize=2)
Hit rate: 33.33%</code></pre>
</div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The <code>functools</code> module is an essential tool for Python developers who want to write more efficient, maintainable, and functional code. Key takeaways include:</p>
<ul>
<li>Use <code>@functools.wraps</code> in all custom decorators</li>
<li>Leverage <code>@functools.lru_cache</code> for expensive function calls</li>
<li>Apply <code>functools.partial</code> for function configuration and specialization</li>
<li>Utilize <code>@functools.total_ordering</code> to reduce boilerplate in comparison classes</li>
<li>Employ <code>functools.reduce</code> for complex data transformations</li>
<li>Combine multiple functools features for powerful programming patterns</li>
<li>Apply <code>@cached_property</code> for expensive class properties</li>
<li>Use <code>partial</code> for function specialization</li>
<li>Implement <code>@singledispatch</code> for type-based function overloading</li>
</ul>
<p>By mastering these tools, you’ll be able to write more elegant and efficient Python code that follows functional programming principles while maintaining readability and performance.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Complete Guide to Python’s itertools Module]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-itertools/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-itertools/</guid>
      <pubDate>Sun, 06 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="complete-guide-to-pythons-itertools-module" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-itertools/iter.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The <code>itertools</code> module is one of Python’s most powerful standard library modules for creating iterators and performing functional programming operations. It provides a collection of tools for creating iterators that are building blocks for efficient loops and data processing pipelines.</p>
<p>The <code>itertools</code> module provides three categories of iterators:</p>
<ul>
<li><strong>Infinite iterators</strong>: Generate infinite sequences</li>
<li><strong>Finite iterators</strong>: Work with finite sequences</li>
<li><strong>Combinatorial iterators</strong>: Generate combinations and permutations</li>
</ul>
<div id="2ddb35bc" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> itertools</span></code></pre></div></div>
</div>
<section id="other-necessary-imports" class="level3">
<h3 class="anchored" data-anchor-id="other-necessary-imports" id="other-necessary-imports">Other necessary imports</h3>
<div id="23afdf0e" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span></code></pre></div></div>
</div>
</section>
</section>
<section id="why-use-itertools" class="level2">
<h2 class="anchored" data-anchor-id="why-use-itertools" id="why-use-itertools">Why Use itertools?</h2>
<ul>
<li><strong>Memory Efficient</strong>: Creates iterators that generate values on-demand</li>
<li><strong>Functional Programming</strong>: Enables elegant functional programming patterns</li>
<li><strong>Performance</strong>: Many operations are implemented in C for speed</li>
<li><strong>Composability</strong>: Functions can be easily combined to create complex iterations</li>
</ul>
</section>
<section id="categories-of-itertools-functions" class="level2">
<h2 class="anchored" data-anchor-id="categories-of-itertools-functions" id="categories-of-itertools-functions">Categories of itertools Functions</h2>
<p>The itertools module is organized into three main categories:</p>
<ol type="1">
<li><strong>Infinite Iterators</strong>: Generate infinite sequences</li>
<li><strong>Finite Iterators</strong>: Terminate based on input sequences</li>
<li><strong>Combinatorial Iterators</strong>: Generate combinations and permutations</li>
</ol>
<hr>
</section>
<section id="infinite-iterators" class="level2">
<h2 class="anchored" data-anchor-id="infinite-iterators" id="infinite-iterators">1. Infinite Iterators</h2>
<section id="countstart0-step1" class="level3">
<h3 class="anchored" data-anchor-id="countstart0-step1" id="countstart0-step1">count(start=0, step=1)</h3>
<p>Creates an infinite arithmetic sequence starting from <code>start</code> with increments of <code>step</code>.</p>
<div id="161c4139" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> itertools</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic counting</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>counter <span class="op">=</span> itertools.count(<span class="dv">1</span>)</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(counter, <span class="dv">5</span>)))  <span class="co"># [1, 2, 3, 4, 5]</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Counting with step</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>counter <span class="op">=</span> itertools.count(<span class="dv">0</span>, <span class="dv">2</span>)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(counter, <span class="dv">5</span>)))  <span class="co"># [0, 2, 4, 6, 8]</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Counting with floats</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>counter <span class="op">=</span> itertools.count(<span class="fl">0.5</span>, <span class="fl">0.1</span>)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(counter, <span class="dv">3</span>)))  <span class="co"># [0.5, 0.6, 0.7]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5]
[0, 2, 4, 6, 8]
[0.5, 0.6, 0.7]</code></pre>
</div>
</div>
<p><strong>Use Case</strong>: Generating IDs, pagination, or any sequence that needs infinite counting.</p>
</section>
<section id="cycleiterable" class="level3">
<h3 class="anchored" data-anchor-id="cycleiterable" id="cycleiterable">cycle(iterable)</h3>
<p>Infinitely repeats the elements of an iterable.</p>
<div id="593c4e10" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>colors <span class="op">=</span> itertools.cycle([<span class="st">'red'</span>, <span class="st">'green'</span>, <span class="st">'blue'</span>])</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(colors, <span class="dv">8</span>)))</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co"># ['red', 'green', 'blue', 'red', 'green', 'blue', 'red', 'green']</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Round-robin assignment</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>tasks <span class="op">=</span> [<span class="st">'task1'</span>, <span class="st">'task2'</span>, <span class="st">'task3'</span>, <span class="st">'task4'</span>]</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>workers <span class="op">=</span> itertools.cycle([<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>])</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>assignments <span class="op">=</span> <span class="bu">list</span>(<span class="bu">zip</span>(tasks, workers))</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(assignments)</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="co"># [('task1', 'Alice'), ('task2', 'Bob'), ('task3', 'Charlie'), ('task4', 'Alice')]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>['red', 'green', 'blue', 'red', 'green', 'blue', 'red', 'green']
[('task1', 'Alice'), ('task2', 'Bob'), ('task3', 'Charlie'), ('task4', 'Alice')]</code></pre>
</div>
</div>
</section>
<section id="repeatobject-timesnone" class="level3">
<h3 class="anchored" data-anchor-id="repeatobject-timesnone" id="repeatobject-timesnone">repeat(object, times=None)</h3>
<p>Repeats an object either infinitely or a specified number of times.</p>
<div id="4f16d7bc" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Infinite repeat</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>ones <span class="op">=</span> itertools.repeat(<span class="dv">1</span>)</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(ones, <span class="dv">5</span>)))  <span class="co"># [1, 1, 1, 1, 1]</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Finite repeat</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>zeros <span class="op">=</span> itertools.repeat(<span class="dv">0</span>, <span class="dv">3</span>)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(zeros))  <span class="co"># [0, 0, 0]</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Creating default values</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>default_config <span class="op">=</span> {<span class="st">'debug'</span>: <span class="va">False</span>, <span class="st">'timeout'</span>: <span class="dv">30</span>}</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>configs <span class="op">=</span> <span class="bu">list</span>(itertools.repeat(default_config, <span class="dv">5</span>))</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">len</span>(configs))  <span class="co"># 5</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 1, 1, 1, 1]
[0, 0, 0]
5</code></pre>
</div>
</div>
<hr>
</section>
</section>
<section id="finite-iterators" class="level2">
<h2 class="anchored" data-anchor-id="finite-iterators" id="finite-iterators">2. Finite Iterators</h2>
<section id="accumulateiterable-funcoperator.add-initialnone" class="level3">
<h3 class="anchored" data-anchor-id="accumulateiterable-funcoperator.add-initialnone" id="accumulateiterable-funcoperator.add-initialnone">accumulate(iterable, func=operator.add, initial=None)</h3>
<p>Returns running totals or results of binary functions.</p>
<div id="5fcd167d" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> operator</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Running sum (default)</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.accumulate(numbers)))  <span class="co"># [1, 3, 6, 10, 15]</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Running product</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.accumulate(numbers, operator.mul)))  <span class="co"># [1, 2, 6, 24, 120]</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Running maximum</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.accumulate([<span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">4</span>, <span class="dv">1</span>, <span class="dv">5</span>], <span class="bu">max</span>)))  <span class="co"># [3, 3, 4, 4, 5]</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="co"># With initial value (Python 3.8+)</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.accumulate([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>], initial<span class="op">=</span><span class="dv">100</span>)))  <span class="co"># [100, 101, 103, 106]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 3, 6, 10, 15]
[1, 2, 6, 24, 120]
[3, 3, 4, 4, 5]
[100, 101, 103, 106]</code></pre>
</div>
</div>
</section>
<section id="chainiterables" class="level3">
<h3 class="anchored" data-anchor-id="chainiterables" id="chainiterables">chain(*iterables)</h3>
<p>Flattens multiple iterables into a single sequence.</p>
<div id="1f576729" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic chaining</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>list1 <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>]</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>list2 <span class="op">=</span> [<span class="dv">4</span>, <span class="dv">5</span>, <span class="dv">6</span>]</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>list3 <span class="op">=</span> [<span class="dv">7</span>, <span class="dv">8</span>, <span class="dv">9</span>]</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>chained <span class="op">=</span> itertools.chain(list1, list2, list3)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(chained))  <span class="co"># [1, 2, 3, 4, 5, 6, 7, 8, 9]</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Chain from iterable</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>nested_lists <span class="op">=</span> [[<span class="dv">1</span>, <span class="dv">2</span>], [<span class="dv">3</span>, <span class="dv">4</span>], [<span class="dv">5</span>, <span class="dv">6</span>]]</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>flattened <span class="op">=</span> itertools.chain.from_iterable(nested_lists)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(flattened))  <span class="co"># [1, 2, 3, 4, 5, 6]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5, 6, 7, 8, 9]
[1, 2, 3, 4, 5, 6]</code></pre>
</div>
</div>
</section>
<section id="compressdata-selectors" class="level3">
<h3 class="anchored" data-anchor-id="compressdata-selectors" id="compressdata-selectors">compress(data, selectors)</h3>
<p>Filters data based on corresponding boolean values in selectors.</p>
<div id="6b403a23" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="st">'A'</span>, <span class="st">'B'</span>, <span class="st">'C'</span>, <span class="st">'D'</span>, <span class="st">'E'</span>]</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>selectors <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">0</span>, <span class="dv">1</span>]</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> itertools.compress(data, selectors)</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(filtered))  <span class="co"># ['A', 'C', 'E']</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Filtering based on conditions</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>names <span class="op">=</span> [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>]</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>ages <span class="op">=</span> [<span class="dv">25</span>, <span class="dv">17</span>, <span class="dv">30</span>, <span class="dv">16</span>]</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>adults <span class="op">=</span> [age <span class="op">&gt;=</span> <span class="dv">18</span> <span class="cf">for</span> age <span class="kw">in</span> ages]</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>adult_names <span class="op">=</span> itertools.compress(names, adults)</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(adult_names))  <span class="co"># ['Alice', 'Charlie']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>['A', 'C', 'E']
['Alice', 'Charlie']</code></pre>
</div>
</div>
</section>
<section id="dropwhilepredicate-iterable" class="level3">
<h3 class="anchored" data-anchor-id="dropwhilepredicate-iterable" id="dropwhilepredicate-iterable">dropwhile(predicate, iterable)</h3>
<p>Drops elements from the beginning while predicate is true.</p>
<div id="d406906a" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>, <span class="dv">12</span>]</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> itertools.dropwhile(<span class="kw">lambda</span> x: x <span class="op">&lt;</span> <span class="dv">8</span>, numbers)</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(result))  <span class="co"># [8, 9, 10, 12]</span></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Skip header lines</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>lines <span class="op">=</span> [<span class="st">'# Comment'</span>, <span class="st">'# Another comment'</span>, <span class="st">'data1'</span>, <span class="st">'data2'</span>, <span class="st">'# inline comment'</span>]</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>data_lines <span class="op">=</span> itertools.dropwhile(<span class="kw">lambda</span> line: line.startswith(<span class="st">'#'</span>), lines)</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(data_lines))  <span class="co"># ['data1', 'data2', '# inline comment']</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Processing log entries</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>log_entries <span class="op">=</span> [</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"INFO: Starting application"</span>,</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">"DEBUG: Loading config"</span>,</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"ERROR: Database connection failed"</span>,</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">"INFO: Retrying connection"</span>,</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>    <span class="st">"INFO: Connection successful"</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Skip INFO messages at the beginning</span></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>important_logs <span class="op">=</span> itertools.dropwhile(</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">lambda</span> x: x.startswith(<span class="st">"INFO"</span>), log_entries</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(important_logs))</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[8, 9, 10, 12]
['data1', 'data2', '# inline comment']
['DEBUG: Loading config', 'ERROR: Database connection failed', 'INFO: Retrying connection', 'INFO: Connection successful']</code></pre>
</div>
</div>
</section>
<section id="takewhilepredicate-iterable" class="level3">
<h3 class="anchored" data-anchor-id="takewhilepredicate-iterable" id="takewhilepredicate-iterable">takewhile(predicate, iterable)</h3>
<p>Returns elements from the beginning while predicate is true.</p>
<div id="3e12b4c9" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>, <span class="dv">12</span>]</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> itertools.takewhile(<span class="kw">lambda</span> x: x <span class="op">&lt;</span> <span class="dv">8</span>, numbers)</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(result))  <span class="co"># [1, 3, 5]</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Read until delimiter</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="st">'apple'</span>, <span class="st">'banana'</span>, <span class="st">'STOP'</span>, <span class="st">'cherry'</span>, <span class="st">'date'</span>]</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>before_stop <span class="op">=</span> itertools.takewhile(<span class="kw">lambda</span> x: x <span class="op">!=</span> <span class="st">'STOP'</span>, data)</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(before_stop))  <span class="co"># ['apple', 'banana']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 3, 5]
['apple', 'banana']</code></pre>
</div>
</div>
</section>
<section id="filterfalsepredicate-iterable" class="level3">
<h3 class="anchored" data-anchor-id="filterfalsepredicate-iterable" id="filterfalsepredicate-iterable">filterfalse(predicate, iterable)</h3>
<p>Returns elements where predicate is false (opposite of filter).</p>
<div id="b7b27b83" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>, <span class="dv">6</span>, <span class="dv">7</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>]</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>odds <span class="op">=</span> itertools.filterfalse(<span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>, numbers)</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(odds))  <span class="co"># [1, 3, 5, 7, 9]</span></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Compare with regular filter</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>evens <span class="op">=</span> <span class="bu">filter</span>(<span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>, numbers)</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(evens))  <span class="co"># [2, 4, 6, 8, 10]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 3, 5, 7, 9]
[2, 4, 6, 8, 10]</code></pre>
</div>
</div>
</section>
<section id="groupbyiterable-keynone" class="level3">
<h3 class="anchored" data-anchor-id="groupbyiterable-keynone" id="groupbyiterable-keynone">groupby(iterable, key=None)</h3>
<p>Groups consecutive elements by a key function.</p>
<div id="dabca415" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic grouping</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">2</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>]</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>grouped <span class="op">=</span> itertools.groupby(data)</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> key, group <span class="kw">in</span> grouped:</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>key<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span><span class="bu">list</span>(group)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="co"># 1: [1, 1]</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a><span class="co"># 2: [2, 2, 2]</span></span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a><span class="co"># 3: [3]</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a><span class="co"># 1: [1, 1]</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Grouping with key function</span></span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>words <span class="op">=</span> [<span class="st">'apple'</span>, <span class="st">'banana'</span>, <span class="st">'apricot'</span>, <span class="st">'blueberry'</span>, <span class="st">'cherry'</span>]</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a><span class="co"># First sort by first letter, then group</span></span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>sorted_words <span class="op">=</span> <span class="bu">sorted</span>(words, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">0</span>])</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>grouped_words <span class="op">=</span> itertools.groupby(sorted_words, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">0</span>])</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> letter, group <span class="kw">in</span> grouped_words:</span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>letter<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span><span class="bu">list</span>(group)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a><span class="co"># a: ['apple', 'apricot']</span></span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a><span class="co"># b: ['banana', 'blueberry']</span></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a><span class="co"># c: ['cherry']</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Grouping sorted data</span></span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>students <span class="op">=</span> [</span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Alice'</span>, <span class="st">'A'</span>),</span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Bob'</span>, <span class="st">'B'</span>),</span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Charlie'</span>, <span class="st">'A'</span>),</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'David'</span>, <span class="st">'B'</span>),</span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Eve'</span>, <span class="st">'A'</span>)</span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Sort first, then group</span></span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>students_sorted <span class="op">=</span> <span class="bu">sorted</span>(students, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">1</span>])</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>by_grade <span class="op">=</span> itertools.groupby(students_sorted, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">1</span>])</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> grade, group <span class="kw">in</span> by_grade:</span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>    names <span class="op">=</span> [student[<span class="dv">0</span>] <span class="cf">for</span> student <span class="kw">in</span> group]</span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Grade </span><span class="sc">{</span>grade<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>names<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>1: [1, 1]
2: [2, 2, 2]
3: [3]
1: [1, 1]
a: ['apple', 'apricot']
b: ['banana', 'blueberry']
c: ['cherry']
Grade A: ['Alice', 'Charlie', 'Eve']
Grade B: ['Bob', 'David']</code></pre>
</div>
</div>
</section>
<section id="isliceiterable-start-stop-step" class="level3">
<h3 class="anchored" data-anchor-id="isliceiterable-start-stop-step" id="isliceiterable-start-stop-step">islice(iterable, start, stop, step)</h3>
<p>Returns selected elements from the iterable (like list slicing but for iterators).</p>
<div id="8191ea97" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> <span class="bu">range</span>(<span class="dv">20</span>)</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a><span class="co"># islice(iterable, stop)</span></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(numbers, <span class="dv">5</span>)))  <span class="co"># [0, 1, 2, 3, 4]</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a><span class="co"># islice(iterable, start, stop)</span></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(numbers, <span class="dv">5</span>, <span class="dv">10</span>)))  <span class="co"># [5, 6, 7, 8, 9]</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a><span class="co"># islice(iterable, start, stop, step)</span></span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(numbers, <span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">2</span>)))  <span class="co"># [0, 2, 4, 6, 8]</span></span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Pagination</span></span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> paginate(iterable, page_size):</span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>    iterator <span class="op">=</span> <span class="bu">iter</span>(iterable)</span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>        page <span class="op">=</span> <span class="bu">list</span>(itertools.islice(iterator, page_size))</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> page:</span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> page</span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> <span class="bu">range</span>(<span class="dv">25</span>)</span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> page_num, page <span class="kw">in</span> <span class="bu">enumerate</span>(paginate(data, <span class="dv">10</span>), <span class="dv">1</span>):</span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Page </span><span class="sc">{</span>page_num<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>page<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[0, 2, 4, 6, 8]
Page 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Page 2: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Page 3: [20, 21, 22, 23, 24]</code></pre>
</div>
</div>
</section>
<section id="starmapfunction-iterable" class="level3">
<h3 class="anchored" data-anchor-id="starmapfunction-iterable" id="starmapfunction-iterable">starmap(function, iterable)</h3>
<p>Applies function to arguments unpacked from each item in iterable.</p>
<div id="e3a0b93a" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic usage</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a>points <span class="op">=</span> [(<span class="dv">1</span>, <span class="dv">2</span>), (<span class="dv">3</span>, <span class="dv">4</span>), (<span class="dv">5</span>, <span class="dv">6</span>)]</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a>distances <span class="op">=</span> itertools.starmap(<span class="kw">lambda</span> x, y: (x<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> y<span class="op">**</span><span class="dv">2</span>)<span class="op">**</span><span class="fl">0.5</span>, points)</span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(distances))  <span class="co"># [2.236..., 5.0, 7.810...]</span></span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Multiple argument functions</span></span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> operator</span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a>pairs <span class="op">=</span> [(<span class="dv">2</span>, <span class="dv">3</span>), (<span class="dv">4</span>, <span class="dv">5</span>), (<span class="dv">6</span>, <span class="dv">7</span>)]</span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a>products <span class="op">=</span> itertools.starmap(operator.mul, pairs)</span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(products))  <span class="co"># [6, 20, 42]</span></span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-12"><a href="#cb25-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Compare with map</span></span>
<span id="cb25-13"><a href="#cb25-13" aria-hidden="true" tabindex="-1"></a>regular_map <span class="op">=</span> <span class="bu">map</span>(operator.mul, [<span class="dv">2</span>, <span class="dv">4</span>, <span class="dv">6</span>], [<span class="dv">3</span>, <span class="dv">5</span>, <span class="dv">7</span>])</span>
<span id="cb25-14"><a href="#cb25-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(regular_map))  <span class="co"># [6, 20, 42]</span></span>
<span id="cb25-15"><a href="#cb25-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-16"><a href="#cb25-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Compare with map</span></span>
<span id="cb25-17"><a href="#cb25-17" aria-hidden="true" tabindex="-1"></a><span class="co"># map passes each tuple as a single argument</span></span>
<span id="cb25-18"><a href="#cb25-18" aria-hidden="true" tabindex="-1"></a><span class="co"># starmap unpacks each tuple as separate arguments</span></span>
<span id="cb25-19"><a href="#cb25-19" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> add(x, y):</span>
<span id="cb25-20"><a href="#cb25-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">+</span> y</span>
<span id="cb25-21"><a href="#cb25-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-22"><a href="#cb25-22" aria-hidden="true" tabindex="-1"></a>pairs <span class="op">=</span> [(<span class="dv">1</span>, <span class="dv">2</span>), (<span class="dv">3</span>, <span class="dv">4</span>), (<span class="dv">5</span>, <span class="dv">6</span>)]</span>
<span id="cb25-23"><a href="#cb25-23" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> <span class="bu">list</span>(itertools.starmap(add, pairs))</span>
<span id="cb25-24"><a href="#cb25-24" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)  <span class="co"># [3, 7, 11]</span></span>
<span id="cb25-25"><a href="#cb25-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-26"><a href="#cb25-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Applying operations to coordinate pairs</span></span>
<span id="cb25-27"><a href="#cb25-27" aria-hidden="true" tabindex="-1"></a>coordinates <span class="op">=</span> [(<span class="dv">1</span>, <span class="dv">2</span>), (<span class="dv">3</span>, <span class="dv">4</span>), (<span class="dv">5</span>, <span class="dv">6</span>)]</span>
<span id="cb25-28"><a href="#cb25-28" aria-hidden="true" tabindex="-1"></a>distances_from_origin <span class="op">=</span> <span class="bu">list</span>(itertools.starmap(</span>
<span id="cb25-29"><a href="#cb25-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">lambda</span> x, y: math.sqrt(x<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> y<span class="op">**</span><span class="dv">2</span>), coordinates</span>
<span id="cb25-30"><a href="#cb25-30" aria-hidden="true" tabindex="-1"></a>))</span>
<span id="cb25-31"><a href="#cb25-31" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(distances_from_origin)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[2.23606797749979, 5.0, 7.810249675906654]
[6, 20, 42]
[6, 20, 42]
[3, 7, 11]
[2.23606797749979, 5.0, 7.810249675906654]</code></pre>
</div>
</div>
</section>
<section id="teeiterable-n2" class="level3">
<h3 class="anchored" data-anchor-id="teeiterable-n2" id="teeiterable-n2">tee(iterable, n=2)</h3>
<p>Splits an iterable into n independent iterators.</p>
<div id="f7de7210" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a>iter1, iter2 <span class="op">=</span> itertools.tee(data)</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(iter1))  <span class="co"># [1, 2, 3, 4, 5]</span></span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(iter2))  <span class="co"># [1, 2, 3, 4, 5]</span></span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Processing data in multiple ways</span></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>, <span class="dv">6</span>, <span class="dv">7</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>]</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>evens_iter, odds_iter <span class="op">=</span> itertools.tee(numbers)</span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a>evens <span class="op">=</span> <span class="bu">filter</span>(<span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>, evens_iter)</span>
<span id="cb27-12"><a href="#cb27-12" aria-hidden="true" tabindex="-1"></a>odds <span class="op">=</span> <span class="bu">filter</span>(<span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">1</span>, odds_iter)</span>
<span id="cb27-13"><a href="#cb27-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-14"><a href="#cb27-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Evens: </span><span class="sc">{</span><span class="bu">list</span>(evens)<span class="sc">}</span><span class="ss">"</span>)  <span class="co"># [2, 4, 6, 8, 10]</span></span>
<span id="cb27-15"><a href="#cb27-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Odds: </span><span class="sc">{</span><span class="bu">list</span>(odds)<span class="sc">}</span><span class="ss">"</span>)    <span class="co"># [1, 3, 5, 7, 9]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5]
[1, 2, 3, 4, 5]
Evens: [2, 4, 6, 8, 10]
Odds: [1, 3, 5, 7, 9]</code></pre>
</div>
</div>
</section>
<section id="zip_longestiterables-fillvaluenone" class="level3">
<h3 class="anchored" data-anchor-id="zip_longestiterables-fillvaluenone" id="zip_longestiterables-fillvaluenone">zip_longest(*iterables, fillvalue=None)</h3>
<p>Zips iterables but continues until the longest is exhausted.</p>
<div id="11b98166" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a>list1 <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>]</span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a>list2 <span class="op">=</span> [<span class="st">'a'</span>, <span class="st">'b'</span>, <span class="st">'c'</span>, <span class="st">'d'</span>, <span class="st">'e'</span>]</span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Regular zip stops at shortest</span></span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(<span class="bu">zip</span>(list1, list2)))  <span class="co"># [(1, 'a'), (2, 'b'), (3, 'c')]</span></span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a><span class="co"># zip_longest continues to longest</span></span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.zip_longest(list1, list2)))</span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a><span class="co"># [(1, 'a'), (2, 'b'), (3, 'c'), (None, 'd'), (None, 'e')]</span></span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a><span class="co"># With custom fillvalue</span></span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.zip_longest(list1, list2, fillvalue<span class="op">=</span><span class="st">'X'</span>)))</span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a><span class="co"># [(1, 'a'), (2, 'b'), (3, 'c'), ('X', 'd'), ('X', 'e')]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[(1, 'a'), (2, 'b'), (3, 'c')]
[(1, 'a'), (2, 'b'), (3, 'c'), (None, 'd'), (None, 'e')]
[(1, 'a'), (2, 'b'), (3, 'c'), ('X', 'd'), ('X', 'e')]</code></pre>
</div>
</div>
<hr>
</section>
</section>
<section id="combinatorial-iterators" class="level2">
<h2 class="anchored" data-anchor-id="combinatorial-iterators" id="combinatorial-iterators">3. Combinatorial Iterators</h2>
<section id="productiterables-repeat1" class="level3">
<h3 class="anchored" data-anchor-id="productiterables-repeat1" id="productiterables-repeat1">product(*iterables, repeat=1)</h3>
<p>Cartesian product of input iterables.</p>
<div id="b9e96a86" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic product</span></span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a>colors <span class="op">=</span> [<span class="st">'red'</span>, <span class="st">'blue'</span>]</span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a>sizes <span class="op">=</span> [<span class="st">'S'</span>, <span class="st">'M'</span>, <span class="st">'L'</span>]</span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a>combinations <span class="op">=</span> itertools.product(colors, sizes)</span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(combinations))</span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a><span class="co"># [('red', 'S'), ('red', 'M'), ('red', 'L'), ('blue', 'S'), ('blue', 'M'), ('blue', 'L')]</span></span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-9"><a href="#cb31-9" aria-hidden="true" tabindex="-1"></a><span class="co"># With repeat</span></span>
<span id="cb31-10"><a href="#cb31-10" aria-hidden="true" tabindex="-1"></a>dice_rolls <span class="op">=</span> itertools.product(<span class="bu">range</span>(<span class="dv">1</span>, <span class="dv">7</span>), repeat<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb31-11"><a href="#cb31-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(dice_rolls, <span class="dv">10</span>)))</span>
<span id="cb31-12"><a href="#cb31-12" aria-hidden="true" tabindex="-1"></a><span class="co"># [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (2, 1), (2, 2), (2, 3), (2, 4)]</span></span>
<span id="cb31-13"><a href="#cb31-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-14"><a href="#cb31-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Grid coordinates</span></span>
<span id="cb31-15"><a href="#cb31-15" aria-hidden="true" tabindex="-1"></a>grid <span class="op">=</span> itertools.product(<span class="bu">range</span>(<span class="dv">3</span>), <span class="bu">range</span>(<span class="dv">3</span>))</span>
<span id="cb31-16"><a href="#cb31-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(grid))</span>
<span id="cb31-17"><a href="#cb31-17" aria-hidden="true" tabindex="-1"></a><span class="co"># [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[('red', 'S'), ('red', 'M'), ('red', 'L'), ('blue', 'S'), ('blue', 'M'), ('blue', 'L')]
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (2, 1), (2, 2), (2, 3), (2, 4)]
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]</code></pre>
</div>
</div>
</section>
<section id="permutationsiterable-rnone" class="level3">
<h3 class="anchored" data-anchor-id="permutationsiterable-rnone" id="permutationsiterable-rnone">permutations(iterable, r=None)</h3>
<p>Returns r-length permutations of elements.</p>
<div id="3b4b528f" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb33"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="co"># All permutations</span></span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a>letters <span class="op">=</span> [<span class="st">'A'</span>, <span class="st">'B'</span>, <span class="st">'C'</span>]</span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a>perms <span class="op">=</span> itertools.permutations(letters)</span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(perms))</span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a><span class="co"># [('A', 'B', 'C'), ('A', 'C', 'B'), ('B', 'A', 'C'), ('B', 'C', 'A'), ('C', 'A', 'B'), ('C', 'B', 'A')]</span></span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a><span class="co"># r-length permutations</span></span>
<span id="cb33-8"><a href="#cb33-8" aria-hidden="true" tabindex="-1"></a>perms_2 <span class="op">=</span> itertools.permutations(letters, <span class="dv">2</span>)</span>
<span id="cb33-9"><a href="#cb33-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(perms_2))</span>
<span id="cb33-10"><a href="#cb33-10" aria-hidden="true" tabindex="-1"></a><span class="co"># [('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]</span></span>
<span id="cb33-11"><a href="#cb33-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-12"><a href="#cb33-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Anagrams</span></span>
<span id="cb33-13"><a href="#cb33-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_anagrams(word, length<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb33-14"><a href="#cb33-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> length <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb33-15"><a href="#cb33-15" aria-hidden="true" tabindex="-1"></a>        length <span class="op">=</span> <span class="bu">len</span>(word)</span>
<span id="cb33-16"><a href="#cb33-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [<span class="st">''</span>.join(p) <span class="cf">for</span> p <span class="kw">in</span> itertools.permutations(word, length)]</span>
<span id="cb33-17"><a href="#cb33-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-18"><a href="#cb33-18" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(find_anagrams(<span class="st">'CAT'</span>, <span class="dv">2</span>))  <span class="co"># ['CA', 'CT', 'AC', 'AT', 'TC', 'TA']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[('A', 'B', 'C'), ('A', 'C', 'B'), ('B', 'A', 'C'), ('B', 'C', 'A'), ('C', 'A', 'B'), ('C', 'B', 'A')]
[('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]
['CA', 'CT', 'AC', 'AT', 'TC', 'TA']</code></pre>
</div>
</div>
</section>
<section id="combinationsiterable-r" class="level3">
<h3 class="anchored" data-anchor-id="combinationsiterable-r" id="combinationsiterable-r">combinations(iterable, r)</h3>
<p>Returns r-length combinations without replacement.</p>
<div id="32061c9c" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic combinations</span></span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>]</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a>combos <span class="op">=</span> itertools.combinations(numbers, <span class="dv">2</span>)</span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(combos))</span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a><span class="co"># [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]</span></span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Team selection</span></span>
<span id="cb35-8"><a href="#cb35-8" aria-hidden="true" tabindex="-1"></a>players <span class="op">=</span> [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>, <span class="st">'Eve'</span>]</span>
<span id="cb35-9"><a href="#cb35-9" aria-hidden="true" tabindex="-1"></a>teams <span class="op">=</span> itertools.combinations(players, <span class="dv">3</span>)</span>
<span id="cb35-10"><a href="#cb35-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(itertools.islice(teams, <span class="dv">5</span>)))</span>
<span id="cb35-11"><a href="#cb35-11" aria-hidden="true" tabindex="-1"></a><span class="co"># [('Alice', 'Bob', 'Charlie'), ('Alice', 'Bob', 'David'), ('Alice', 'Bob', 'Eve'), ('Alice', 'Charlie', 'David'), ('Alice', 'Charlie', 'Eve')]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
[('Alice', 'Bob', 'Charlie'), ('Alice', 'Bob', 'David'), ('Alice', 'Bob', 'Eve'), ('Alice', 'Charlie', 'David'), ('Alice', 'Charlie', 'Eve')]</code></pre>
</div>
</div>
</section>
<section id="combinations_with_replacementiterable-r" class="level3">
<h3 class="anchored" data-anchor-id="combinations_with_replacementiterable-r" id="combinations_with_replacementiterable-r">combinations_with_replacement(iterable, r)</h3>
<p>Returns r-length combinations with replacement allowed.</p>
<div id="54d82c8d" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic combinations with replacement</span></span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>]</span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a>combos <span class="op">=</span> itertools.combinations_with_replacement(numbers, <span class="dv">2</span>)</span>
<span id="cb37-4"><a href="#cb37-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(combos))</span>
<span id="cb37-5"><a href="#cb37-5" aria-hidden="true" tabindex="-1"></a><span class="co"># [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]</span></span>
<span id="cb37-6"><a href="#cb37-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb37-7"><a href="#cb37-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Practical example: Coin flips allowing same outcome</span></span>
<span id="cb37-8"><a href="#cb37-8" aria-hidden="true" tabindex="-1"></a>outcomes <span class="op">=</span> [<span class="st">'H'</span>, <span class="st">'T'</span>]</span>
<span id="cb37-9"><a href="#cb37-9" aria-hidden="true" tabindex="-1"></a>two_flips <span class="op">=</span> itertools.combinations_with_replacement(outcomes, <span class="dv">2</span>)</span>
<span id="cb37-10"><a href="#cb37-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(two_flips))</span>
<span id="cb37-11"><a href="#cb37-11" aria-hidden="true" tabindex="-1"></a><span class="co"># [('H', 'H'), ('H', 'T'), ('T', 'T')]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
[('H', 'H'), ('H', 'T'), ('T', 'T')]</code></pre>
</div>
</div>
<hr>
</section>
</section>
<section id="grouping-and-filtering" class="level2">
<h2 class="anchored" data-anchor-id="grouping-and-filtering" id="grouping-and-filtering">Grouping and Filtering</h2>
<section id="advanced-groupby-examples" class="level3">
<h3 class="anchored" data-anchor-id="advanced-groupby-examples" id="advanced-groupby-examples">Advanced groupby() Examples</h3>
<div id="f232dc9c" class="cell" data-execution_count="21">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Group by multiple criteria</span></span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [</span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'name'</span>: <span class="st">'Alice'</span>, <span class="st">'age'</span>: <span class="dv">25</span>, <span class="st">'city'</span>: <span class="st">'New York'</span>},</span>
<span id="cb39-4"><a href="#cb39-4" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'name'</span>: <span class="st">'Bob'</span>, <span class="st">'age'</span>: <span class="dv">25</span>, <span class="st">'city'</span>: <span class="st">'New York'</span>},</span>
<span id="cb39-5"><a href="#cb39-5" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'name'</span>: <span class="st">'Charlie'</span>, <span class="st">'age'</span>: <span class="dv">30</span>, <span class="st">'city'</span>: <span class="st">'Boston'</span>},</span>
<span id="cb39-6"><a href="#cb39-6" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'name'</span>: <span class="st">'David'</span>, <span class="st">'age'</span>: <span class="dv">30</span>, <span class="st">'city'</span>: <span class="st">'Boston'</span>},</span>
<span id="cb39-7"><a href="#cb39-7" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'name'</span>: <span class="st">'Eve'</span>, <span class="st">'age'</span>: <span class="dv">25</span>, <span class="st">'city'</span>: <span class="st">'Boston'</span>}</span>
<span id="cb39-8"><a href="#cb39-8" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb39-9"><a href="#cb39-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-10"><a href="#cb39-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Group by age and city</span></span>
<span id="cb39-11"><a href="#cb39-11" aria-hidden="true" tabindex="-1"></a>key_func <span class="op">=</span> <span class="kw">lambda</span> x: (x[<span class="st">'age'</span>], x[<span class="st">'city'</span>])</span>
<span id="cb39-12"><a href="#cb39-12" aria-hidden="true" tabindex="-1"></a>sorted_data <span class="op">=</span> <span class="bu">sorted</span>(data, key<span class="op">=</span>key_func)</span>
<span id="cb39-13"><a href="#cb39-13" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> key, group <span class="kw">in</span> itertools.groupby(sorted_data, key<span class="op">=</span>key_func):</span>
<span id="cb39-14"><a href="#cb39-14" aria-hidden="true" tabindex="-1"></a>    age, city <span class="op">=</span> key</span>
<span id="cb39-15"><a href="#cb39-15" aria-hidden="true" tabindex="-1"></a>    names <span class="op">=</span> [person[<span class="st">'name'</span>] <span class="cf">for</span> person <span class="kw">in</span> group]</span>
<span id="cb39-16"><a href="#cb39-16" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Age </span><span class="sc">{</span>age<span class="sc">}</span><span class="ss">, City </span><span class="sc">{</span>city<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>names<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Age 25, City Boston: ['Eve']
Age 25, City New York: ['Alice', 'Bob']
Age 30, City Boston: ['Charlie', 'David']</code></pre>
</div>
</div>
</section>
<section id="custom-filtering-patterns" class="level3">
<h3 class="anchored" data-anchor-id="custom-filtering-patterns" id="custom-filtering-patterns">Custom Filtering Patterns</h3>
<div id="d4f0691d" class="cell" data-execution_count="22">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter consecutive duplicates</span></span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> remove_consecutive_duplicates(iterable):</span>
<span id="cb41-3"><a href="#cb41-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [key <span class="cf">for</span> key, _ <span class="kw">in</span> itertools.groupby(iterable)]</span>
<span id="cb41-4"><a href="#cb41-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-5"><a href="#cb41-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">2</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">4</span>]</span>
<span id="cb41-6"><a href="#cb41-6" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> remove_consecutive_duplicates(data)</span>
<span id="cb41-7"><a href="#cb41-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)  <span class="co"># [1, 2, 3, 1, 4]</span></span>
<span id="cb41-8"><a href="#cb41-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb41-9"><a href="#cb41-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter with multiple conditions</span></span>
<span id="cb41-10"><a href="#cb41-10" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> <span class="bu">range</span>(<span class="dv">1</span>, <span class="dv">21</span>)</span>
<span id="cb41-11"><a href="#cb41-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Even numbers not divisible by 4</span></span>
<span id="cb41-12"><a href="#cb41-12" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> itertools.filterfalse(</span>
<span id="cb41-13"><a href="#cb41-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">!=</span> <span class="dv">0</span> <span class="kw">or</span> x <span class="op">%</span> <span class="dv">4</span> <span class="op">==</span> <span class="dv">0</span>, numbers</span>
<span id="cb41-14"><a href="#cb41-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb41-15"><a href="#cb41-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(filtered))  <span class="co"># [2, 6, 10, 14, 18]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 1, 4]
[2, 6, 10, 14, 18]</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-patterns-and-recipes" class="level2">
<h2 class="anchored" data-anchor-id="advanced-patterns-and-recipes" id="advanced-patterns-and-recipes">Advanced Patterns and Recipes</h2>
<section id="recipe-flatten-nested-iterables" class="level3">
<h3 class="anchored" data-anchor-id="recipe-flatten-nested-iterables" id="recipe-flatten-nested-iterables">Recipe: Flatten Nested Iterables</h3>
<div id="a977f180" class="cell" data-execution_count="23">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb43"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb43-1"><a href="#cb43-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> flatten(nested_iterable):</span>
<span id="cb43-2"><a href="#cb43-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Flatten one level of nesting."""</span></span>
<span id="cb43-3"><a href="#cb43-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> itertools.chain.from_iterable(nested_iterable)</span>
<span id="cb43-4"><a href="#cb43-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb43-5"><a href="#cb43-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb43-6"><a href="#cb43-6" aria-hidden="true" tabindex="-1"></a>nested <span class="op">=</span> [[<span class="dv">1</span>, <span class="dv">2</span>], [<span class="dv">3</span>, <span class="dv">4</span>], [<span class="dv">5</span>, <span class="dv">6</span>]]</span>
<span id="cb43-7"><a href="#cb43-7" aria-hidden="true" tabindex="-1"></a>flat <span class="op">=</span> <span class="bu">list</span>(flatten(nested))</span>
<span id="cb43-8"><a href="#cb43-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(flat)  <span class="co"># [1, 2, 3, 4, 5, 6]</span></span>
<span id="cb43-9"><a href="#cb43-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb43-10"><a href="#cb43-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> deep_flatten(nested_iterable):</span>
<span id="cb43-11"><a href="#cb43-11" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Recursively flatten deeply nested iterables."""</span></span>
<span id="cb43-12"><a href="#cb43-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> item <span class="kw">in</span> nested_iterable:</span>
<span id="cb43-13"><a href="#cb43-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(item, <span class="st">'__iter__'</span>) <span class="kw">and</span> <span class="kw">not</span> <span class="bu">isinstance</span>(item, (<span class="bu">str</span>, <span class="bu">bytes</span>)):</span>
<span id="cb43-14"><a href="#cb43-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> <span class="cf">from</span> deep_flatten(item)</span>
<span id="cb43-15"><a href="#cb43-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb43-16"><a href="#cb43-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> item</span>
<span id="cb43-17"><a href="#cb43-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb43-18"><a href="#cb43-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb43-19"><a href="#cb43-19" aria-hidden="true" tabindex="-1"></a>deeply_nested <span class="op">=</span> [<span class="dv">1</span>, [<span class="dv">2</span>, [<span class="dv">3</span>, <span class="dv">4</span>]], <span class="dv">5</span>, [<span class="dv">6</span>, [<span class="dv">7</span>, [<span class="dv">8</span>, <span class="dv">9</span>]]]]</span>
<span id="cb43-20"><a href="#cb43-20" aria-hidden="true" tabindex="-1"></a>flat <span class="op">=</span> <span class="bu">list</span>(deep_flatten(deeply_nested))</span>
<span id="cb43-21"><a href="#cb43-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(flat)  <span class="co"># [1, 2, 3, 4, 5, 6, 7, 8, 9]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5, 6]
[1, 2, 3, 4, 5, 6, 7, 8, 9]</code></pre>
</div>
</div>
</section>
<section id="recipe-sliding-window" class="level3">
<h3 class="anchored" data-anchor-id="recipe-sliding-window" id="recipe-sliding-window">Recipe: Sliding Window</h3>
<div id="9803599b" class="cell" data-execution_count="24">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb45"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sliding_window(iterable, n):</span>
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create a sliding window of size n."""</span></span>
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a>    iterators <span class="op">=</span> itertools.tee(iterable, n)</span>
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, it <span class="kw">in</span> <span class="bu">enumerate</span>(iterators):</span>
<span id="cb45-5"><a href="#cb45-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Advance each iterator by i positions</span></span>
<span id="cb45-6"><a href="#cb45-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(i):</span>
<span id="cb45-7"><a href="#cb45-7" aria-hidden="true" tabindex="-1"></a>            <span class="bu">next</span>(it, <span class="va">None</span>)</span>
<span id="cb45-8"><a href="#cb45-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">zip</span>(<span class="op">*</span>iterators)</span>
<span id="cb45-9"><a href="#cb45-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb45-10"><a href="#cb45-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb45-11"><a href="#cb45-11" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>, <span class="dv">6</span>, <span class="dv">7</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>]</span>
<span id="cb45-12"><a href="#cb45-12" aria-hidden="true" tabindex="-1"></a>windows <span class="op">=</span> <span class="bu">list</span>(sliding_window(data, <span class="dv">3</span>))</span>
<span id="cb45-13"><a href="#cb45-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(windows)  <span class="co"># [(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6), (5, 6, 7), (6, 7, 8), (7, 8, 9), (8, 9, 10)]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6), (5, 6, 7), (6, 7, 8), (7, 8, 9), (8, 9, 10)]</code></pre>
</div>
</div>
</section>
<section id="recipe-roundrobin" class="level3">
<h3 class="anchored" data-anchor-id="recipe-roundrobin" id="recipe-roundrobin">Recipe: Roundrobin</h3>
<div id="c6f79066" class="cell" data-execution_count="25">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb47"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb47-1"><a href="#cb47-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> roundrobin(<span class="op">*</span>iterables):</span>
<span id="cb47-2"><a href="#cb47-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Take elements from iterables in round-robin fashion."""</span></span>
<span id="cb47-3"><a href="#cb47-3" aria-hidden="true" tabindex="-1"></a>    iterators <span class="op">=</span> [<span class="bu">iter</span>(it) <span class="cf">for</span> it <span class="kw">in</span> iterables]</span>
<span id="cb47-4"><a href="#cb47-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> iterators:</span>
<span id="cb47-5"><a href="#cb47-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> it <span class="kw">in</span> iterators[:]:</span>
<span id="cb47-6"><a href="#cb47-6" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb47-7"><a href="#cb47-7" aria-hidden="true" tabindex="-1"></a>                <span class="cf">yield</span> <span class="bu">next</span>(it)</span>
<span id="cb47-8"><a href="#cb47-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> <span class="pp">StopIteration</span>:</span>
<span id="cb47-9"><a href="#cb47-9" aria-hidden="true" tabindex="-1"></a>                iterators.remove(it)</span>
<span id="cb47-10"><a href="#cb47-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb47-11"><a href="#cb47-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb47-12"><a href="#cb47-12" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> <span class="bu">list</span>(roundrobin(<span class="st">'ABC'</span>, <span class="st">'12345'</span>, <span class="st">'xyz'</span>))</span>
<span id="cb47-13"><a href="#cb47-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)  <span class="co"># ['A', '1', 'x', 'B', '2', 'y', 'C', '3', 'z', '4', '5']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>['A', '1', 'x', 'B', '2', 'y', 'C', '3', 'z', '4', '5']</code></pre>
</div>
</div>
</section>
<section id="recipe-unique-elements-preserving-order" class="level3">
<h3 class="anchored" data-anchor-id="recipe-unique-elements-preserving-order" id="recipe-unique-elements-preserving-order">Recipe: Unique Elements (Preserving Order)</h3>
<div id="0d19a2a9" class="cell" data-execution_count="26">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb49"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb49-1"><a href="#cb49-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> unique_everseen(iterable, key<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb49-2"><a href="#cb49-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""List unique elements, preserving order."""</span></span>
<span id="cb49-3"><a href="#cb49-3" aria-hidden="true" tabindex="-1"></a>    seen <span class="op">=</span> <span class="bu">set</span>()</span>
<span id="cb49-4"><a href="#cb49-4" aria-hidden="true" tabindex="-1"></a>    seen_add <span class="op">=</span> seen.add</span>
<span id="cb49-5"><a href="#cb49-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> key <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb49-6"><a href="#cb49-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> element <span class="kw">in</span> itertools.filterfalse(seen.<span class="fu">__contains__</span>, iterable):</span>
<span id="cb49-7"><a href="#cb49-7" aria-hidden="true" tabindex="-1"></a>            seen_add(element)</span>
<span id="cb49-8"><a href="#cb49-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> element</span>
<span id="cb49-9"><a href="#cb49-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb49-10"><a href="#cb49-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> element <span class="kw">in</span> iterable:</span>
<span id="cb49-11"><a href="#cb49-11" aria-hidden="true" tabindex="-1"></a>            k <span class="op">=</span> key(element)</span>
<span id="cb49-12"><a href="#cb49-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> k <span class="kw">not</span> <span class="kw">in</span> seen:</span>
<span id="cb49-13"><a href="#cb49-13" aria-hidden="true" tabindex="-1"></a>                seen_add(k)</span>
<span id="cb49-14"><a href="#cb49-14" aria-hidden="true" tabindex="-1"></a>                <span class="cf">yield</span> element</span>
<span id="cb49-15"><a href="#cb49-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-16"><a href="#cb49-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb49-17"><a href="#cb49-17" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">2</span>, <span class="dv">4</span>, <span class="dv">1</span>, <span class="dv">5</span>, <span class="dv">3</span>, <span class="dv">6</span>]</span>
<span id="cb49-18"><a href="#cb49-18" aria-hidden="true" tabindex="-1"></a>unique <span class="op">=</span> <span class="bu">list</span>(unique_everseen(data))</span>
<span id="cb49-19"><a href="#cb49-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(unique)  <span class="co"># [1, 2, 3, 4, 5, 6]</span></span>
<span id="cb49-20"><a href="#cb49-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb49-21"><a href="#cb49-21" aria-hidden="true" tabindex="-1"></a><span class="co"># With key function</span></span>
<span id="cb49-22"><a href="#cb49-22" aria-hidden="true" tabindex="-1"></a>words <span class="op">=</span> [<span class="st">'apple'</span>, <span class="st">'Banana'</span>, <span class="st">'cherry'</span>, <span class="st">'Apple'</span>, <span class="st">'banana'</span>]</span>
<span id="cb49-23"><a href="#cb49-23" aria-hidden="true" tabindex="-1"></a>unique_words <span class="op">=</span> <span class="bu">list</span>(unique_everseen(words, key<span class="op">=</span><span class="bu">str</span>.lower))</span>
<span id="cb49-24"><a href="#cb49-24" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(unique_words)  <span class="co"># ['apple', 'Banana', 'cherry']</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5, 6]
['apple', 'Banana', 'cherry']</code></pre>
</div>
</div>
<hr>
</section>
</section>
<section id="practical-examples-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="practical-examples-and-use-cases" id="practical-examples-and-use-cases">Practical Examples and Use Cases</h2>
<section id="data-processing-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="data-processing-pipeline" id="data-processing-pipeline">1. Data Processing Pipeline</h3>
<div id="2c21e6ed" class="cell" data-execution_count="27">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb51"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb51-1"><a href="#cb51-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> itertools</span>
<span id="cb51-2"><a href="#cb51-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> operator</span>
<span id="cb51-3"><a href="#cb51-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-4"><a href="#cb51-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Sample data</span></span>
<span id="cb51-5"><a href="#cb51-5" aria-hidden="true" tabindex="-1"></a>sales_data <span class="op">=</span> [</span>
<span id="cb51-6"><a href="#cb51-6" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q1'</span>, <span class="st">'Product A'</span>, <span class="dv">100</span>),</span>
<span id="cb51-7"><a href="#cb51-7" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q1'</span>, <span class="st">'Product B'</span>, <span class="dv">150</span>),</span>
<span id="cb51-8"><a href="#cb51-8" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q2'</span>, <span class="st">'Product A'</span>, <span class="dv">120</span>),</span>
<span id="cb51-9"><a href="#cb51-9" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q2'</span>, <span class="st">'Product B'</span>, <span class="dv">180</span>),</span>
<span id="cb51-10"><a href="#cb51-10" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q3'</span>, <span class="st">'Product A'</span>, <span class="dv">110</span>),</span>
<span id="cb51-11"><a href="#cb51-11" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'Q3'</span>, <span class="st">'Product B'</span>, <span class="dv">160</span>),</span>
<span id="cb51-12"><a href="#cb51-12" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb51-13"><a href="#cb51-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-14"><a href="#cb51-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Group by quarter and calculate totals</span></span>
<span id="cb51-15"><a href="#cb51-15" aria-hidden="true" tabindex="-1"></a>sales_by_quarter <span class="op">=</span> itertools.groupby(sales_data, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">0</span>])</span>
<span id="cb51-16"><a href="#cb51-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb51-17"><a href="#cb51-17" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> quarter, sales <span class="kw">in</span> sales_by_quarter:</span>
<span id="cb51-18"><a href="#cb51-18" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="bu">sum</span>(sale[<span class="dv">2</span>] <span class="cf">for</span> sale <span class="kw">in</span> sales)</span>
<span id="cb51-19"><a href="#cb51-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>quarter<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>total<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Q1: 250
Q2: 300
Q3: 270</code></pre>
</div>
</div>
</section>
<section id="batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing" id="batch-processing">2. Batch Processing</h3>
<div id="31cd6867" class="cell" data-execution_count="28">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb53"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb53-1"><a href="#cb53-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_process(iterable, batch_size):</span>
<span id="cb53-2"><a href="#cb53-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process items in batches"""</span></span>
<span id="cb53-3"><a href="#cb53-3" aria-hidden="true" tabindex="-1"></a>    iterator <span class="op">=</span> <span class="bu">iter</span>(iterable)</span>
<span id="cb53-4"><a href="#cb53-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb53-5"><a href="#cb53-5" aria-hidden="true" tabindex="-1"></a>        batch <span class="op">=</span> <span class="bu">list</span>(itertools.islice(iterator, batch_size))</span>
<span id="cb53-6"><a href="#cb53-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> batch:</span>
<span id="cb53-7"><a href="#cb53-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb53-8"><a href="#cb53-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> batch</span>
<span id="cb53-9"><a href="#cb53-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb53-10"><a href="#cb53-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb53-11"><a href="#cb53-11" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> <span class="bu">range</span>(<span class="dv">25</span>)</span>
<span id="cb53-12"><a href="#cb53-12" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> batch_process(data, <span class="dv">10</span>):</span>
<span id="cb53-13"><a href="#cb53-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Processing batch: </span><span class="sc">{</span>batch<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Processing batch: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing batch: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Processing batch: [20, 21, 22, 23, 24]</code></pre>
</div>
</div>
</section>
<section id="round-robin-scheduler" class="level3">
<h3 class="anchored" data-anchor-id="round-robin-scheduler" id="round-robin-scheduler">3. Round-Robin Scheduler</h3>
<div id="94ee84d0" class="cell" data-execution_count="29">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb55"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb55-1"><a href="#cb55-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> round_robin_scheduler(tasks, workers):</span>
<span id="cb55-2"><a href="#cb55-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Distribute tasks among workers in round-robin fashion"""</span></span>
<span id="cb55-3"><a href="#cb55-3" aria-hidden="true" tabindex="-1"></a>    worker_cycle <span class="op">=</span> itertools.cycle(workers)</span>
<span id="cb55-4"><a href="#cb55-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">list</span>(<span class="bu">zip</span>(tasks, worker_cycle))</span>
<span id="cb55-5"><a href="#cb55-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-6"><a href="#cb55-6" aria-hidden="true" tabindex="-1"></a>tasks <span class="op">=</span> [<span class="st">'task1'</span>, <span class="st">'task2'</span>, <span class="st">'task3'</span>, <span class="st">'task4'</span>, <span class="st">'task5'</span>]</span>
<span id="cb55-7"><a href="#cb55-7" aria-hidden="true" tabindex="-1"></a>workers <span class="op">=</span> [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>]</span>
<span id="cb55-8"><a href="#cb55-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb55-9"><a href="#cb55-9" aria-hidden="true" tabindex="-1"></a>schedule <span class="op">=</span> round_robin_scheduler(tasks, workers)</span>
<span id="cb55-10"><a href="#cb55-10" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> task, worker <span class="kw">in</span> schedule:</span>
<span id="cb55-11"><a href="#cb55-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>task<span class="sc">}</span><span class="ss"> -&gt; </span><span class="sc">{</span>worker<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>task1 -&gt; Alice
task2 -&gt; Bob
task3 -&gt; Charlie
task4 -&gt; Alice
task5 -&gt; Bob</code></pre>
</div>
</div>
</section>
<section id="sliding-window" class="level3">
<h3 class="anchored" data-anchor-id="sliding-window" id="sliding-window">4. Sliding Window</h3>
<div id="fed1029e" class="cell" data-execution_count="30">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb57"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb57-1"><a href="#cb57-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sliding_window(iterable, window_size):</span>
<span id="cb57-2"><a href="#cb57-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create sliding window of specified size"""</span></span>
<span id="cb57-3"><a href="#cb57-3" aria-hidden="true" tabindex="-1"></a>    iterators <span class="op">=</span> itertools.tee(iterable, window_size)</span>
<span id="cb57-4"><a href="#cb57-4" aria-hidden="true" tabindex="-1"></a>    iterators <span class="op">=</span> [itertools.islice(iterator, i, <span class="va">None</span>) </span>
<span id="cb57-5"><a href="#cb57-5" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> i, iterator <span class="kw">in</span> <span class="bu">enumerate</span>(iterators)]</span>
<span id="cb57-6"><a href="#cb57-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">zip</span>(<span class="op">*</span>iterators)</span>
<span id="cb57-7"><a href="#cb57-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb57-8"><a href="#cb57-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb57-9"><a href="#cb57-9" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>, <span class="dv">6</span>, <span class="dv">7</span>, <span class="dv">8</span>, <span class="dv">9</span>, <span class="dv">10</span>]</span>
<span id="cb57-10"><a href="#cb57-10" aria-hidden="true" tabindex="-1"></a>windows <span class="op">=</span> sliding_window(data, <span class="dv">3</span>)</span>
<span id="cb57-11"><a href="#cb57-11" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> window <span class="kw">in</span> windows:</span>
<span id="cb57-12"><a href="#cb57-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(window)</span>
<span id="cb57-13"><a href="#cb57-13" aria-hidden="true" tabindex="-1"></a><span class="co"># (1, 2, 3)</span></span>
<span id="cb57-14"><a href="#cb57-14" aria-hidden="true" tabindex="-1"></a><span class="co"># (2, 3, 4)</span></span>
<span id="cb57-15"><a href="#cb57-15" aria-hidden="true" tabindex="-1"></a><span class="co"># (3, 4, 5)</span></span>
<span id="cb57-16"><a href="#cb57-16" aria-hidden="true" tabindex="-1"></a><span class="co"># ...</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>(1, 2, 3)
(2, 3, 4)
(3, 4, 5)
(4, 5, 6)
(5, 6, 7)
(6, 7, 8)
(7, 8, 9)
(8, 9, 10)</code></pre>
</div>
</div>
</section>
<section id="pairwise-iteration" class="level3">
<h3 class="anchored" data-anchor-id="pairwise-iteration" id="pairwise-iteration">5. Pairwise Iteration</h3>
<div id="5d54c14c" class="cell" data-execution_count="31">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb59"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb59-1"><a href="#cb59-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> pairwise(iterable):</span>
<span id="cb59-2"><a href="#cb59-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Return successive overlapping pairs"""</span></span>
<span id="cb59-3"><a href="#cb59-3" aria-hidden="true" tabindex="-1"></a>    a, b <span class="op">=</span> itertools.tee(iterable)</span>
<span id="cb59-4"><a href="#cb59-4" aria-hidden="true" tabindex="-1"></a>    <span class="bu">next</span>(b, <span class="va">None</span>)</span>
<span id="cb59-5"><a href="#cb59-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">zip</span>(a, b)</span>
<span id="cb59-6"><a href="#cb59-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb59-7"><a href="#cb59-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb59-8"><a href="#cb59-8" aria-hidden="true" tabindex="-1"></a>numbers <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb59-9"><a href="#cb59-9" aria-hidden="true" tabindex="-1"></a>pairs <span class="op">=</span> pairwise(numbers)</span>
<span id="cb59-10"><a href="#cb59-10" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> pair <span class="kw">in</span> pairs:</span>
<span id="cb59-11"><a href="#cb59-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(pair)</span>
<span id="cb59-12"><a href="#cb59-12" aria-hidden="true" tabindex="-1"></a><span class="co"># (1, 2)</span></span>
<span id="cb59-13"><a href="#cb59-13" aria-hidden="true" tabindex="-1"></a><span class="co"># (2, 3)</span></span>
<span id="cb59-14"><a href="#cb59-14" aria-hidden="true" tabindex="-1"></a><span class="co"># (3, 4)</span></span>
<span id="cb59-15"><a href="#cb59-15" aria-hidden="true" tabindex="-1"></a><span class="co"># (4, 5)</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>(1, 2)
(2, 3)
(3, 4)
(4, 5)</code></pre>
</div>
</div>
<hr>
</section>
</section>
<section id="performance-tips" class="level2">
<h2 class="anchored" data-anchor-id="performance-tips" id="performance-tips">Performance Tips</h2>
<section id="memory-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficiency" id="memory-efficiency">1. Memory Efficiency</h3>
<div id="4741ecba" class="cell" data-execution_count="32">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb61"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb61-1"><a href="#cb61-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Bad: Creates entire list in memory</span></span>
<span id="cb61-2"><a href="#cb61-2" aria-hidden="true" tabindex="-1"></a>large_range <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">1000000</span>))</span>
<span id="cb61-3"><a href="#cb61-3" aria-hidden="true" tabindex="-1"></a>squared <span class="op">=</span> [x<span class="op">**</span><span class="dv">2</span> <span class="cf">for</span> x <span class="kw">in</span> large_range]</span>
<span id="cb61-4"><a href="#cb61-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb61-5"><a href="#cb61-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Good: Uses iterators</span></span>
<span id="cb61-6"><a href="#cb61-6" aria-hidden="true" tabindex="-1"></a>large_range <span class="op">=</span> <span class="bu">range</span>(<span class="dv">1000000</span>)</span>
<span id="cb61-7"><a href="#cb61-7" aria-hidden="true" tabindex="-1"></a>squared <span class="op">=</span> <span class="bu">map</span>(<span class="kw">lambda</span> x: x<span class="op">**</span><span class="dv">2</span>, large_range)</span></code></pre></div></div>
</div>
</section>
<section id="lazy-evaluation" class="level3">
<h3 class="anchored" data-anchor-id="lazy-evaluation" id="lazy-evaluation">2. Lazy Evaluation</h3>
<div id="c161b33d" class="cell" data-execution_count="33">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb62"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb62-1"><a href="#cb62-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Itertools functions are lazy - they don't compute until needed</span></span>
<span id="cb62-2"><a href="#cb62-2" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> <span class="bu">range</span>(<span class="dv">1000000</span>)</span>
<span id="cb62-3"><a href="#cb62-3" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> itertools.filterfalse(<span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span>, data)</span>
<span id="cb62-4"><a href="#cb62-4" aria-hidden="true" tabindex="-1"></a><span class="co"># No computation happens here yet</span></span>
<span id="cb62-5"><a href="#cb62-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb62-6"><a href="#cb62-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Only compute what you need</span></span>
<span id="cb62-7"><a href="#cb62-7" aria-hidden="true" tabindex="-1"></a>first_10_odds <span class="op">=</span> <span class="bu">list</span>(itertools.islice(filtered, <span class="dv">10</span>))</span></code></pre></div></div>
</div>
</section>
<section id="chaining-operations" class="level3">
<h3 class="anchored" data-anchor-id="chaining-operations" id="chaining-operations">3. Chaining Operations</h3>
<div id="46bdd5a4" class="cell" data-execution_count="34">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb63"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb63-1"><a href="#cb63-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Chain multiple itertools operations for complex processing</span></span>
<span id="cb63-2"><a href="#cb63-2" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> <span class="bu">range</span>(<span class="dv">100</span>)</span>
<span id="cb63-3"><a href="#cb63-3" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> itertools.takewhile(</span>
<span id="cb63-4"><a href="#cb63-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">lambda</span> x: x <span class="op">&lt;</span> <span class="dv">50</span>,</span>
<span id="cb63-5"><a href="#cb63-5" aria-hidden="true" tabindex="-1"></a>    itertools.filterfalse(</span>
<span id="cb63-6"><a href="#cb63-6" aria-hidden="true" tabindex="-1"></a>        <span class="kw">lambda</span> x: x <span class="op">%</span> <span class="dv">3</span> <span class="op">==</span> <span class="dv">0</span>,</span>
<span id="cb63-7"><a href="#cb63-7" aria-hidden="true" tabindex="-1"></a>        itertools.accumulate(data)</span>
<span id="cb63-8"><a href="#cb63-8" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb63-9"><a href="#cb63-9" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="common-patterns-and-recipes" class="level2">
<h2 class="anchored" data-anchor-id="common-patterns-and-recipes" id="common-patterns-and-recipes">Common Patterns and Recipes</h2>
<section id="flatten-nested-iterables" class="level3">
<h3 class="anchored" data-anchor-id="flatten-nested-iterables" id="flatten-nested-iterables">1. Flatten Nested Iterables</h3>
<div id="d7c7bf16" class="cell" data-execution_count="35">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb64"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb64-1"><a href="#cb64-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> flatten(nested_iterable):</span>
<span id="cb64-2"><a href="#cb64-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Completely flatten a nested iterable"""</span></span>
<span id="cb64-3"><a href="#cb64-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> item <span class="kw">in</span> nested_iterable:</span>
<span id="cb64-4"><a href="#cb64-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(item, <span class="st">'__iter__'</span>) <span class="kw">and</span> <span class="kw">not</span> <span class="bu">isinstance</span>(item, (<span class="bu">str</span>, <span class="bu">bytes</span>)):</span>
<span id="cb64-5"><a href="#cb64-5" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> <span class="cf">from</span> flatten(item)</span>
<span id="cb64-6"><a href="#cb64-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb64-7"><a href="#cb64-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> item</span>
<span id="cb64-8"><a href="#cb64-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb64-9"><a href="#cb64-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Example</span></span>
<span id="cb64-10"><a href="#cb64-10" aria-hidden="true" tabindex="-1"></a>nested <span class="op">=</span> [<span class="dv">1</span>, [<span class="dv">2</span>, <span class="dv">3</span>], [<span class="dv">4</span>, [<span class="dv">5</span>, <span class="dv">6</span>]], <span class="dv">7</span>]</span>
<span id="cb64-11"><a href="#cb64-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(flatten(nested)))  <span class="co"># [1, 2, 3, 4, 5, 6, 7]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5, 6, 7]</code></pre>
</div>
</div>
</section>
<section id="unique-elements-preserving-order" class="level3">
<h3 class="anchored" data-anchor-id="unique-elements-preserving-order" id="unique-elements-preserving-order">2. Unique Elements (Preserving Order)</h3>
<div id="8b64cfb1" class="cell" data-execution_count="36">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb66"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb66-1"><a href="#cb66-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> unique_everseen(iterable, key<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb66-2"><a href="#cb66-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""List unique elements, preserving order"""</span></span>
<span id="cb66-3"><a href="#cb66-3" aria-hidden="true" tabindex="-1"></a>    seen <span class="op">=</span> <span class="bu">set</span>()</span>
<span id="cb66-4"><a href="#cb66-4" aria-hidden="true" tabindex="-1"></a>    seen_add <span class="op">=</span> seen.add</span>
<span id="cb66-5"><a href="#cb66-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> key <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb66-6"><a href="#cb66-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> element <span class="kw">in</span> itertools.filterfalse(seen.<span class="fu">__contains__</span>, iterable):</span>
<span id="cb66-7"><a href="#cb66-7" aria-hidden="true" tabindex="-1"></a>            seen_add(element)</span>
<span id="cb66-8"><a href="#cb66-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> element</span>
<span id="cb66-9"><a href="#cb66-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb66-10"><a href="#cb66-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> element <span class="kw">in</span> iterable:</span>
<span id="cb66-11"><a href="#cb66-11" aria-hidden="true" tabindex="-1"></a>            k <span class="op">=</span> key(element)</span>
<span id="cb66-12"><a href="#cb66-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> k <span class="kw">not</span> <span class="kw">in</span> seen:</span>
<span id="cb66-13"><a href="#cb66-13" aria-hidden="true" tabindex="-1"></a>                seen_add(k)</span>
<span id="cb66-14"><a href="#cb66-14" aria-hidden="true" tabindex="-1"></a>                <span class="cf">yield</span> element</span>
<span id="cb66-15"><a href="#cb66-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb66-16"><a href="#cb66-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Example</span></span>
<span id="cb66-17"><a href="#cb66-17" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">2</span>, <span class="dv">1</span>, <span class="dv">4</span>, <span class="dv">3</span>, <span class="dv">5</span>]</span>
<span id="cb66-18"><a href="#cb66-18" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">list</span>(unique_everseen(data)))  <span class="co"># [1, 2, 3, 4, 5]</span></span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[1, 2, 3, 4, 5]</code></pre>
</div>
</div>
</section>
<section id="consume-iterator" class="level3">
<h3 class="anchored" data-anchor-id="consume-iterator" id="consume-iterator">3. Consume Iterator</h3>
<div id="3f2d211a" class="cell" data-execution_count="37">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb68"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb68-1"><a href="#cb68-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> consume(iterator, n<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb68-2"><a href="#cb68-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Advance the iterator n-steps ahead. If n is None, consume entirely."""</span></span>
<span id="cb68-3"><a href="#cb68-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb68-4"><a href="#cb68-4" aria-hidden="true" tabindex="-1"></a>        <span class="co"># feed the entire iterator into a zero-length deque</span></span>
<span id="cb68-5"><a href="#cb68-5" aria-hidden="true" tabindex="-1"></a>        collections.deque(iterator, maxlen<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb68-6"><a href="#cb68-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb68-7"><a href="#cb68-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># advance to the empty slice starting at position n</span></span>
<span id="cb68-8"><a href="#cb68-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">next</span>(itertools.islice(iterator, n, n), <span class="va">None</span>)</span></code></pre></div></div>
</div>
<hr>
</section>
</section>
<section id="real-world-examples" class="level2">
<h2 class="anchored" data-anchor-id="real-world-examples" id="real-world-examples">Real-World Examples</h2>
<section id="example-1-data-processing-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="example-1-data-processing-pipeline" id="example-1-data-processing-pipeline">Example 1: Data Processing Pipeline</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb69"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb69-1"><a href="#cb69-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Processing CSV-like data</span></span>
<span id="cb69-2"><a href="#cb69-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_sales_data(data):</span>
<span id="cb69-3"><a href="#cb69-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process sales data with itertools."""</span></span>
<span id="cb69-4"><a href="#cb69-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Filter out header and empty lines</span></span>
<span id="cb69-5"><a href="#cb69-5" aria-hidden="true" tabindex="-1"></a>    clean_data <span class="op">=</span> itertools.filterfalse(</span>
<span id="cb69-6"><a href="#cb69-6" aria-hidden="true" tabindex="-1"></a>        <span class="kw">lambda</span> x: x.startswith(<span class="st">'Date'</span>) <span class="kw">or</span> <span class="kw">not</span> x.strip(), </span>
<span id="cb69-7"><a href="#cb69-7" aria-hidden="true" tabindex="-1"></a>        data</span>
<span id="cb69-8"><a href="#cb69-8" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb69-9"><a href="#cb69-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb69-10"><a href="#cb69-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Parse each line</span></span>
<span id="cb69-11"><a href="#cb69-11" aria-hidden="true" tabindex="-1"></a>    parsed <span class="op">=</span> (line.split(<span class="st">','</span>) <span class="cf">for</span> line <span class="kw">in</span> clean_data)</span>
<span id="cb69-12"><a href="#cb69-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb69-13"><a href="#cb69-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Group by month</span></span>
<span id="cb69-14"><a href="#cb69-14" aria-hidden="true" tabindex="-1"></a>    by_month <span class="op">=</span> itertools.groupby(</span>
<span id="cb69-15"><a href="#cb69-15" aria-hidden="true" tabindex="-1"></a>        <span class="bu">sorted</span>(parsed, key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">0</span>][:<span class="dv">7</span>]),  <span class="co"># Sort by year-month</span></span>
<span id="cb69-16"><a href="#cb69-16" aria-hidden="true" tabindex="-1"></a>        key<span class="op">=</span><span class="kw">lambda</span> x: x[<span class="dv">0</span>][:<span class="dv">7</span>]</span>
<span id="cb69-17"><a href="#cb69-17" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb69-18"><a href="#cb69-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb69-19"><a href="#cb69-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate monthly totals</span></span>
<span id="cb69-20"><a href="#cb69-20" aria-hidden="true" tabindex="-1"></a>    monthly_totals <span class="op">=</span> {}</span>
<span id="cb69-21"><a href="#cb69-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> month, sales <span class="kw">in</span> by_month:</span>
<span id="cb69-22"><a href="#cb69-22" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="bu">sum</span>(<span class="bu">float</span>(sale[<span class="dv">2</span>]) <span class="cf">for</span> sale <span class="kw">in</span> sales)</span>
<span id="cb69-23"><a href="#cb69-23" aria-hidden="true" tabindex="-1"></a>        monthly_totals[month] <span class="op">=</span> total</span>
<span id="cb69-24"><a href="#cb69-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb69-25"><a href="#cb69-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> monthly_totals</span>
<span id="cb69-26"><a href="#cb69-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb69-27"><a href="#cb69-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Sample data</span></span>
<span id="cb69-28"><a href="#cb69-28" aria-hidden="true" tabindex="-1"></a>sales_data <span class="op">=</span> [</span>
<span id="cb69-29"><a href="#cb69-29" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Date,Product,Amount"</span>,</span>
<span id="cb69-30"><a href="#cb69-30" aria-hidden="true" tabindex="-1"></a>    <span class="st">"2023-01-15,Widget,100.50"</span>,</span>
<span id="cb69-31"><a href="#cb69-31" aria-hidden="true" tabindex="-1"></a>    <span class="st">"2023-01-20,Gadget,75.25"</span>,</span>
<span id="cb69-32"><a href="#cb69-32" aria-hidden="true" tabindex="-1"></a>    <span class="st">"2023-02-10,Widget,120.00"</span>,</span>
<span id="cb69-33"><a href="#cb69-33" aria-hidden="true" tabindex="-1"></a>    <span class="st">"2023-02-15,Gadget,85.75"</span>,</span>
<span id="cb69-34"><a href="#cb69-34" aria-hidden="true" tabindex="-1"></a>    <span class="st">""</span>,</span>
<span id="cb69-35"><a href="#cb69-35" aria-hidden="true" tabindex="-1"></a>    <span class="st">"2023-01-25,Widget,95.00"</span></span>
<span id="cb69-36"><a href="#cb69-36" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb69-37"><a href="#cb69-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb69-38"><a href="#cb69-38" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> process_sales_data(sales_data)</span>
<span id="cb69-39"><a href="#cb69-39" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)</span></code></pre></div></div>
</section>
<section id="example-2-configuration-generator" class="level3">
<h3 class="anchored" data-anchor-id="example-2-configuration-generator" id="example-2-configuration-generator">Example 2: Configuration Generator</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb70"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb70-1"><a href="#cb70-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Generate all possible configurations</span></span>
<span id="cb70-2"><a href="#cb70-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> generate_configurations(options):</span>
<span id="cb70-3"><a href="#cb70-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Generate all possible configuration combinations."""</span></span>
<span id="cb70-4"><a href="#cb70-4" aria-hidden="true" tabindex="-1"></a>    keys <span class="op">=</span> <span class="bu">list</span>(options.keys())</span>
<span id="cb70-5"><a href="#cb70-5" aria-hidden="true" tabindex="-1"></a>    values <span class="op">=</span> <span class="bu">list</span>(options.values())</span>
<span id="cb70-6"><a href="#cb70-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb70-7"><a href="#cb70-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> combo <span class="kw">in</span> itertools.product(<span class="op">*</span>values):</span>
<span id="cb70-8"><a href="#cb70-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> <span class="bu">dict</span>(<span class="bu">zip</span>(keys, combo))</span>
<span id="cb70-9"><a href="#cb70-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-10"><a href="#cb70-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb70-11"><a href="#cb70-11" aria-hidden="true" tabindex="-1"></a>server_options <span class="op">=</span> {</span>
<span id="cb70-12"><a href="#cb70-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">'cpu'</span>: [<span class="st">'2-core'</span>, <span class="st">'4-core'</span>, <span class="st">'8-core'</span>],</span>
<span id="cb70-13"><a href="#cb70-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">'memory'</span>: [<span class="st">'4GB'</span>, <span class="st">'8GB'</span>, <span class="st">'16GB'</span>],</span>
<span id="cb70-14"><a href="#cb70-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">'storage'</span>: [<span class="st">'SSD'</span>, <span class="st">'HDD'</span>],</span>
<span id="cb70-15"><a href="#cb70-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">'os'</span>: [<span class="st">'Linux'</span>, <span class="st">'Windows'</span>]</span>
<span id="cb70-16"><a href="#cb70-16" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb70-17"><a href="#cb70-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb70-18"><a href="#cb70-18" aria-hidden="true" tabindex="-1"></a>configs <span class="op">=</span> <span class="bu">list</span>(generate_configurations(server_options))</span>
<span id="cb70-19"><a href="#cb70-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Total configurations: </span><span class="sc">{</span><span class="bu">len</span>(configs)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb70-20"><a href="#cb70-20" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> config <span class="kw">in</span> configs[:<span class="dv">3</span>]:  <span class="co"># Show first 3</span></span>
<span id="cb70-21"><a href="#cb70-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(config)</span></code></pre></div></div>
</section>
<section id="example-3-batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="example-3-batch-processing" id="example-3-batch-processing">Example 3: Batch Processing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb71"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb71-1"><a href="#cb71-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_process(items, batch_size, process_func):</span>
<span id="cb71-2"><a href="#cb71-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process items in batches."""</span></span>
<span id="cb71-3"><a href="#cb71-3" aria-hidden="true" tabindex="-1"></a>    iterator <span class="op">=</span> <span class="bu">iter</span>(items)</span>
<span id="cb71-4"><a href="#cb71-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb71-5"><a href="#cb71-5" aria-hidden="true" tabindex="-1"></a>        batch <span class="op">=</span> <span class="bu">list</span>(itertools.islice(iterator, batch_size))</span>
<span id="cb71-6"><a href="#cb71-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> batch:</span>
<span id="cb71-7"><a href="#cb71-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb71-8"><a href="#cb71-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> process_func(batch)</span>
<span id="cb71-9"><a href="#cb71-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb71-10"><a href="#cb71-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sum_batch(batch):</span>
<span id="cb71-11"><a href="#cb71-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">sum</span>(batch)</span>
<span id="cb71-12"><a href="#cb71-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb71-13"><a href="#cb71-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb71-14"><a href="#cb71-14" aria-hidden="true" tabindex="-1"></a>large_numbers <span class="op">=</span> <span class="bu">range</span>(<span class="dv">1000</span>)</span>
<span id="cb71-15"><a href="#cb71-15" aria-hidden="true" tabindex="-1"></a>batch_sums <span class="op">=</span> <span class="bu">list</span>(batch_process(large_numbers, <span class="dv">100</span>, sum_batch))</span>
<span id="cb71-16"><a href="#cb71-16" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Batch sums: </span><span class="sc">{</span>batch_sums[:<span class="dv">5</span>]<span class="sc">}</span><span class="ss">..."</span>)  <span class="co"># Show first 5 batch sums</span></span></code></pre></div></div>
<hr>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<ol type="1">
<li><strong>Use itertools for memory-efficient processing</strong>: When working with large datasets, itertools can help avoid loading everything into memory.</li>
<li><strong>Combine with other functional programming tools</strong>: itertools works well with <code>map()</code>, <code>filter()</code>, and <code>functools.reduce()</code>.</li>
<li><strong>Remember lazy evaluation</strong>: Most itertools functions return iterators, not lists. Use <code>list()</code> when you need to materialize the results.</li>
<li><strong>Profile your code</strong>: While itertools is generally efficient, measure performance for your specific use case.</li>
<li><strong>Consider readability</strong>: Sometimes a simple loop is clearer than a complex itertools chain.</li>
<li><strong>Use type hints</strong>: When writing functions that use itertools, consider adding type hints for better code documentation.</li>
<li><strong>Sort before grouping</strong>: <code>groupby()</code> only groups consecutive identical elements, so sort your data first if needed.</li>
<li><strong>Use <code>tee()</code> carefully</strong>: Each iterator from <code>tee()</code> maintains its own internal buffer, which can consume significant memory if iterators advance at different rates.</li>
<li><strong>Profile your code</strong>: For performance-critical applications, measure whether itertools or other approaches (like NumPy) are faster for your specific use case.</li>
</ol>
<hr>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The itertools module provides powerful tools for creating efficient, memory-friendly iterators. By mastering these functions, you can write more elegant and performant Python code, especially when dealing with large datasets or complex iteration patterns. The key is understanding when and how to use each function effectively in your specific use cases.</p>
<p>Remember that itertools excels at functional programming patterns and can often replace complex loops with more readable and efficient iterator chains. Practice with these examples and experiment with combining different itertools functions to solve your specific problems.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python Multiprocessing and Multithreading: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-multi-star/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-multi-star/</guid>
      <pubDate>Sun, 06 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="python-multiprocessing-and-multithreading-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-multi-star/multi-star.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Python provides two primary approaches for concurrent execution: <strong>multithreading</strong> and <strong>multiprocessing</strong>. Understanding when and how to use each is crucial for writing efficient Python applications.</p>
<ul>
<li><strong>Multithreading</strong>: Multiple threads within a single process sharing memory space</li>
<li><strong>Multiprocessing</strong>: Multiple separate processes, each with its own memory space</li>
</ul>
</section>
<section id="understanding-concurrency-vs-parallelism" class="level2">
<h2 class="anchored" data-anchor-id="understanding-concurrency-vs-parallelism" id="understanding-concurrency-vs-parallelism">Understanding Concurrency vs Parallelism</h2>
<section id="concurrency" class="level3">
<h3 class="anchored" data-anchor-id="concurrency" id="concurrency">Concurrency</h3>
<p>Concurrency is about dealing with multiple tasks at once, but not necessarily executing them simultaneously. Tasks may be interleaved or switched between rapidly.</p>
</section>
<section id="parallelism" class="level3">
<h3 class="anchored" data-anchor-id="parallelism" id="parallelism">Parallelism</h3>
<p>Parallelism is about executing multiple tasks simultaneously, typically on multiple CPU cores.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Concurrent execution (may not be parallel)</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task(name):</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Task </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Create threads</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>t1 <span class="op">=</span> threading.Thread(target<span class="op">=</span>task, args<span class="op">=</span>(<span class="st">"A"</span>,))</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>t2 <span class="op">=</span> threading.Thread(target<span class="op">=</span>task, args<span class="op">=</span>(<span class="st">"B"</span>,))</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Start threads</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>t1.start()</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>t2.start()</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Wait for completion</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>t1.join()</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>t2.join()</span></code></pre></div></div>
</section>
</section>
<section id="the-global-interpreter-lock-gil" class="level2">
<h2 class="anchored" data-anchor-id="the-global-interpreter-lock-gil" id="the-global-interpreter-lock-gil">The Global Interpreter Lock (GIL)</h2>
<p>The GIL is a mutex that protects access to Python objects, preventing multiple threads from executing Python bytecode simultaneously. This has important implications:</p>
<section id="gil-impact" class="level3">
<h3 class="anchored" data-anchor-id="gil-impact" id="gil-impact">GIL Impact</h3>
<ul>
<li><strong>CPU-bound tasks</strong>: Multithreading provides little benefit due to GIL</li>
<li><strong>I/O-bound tasks</strong>: Multithreading can be effective as GIL is released during I/O operations</li>
<li><strong>Multiprocessing</strong>: Bypasses GIL limitations by using separate processes</li>
</ul>
</section>
<section id="when-gil-is-released" class="level3">
<h3 class="anchored" data-anchor-id="when-gil-is-released" id="when-gil-is-released">When GIL is Released</h3>
<ul>
<li>File I/O operations</li>
<li>Network I/O operations</li>
<li>Image processing (PIL/Pillow)</li>
<li>NumPy operations</li>
<li>Time.sleep() calls</li>
</ul>
</section>
</section>
<section id="multithreading-with-threading-module" class="level2">
<h2 class="anchored" data-anchor-id="multithreading-with-threading-module" id="multithreading-with-threading-module">Multithreading with threading Module</h2>
<section id="basic-thread-creation" class="level3">
<h3 class="anchored" data-anchor-id="basic-thread-creation" id="basic-thread-creation">Basic Thread Creation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Method 1: Using Thread class directly</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker_function(name, delay):</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        time.sleep(delay)</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and start threads</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>thread1 <span class="op">=</span> threading.Thread(target<span class="op">=</span>worker_function, args<span class="op">=</span>(<span class="st">"A"</span>, <span class="fl">0.5</span>))</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>thread2 <span class="op">=</span> threading.Thread(target<span class="op">=</span>worker_function, args<span class="op">=</span>(<span class="st">"B"</span>, <span class="fl">0.3</span>))</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>thread1.start()</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>thread2.start()</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>thread1.join()</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>thread2.join()</span></code></pre></div></div>
</section>
<section id="thread-subclassing" class="level3">
<h3 class="anchored" data-anchor-id="thread-subclassing" id="thread-subclassing">Thread Subclassing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> WorkerThread(threading.Thread):</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, name, delay):</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.name <span class="op">=</span> name</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.delay <span class="op">=</span> delay</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> run(<span class="va">self</span>):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="va">self</span>.delay)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and start threads</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>worker1 <span class="op">=</span> WorkerThread(<span class="st">"A"</span>, <span class="fl">0.5</span>)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>worker2 <span class="op">=</span> WorkerThread(<span class="st">"B"</span>, <span class="fl">0.3</span>)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>worker1.start()</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>worker2.start()</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>worker1.join()</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>worker2.join()</span></code></pre></div></div>
</section>
<section id="thread-pool-executor" class="level3">
<h3 class="anchored" data-anchor-id="thread-pool-executor" id="thread-pool-executor">Thread Pool Executor</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, as_completed</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task(name, duration):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Starting task </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    time.sleep(duration)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Task </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss"> completed"</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Using ThreadPoolExecutor</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">3</span>) <span class="im">as</span> executor:</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Submit tasks</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    futures <span class="op">=</span> [</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        executor.submit(task, <span class="st">"A"</span>, <span class="dv">2</span>),</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        executor.submit(task, <span class="st">"B"</span>, <span class="dv">1</span>),</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        executor.submit(task, <span class="st">"C"</span>, <span class="dv">3</span>)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Collect results as they complete</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> future <span class="kw">in</span> as_completed(futures):</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> future.result()</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(result)</span></code></pre></div></div>
</section>
<section id="thread-safe-operations" class="level3">
<h3 class="anchored" data-anchor-id="thread-safe-operations" id="thread-safe-operations">Thread-Safe Operations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ThreadSafeCounter:</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.value <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lock <span class="op">=</span> threading.Lock()</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> increment(<span class="va">self</span>):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="va">self</span>.lock:</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>            temp <span class="op">=</span> <span class="va">self</span>.value</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="fl">0.001</span>)  <span class="co"># Simulate processing</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.value <span class="op">=</span> temp <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_value(<span class="va">self</span>):</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="va">self</span>.lock:</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.value</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstrate thread safety</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>counter <span class="op">=</span> ThreadSafeCounter()</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker():</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        counter.increment()</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> []</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> threading.Thread(target<span class="op">=</span>worker)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    threads.append(t)</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>    t.join()</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Final counter value: </span><span class="sc">{</span>counter<span class="sc">.</span>get_value()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="multiprocessing-with-multiprocessing-module" class="level2">
<h2 class="anchored" data-anchor-id="multiprocessing-with-multiprocessing-module" id="multiprocessing-with-multiprocessing-module">Multiprocessing with multiprocessing Module</h2>
<section id="basic-process-creation" class="level3">
<h3 class="anchored" data-anchor-id="basic-process-creation" id="basic-process-creation">Basic Process Creation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker_function(name, delay):</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    process_id <span class="op">=</span> os.getpid()</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss"> (PID: </span><span class="sc">{</span>process_id<span class="sc">}</span><span class="ss">): </span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        time.sleep(delay)</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create and start processes</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    process1 <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>worker_function, args<span class="op">=</span>(<span class="st">"A"</span>, <span class="fl">0.5</span>))</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    process2 <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>worker_function, args<span class="op">=</span>(<span class="st">"B"</span>, <span class="fl">0.3</span>))</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    process1.start()</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    process2.start()</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    process1.join()</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    process2.join()</span></code></pre></div></div>
</section>
<section id="process-pool" class="level3">
<h3 class="anchored" data-anchor-id="process-pool" id="process-pool">Process Pool</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compute_square(n):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""CPU-intensive task"""</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">*</span> n</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compute_with_delay(n):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simulate processing time"""</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">*</span> n</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    numbers <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">1</span>, <span class="dv">11</span>))</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sequential execution</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    sequential_results <span class="op">=</span> [compute_with_delay(n) <span class="cf">for</span> n <span class="kw">in</span> numbers]</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    sequential_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Parallel execution</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> multiprocessing.Pool(processes<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> pool:</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        parallel_results <span class="op">=</span> pool.<span class="bu">map</span>(compute_with_delay, numbers)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    parallel_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Sequential time: </span><span class="sc">{</span>sequential_time<span class="sc">:.2f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Parallel time: </span><span class="sc">{</span>parallel_time<span class="sc">:.2f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Speedup: </span><span class="sc">{</span>sequential_time<span class="op">/</span>parallel_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span></code></pre></div></div>
</section>
<section id="process-pool-executor" class="level3">
<h3 class="anchored" data-anchor-id="process-pool-executor" id="process-pool-executor">Process Pool Executor</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ProcessPoolExecutor, as_completed</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_intensive_task(n):</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simulate CPU-intensive computation"""</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n <span class="op">*</span> <span class="dv">1000000</span>):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> i</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    tasks <span class="op">=</span> [<span class="dv">100</span>, <span class="dv">200</span>, <span class="dv">300</span>, <span class="dv">400</span>, <span class="dv">500</span>]</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Submit all tasks</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        futures <span class="op">=</span> [executor.submit(cpu_intensive_task, task) <span class="cf">for</span> task <span class="kw">in</span> tasks]</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Collect results</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, future <span class="kw">in</span> <span class="bu">enumerate</span>(as_completed(futures)):</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> future.result()</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Task </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss"> completed with result: </span><span class="sc">{</span>result<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="communication-between-processesthreads" class="level2">
<h2 class="anchored" data-anchor-id="communication-between-processesthreads" id="communication-between-processesthreads">Communication Between Processes/Threads</h2>
<section id="queues" class="level3">
<h3 class="anchored" data-anchor-id="queues" id="queues">Queues</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Process Queue</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> producer(queue, items):</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> item <span class="kw">in</span> items:</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        queue.put(item)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Produced: </span><span class="sc">{</span>item<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    queue.put(<span class="va">None</span>)  <span class="co"># Sentinel value</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> consumer(queue):</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        item <span class="op">=</span> queue.get()</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> item <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Consumed: </span><span class="sc">{</span>item<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.2</span>)</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Process communication</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>    process_queue <span class="op">=</span> multiprocessing.Queue()</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    items <span class="op">=</span> [<span class="st">'item1'</span>, <span class="st">'item2'</span>, <span class="st">'item3'</span>, <span class="st">'item4'</span>]</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    producer_process <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>producer, args<span class="op">=</span>(process_queue, items))</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    consumer_process <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>consumer, args<span class="op">=</span>(process_queue,))</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>    producer_process.start()</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    consumer_process.start()</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>    producer_process.join()</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>    consumer_process.join()</span></code></pre></div></div>
</section>
<section id="pipes" class="level3">
<h3 class="anchored" data-anchor-id="pipes" id="pipes">Pipes</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sender(conn, messages):</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> msg <span class="kw">in</span> messages:</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        conn.send(msg)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Sent: </span><span class="sc">{</span>msg<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    conn.close()</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> receiver(conn):</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>            msg <span class="op">=</span> conn.recv()</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Received: </span><span class="sc">{</span>msg<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">EOFError</span>:</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    parent_conn, child_conn <span class="op">=</span> multiprocessing.Pipe()</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    messages <span class="op">=</span> [<span class="st">'Hello'</span>, <span class="st">'World'</span>, <span class="st">'From'</span>, <span class="st">'Process'</span>]</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    sender_process <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>sender, args<span class="op">=</span>(child_conn, messages))</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    receiver_process <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>receiver, args<span class="op">=</span>(parent_conn,))</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    sender_process.start()</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>    receiver_process.start()</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>    sender_process.join()</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>    receiver_process.join()</span></code></pre></div></div>
</section>
<section id="shared-memory" class="level3">
<h3 class="anchored" data-anchor-id="shared-memory" id="shared-memory">Shared Memory</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker(shared_list, shared_value, lock, worker_id):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock:</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>            shared_value.value <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>            shared_list[worker_id] <span class="op">=</span> shared_value.value</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>worker_id<span class="sc">}</span><span class="ss">: Updated shared_value to </span><span class="sc">{</span>shared_value<span class="sc">.</span>value<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create shared objects</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    shared_list <span class="op">=</span> multiprocessing.Array(<span class="st">'i'</span>, [<span class="dv">0</span>] <span class="op">*</span> <span class="dv">3</span>)  <span class="co"># Array of integers</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    shared_value <span class="op">=</span> multiprocessing.Value(<span class="st">'i'</span>, <span class="dv">0</span>)       <span class="co"># Single integer</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    lock <span class="op">=</span> multiprocessing.Lock()</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    processes <span class="op">=</span> []</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>):</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        p <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>worker, args<span class="op">=</span>(shared_list, shared_value, lock, i))</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        processes.append(p)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        p.start()</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        p.join()</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Final shared_list: </span><span class="sc">{</span><span class="bu">list</span>(shared_list)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Final shared_value: </span><span class="sc">{</span>shared_value<span class="sc">.</span>value<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="synchronization-primitives" class="level2">
<h2 class="anchored" data-anchor-id="synchronization-primitives" id="synchronization-primitives">Synchronization Primitives</h2>
<section id="locks" class="level3">
<h3 class="anchored" data-anchor-id="locks" id="locks">Locks</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Thread Lock</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>shared_resource <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>lock <span class="op">=</span> threading.Lock()</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> increment_with_lock():</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> shared_resource</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100000</span>):</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock:</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>            shared_resource <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> increment_without_lock():</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> shared_resource</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100000</span>):</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        shared_resource <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Demonstrate race condition</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>shared_resource <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> []</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> threading.Thread(target<span class="op">=</span>increment_without_lock)</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>    threads.append(t)</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>    t.join()</span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Without lock: </span><span class="sc">{</span>shared_resource<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a><span class="co"># With lock</span></span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>shared_resource <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> []</span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> threading.Thread(target<span class="op">=</span>increment_with_lock)</span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>    threads.append(t)</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>    t.join()</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"With lock: </span><span class="sc">{</span>shared_resource<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="semaphores" class="level3">
<h3 class="anchored" data-anchor-id="semaphores" id="semaphores">Semaphores</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Semaphore to limit concurrent access</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>semaphore <span class="op">=</span> threading.Semaphore(<span class="dv">2</span>)  <span class="co"># Allow 2 concurrent accesses</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> access_resource(worker_id):</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> semaphore:</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>worker_id<span class="sc">}</span><span class="ss"> accessing resource"</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="dv">2</span>)  <span class="co"># Simulate work</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>worker_id<span class="sc">}</span><span class="ss"> finished"</span>)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> []</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> threading.Thread(target<span class="op">=</span>access_resource, args<span class="op">=</span>(i,))</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    threads.append(t)</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    t.join()</span></code></pre></div></div>
</section>
<section id="condition-variables" class="level3">
<h3 class="anchored" data-anchor-id="condition-variables" id="condition-variables">Condition Variables</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Producer-Consumer with Condition</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>condition <span class="op">=</span> threading.Condition()</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="bu">buffer</span> <span class="op">=</span> []</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>MAX_SIZE <span class="op">=</span> <span class="dv">5</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> producer():</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> condition:</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">while</span> <span class="bu">len</span>(<span class="bu">buffer</span>) <span class="op">&gt;=</span> MAX_SIZE:</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="st">"Buffer full, producer waiting..."</span>)</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>                condition.wait()</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>            item <span class="op">=</span> <span class="ss">f"item_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>            <span class="bu">buffer</span>.append(item)</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Produced: </span><span class="sc">{</span>item<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>            condition.notify_all()</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>        time.sleep(random.uniform(<span class="fl">0.1</span>, <span class="fl">0.5</span>))</span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> consumer(consumer_id):</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> condition:</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">while</span> <span class="kw">not</span> <span class="bu">buffer</span>:</span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Consumer </span><span class="sc">{</span>consumer_id<span class="sc">}</span><span class="ss"> waiting..."</span>)</span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>                condition.wait()</span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>            item <span class="op">=</span> <span class="bu">buffer</span>.pop(<span class="dv">0</span>)</span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Consumer </span><span class="sc">{</span>consumer_id<span class="sc">}</span><span class="ss"> consumed: </span><span class="sc">{</span>item<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>            condition.notify_all()</span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>        time.sleep(random.uniform(<span class="fl">0.1</span>, <span class="fl">0.5</span>))</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Start producer and consumers</span></span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>producer_thread <span class="op">=</span> threading.Thread(target<span class="op">=</span>producer)</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>consumer_threads <span class="op">=</span> [threading.Thread(target<span class="op">=</span>consumer, args<span class="op">=</span>(i,)) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">2</span>)]</span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>producer_thread.start()</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> consumer_threads:</span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>producer_thread.join()</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> consumer_threads:</span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>    t.join()</span></code></pre></div></div>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<section id="io-bound-tasks" class="level3">
<h3 class="anchored" data-anchor-id="io-bound-tasks" id="io-bound-tasks">I/O-Bound Tasks</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, ProcessPoolExecutor</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fetch_url(url):</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simulate I/O-bound task"""</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        response <span class="op">=</span> requests.get(url, timeout<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Status: </span><span class="sc">{</span>response<span class="sc">.</span>status_code<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span>:</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="st">"Error"</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> time_execution(func, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>    end <span class="op">=</span> time.time()</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result, end <span class="op">-</span> start</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Sequential execution</span></span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sequential_fetch(urls):</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [fetch_url(url) <span class="cf">for</span> url <span class="kw">in</span> urls]</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Threaded execution</span></span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> threaded_fetch(urls):</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">10</span>) <span class="im">as</span> executor:</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(fetch_url, urls))</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Process execution</span></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_fetch(urls):</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="dv">10</span>) <span class="im">as</span> executor:</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(fetch_url, urls))</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>    urls <span class="op">=</span> [<span class="st">'https://httpbin.org/delay/1'</span>] <span class="op">*</span> <span class="dv">10</span></span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compare performance</span></span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>    _, seq_time <span class="op">=</span> time_execution(sequential_fetch, urls)</span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>    _, thread_time <span class="op">=</span> time_execution(threaded_fetch, urls)</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>    _, process_time <span class="op">=</span> time_execution(process_fetch, urls)</span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Sequential: </span><span class="sc">{</span>seq_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Threading: </span><span class="sc">{</span>thread_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Multiprocessing: </span><span class="sc">{</span>process_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span></code></pre></div></div>
</section>
<section id="cpu-bound-tasks" class="level3">
<h3 class="anchored" data-anchor-id="cpu-bound-tasks" id="cpu-bound-tasks">CPU-Bound Tasks</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, ProcessPoolExecutor</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_bound_task(n):</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""CPU-intensive computation"""</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n):</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> i <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compare_performance():</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    numbers <span class="op">=</span> [<span class="dv">1000000</span>] <span class="op">*</span> <span class="dv">8</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sequential</span></span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    sequential_results <span class="op">=</span> [cpu_bound_task(n) <span class="cf">for</span> n <span class="kw">in</span> numbers]</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>    sequential_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Threading</span></span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">8</span>) <span class="im">as</span> executor:</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>        thread_results <span class="op">=</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(cpu_bound_task, numbers))</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    thread_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Multiprocessing</span></span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="dv">8</span>) <span class="im">as</span> executor:</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a>        process_results <span class="op">=</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(cpu_bound_task, numbers))</span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>    process_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-32"><a href="#cb16-32" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"CPU-bound task comparison:"</span>)</span>
<span id="cb16-33"><a href="#cb16-33" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Sequential: </span><span class="sc">{</span>sequential_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb16-34"><a href="#cb16-34" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Threading: </span><span class="sc">{</span>thread_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb16-35"><a href="#cb16-35" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Multiprocessing: </span><span class="sc">{</span>process_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb16-36"><a href="#cb16-36" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Process speedup: </span><span class="sc">{</span>sequential_time<span class="op">/</span>process_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span>
<span id="cb16-37"><a href="#cb16-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-38"><a href="#cb16-38" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb16-39"><a href="#cb16-39" aria-hidden="true" tabindex="-1"></a>    compare_performance()</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="choose-the-right-approach" class="level3">
<h3 class="anchored" data-anchor-id="choose-the-right-approach" id="choose-the-right-approach">1. Choose the Right Approach</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># For I/O-bound tasks: Use threading</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> io_bound_work():</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># File operations, network requests, database queries</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a><span class="co"># For CPU-bound tasks: Use multiprocessing</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ProcessPoolExecutor</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_bound_work():</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Mathematical computations, image processing, data analysis</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</section>
<section id="resource-management" class="level3">
<h3 class="anchored" data-anchor-id="resource-management" id="resource-management">2. Resource Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a><span class="at">@contextmanager</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> managed_thread_pool(max_workers):</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span>max_workers) <span class="im">as</span> executor:</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> executor</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="at">@contextmanager</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> managed_process_pool(max_workers):</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span>max_workers) <span class="im">as</span> executor:</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> executor</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> managed_thread_pool(<span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    futures <span class="op">=</span> [executor.submit(some_function, arg) <span class="cf">for</span> arg <span class="kw">in</span> args]</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> [future.result() <span class="cf">for</span> future <span class="kw">in</span> futures]</span></code></pre></div></div>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">3. Error Handling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, as_completed</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_worker(task_id):</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Your work here</span></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="ss">f"Task </span><span class="sc">{</span>task_id<span class="sc">}</span><span class="ss"> completed"</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        logging.error(<span class="ss">f"Task </span><span class="sc">{</span>task_id<span class="sc">}</span><span class="ss"> failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> execute_with_error_handling():</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>        futures <span class="op">=</span> [executor.submit(safe_worker, i) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>)]</span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> future <span class="kw">in</span> as_completed(futures):</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>                result <span class="op">=</span> future.result()</span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> result:</span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(result)</span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>                logging.error(<span class="ss">f"Future failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="graceful-shutdown" class="level3">
<h3 class="anchored" data-anchor-id="graceful-shutdown" id="graceful-shutdown">4. Graceful Shutdown</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> signal</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sys</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> GracefulWorker:</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.shutdown_event <span class="op">=</span> threading.Event()</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.threads <span class="op">=</span> []</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> worker(<span class="va">self</span>, worker_id):</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">while</span> <span class="kw">not</span> <span class="va">self</span>.shutdown_event.is_set():</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>worker_id<span class="sc">}</span><span class="ss"> working..."</span>)</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="dv">1</span>)</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>worker_id<span class="sc">}</span><span class="ss"> shutting down"</span>)</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> start_workers(<span class="va">self</span>, num_workers):</span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_workers):</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>            t <span class="op">=</span> threading.Thread(target<span class="op">=</span><span class="va">self</span>.worker, args<span class="op">=</span>(i,))</span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>            t.start()</span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.threads.append(t)</span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> shutdown(<span class="va">self</span>):</span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Initiating graceful shutdown..."</span>)</span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.shutdown_event.<span class="bu">set</span>()</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> t <span class="kw">in</span> <span class="va">self</span>.threads:</span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>            t.join()</span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"All workers shut down"</span>)</span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>worker_manager <span class="op">=</span> GracefulWorker()</span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> signal_handler(signum, frame):</span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a>    worker_manager.shutdown()</span>
<span id="cb20-35"><a href="#cb20-35" aria-hidden="true" tabindex="-1"></a>    sys.exit(<span class="dv">0</span>)</span>
<span id="cb20-36"><a href="#cb20-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-37"><a href="#cb20-37" aria-hidden="true" tabindex="-1"></a>signal.signal(signal.SIGINT, signal_handler)</span>
<span id="cb20-38"><a href="#cb20-38" aria-hidden="true" tabindex="-1"></a>worker_manager.start_workers(<span class="dv">3</span>)</span>
<span id="cb20-39"><a href="#cb20-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-40"><a href="#cb20-40" aria-hidden="true" tabindex="-1"></a><span class="co"># Keep main thread alive</span></span>
<span id="cb20-41"><a href="#cb20-41" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb20-42"><a href="#cb20-42" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb20-43"><a href="#cb20-43" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="dv">1</span>)</span>
<span id="cb20-44"><a href="#cb20-44" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">KeyboardInterrupt</span>:</span>
<span id="cb20-45"><a href="#cb20-45" aria-hidden="true" tabindex="-1"></a>    worker_manager.shutdown()</span></code></pre></div></div>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="custom-thread-pool" class="level3">
<h3 class="anchored" data-anchor-id="custom-thread-pool" id="custom-thread-pool">1. Custom Thread Pool</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> queue</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleThreadPool:</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_workers):</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.task_queue <span class="op">=</span> queue.Queue()</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.workers <span class="op">=</span> []</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.shutdown <span class="op">=</span> <span class="va">False</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_workers):</span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>            worker <span class="op">=</span> threading.Thread(target<span class="op">=</span><span class="va">self</span>._worker)</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>            worker.start()</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.workers.append(worker)</span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _worker(<span class="va">self</span>):</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">while</span> <span class="kw">not</span> <span class="va">self</span>.shutdown:</span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>                task, args, kwargs <span class="op">=</span> <span class="va">self</span>.task_queue.get(timeout<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> task <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>                task(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.task_queue.task_done()</span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> queue.Empty:</span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>                <span class="cf">continue</span></span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> submit(<span class="va">self</span>, task, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.task_queue.put((task, args, kwargs))</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> close(<span class="va">self</span>):</span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.shutdown <span class="op">=</span> <span class="va">True</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="va">self</span>.workers:</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.task_queue.put((<span class="va">None</span>, (), {}))</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> worker <span class="kw">in</span> <span class="va">self</span>.workers:</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a>            worker.join()</span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sample_task(name, delay):</span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Task </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss"> starting"</span>)</span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a>    time.sleep(delay)</span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Task </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss"> completed"</span>)</span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>pool <span class="op">=</span> SimpleThreadPool(<span class="dv">3</span>)</span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a>    pool.submit(sample_task, <span class="ss">f"Task-</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span>, <span class="dv">1</span>)</span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a>time.sleep(<span class="dv">6</span>)</span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a>pool.close()</span></code></pre></div></div>
</section>
<section id="async-style-with-threading" class="level3">
<h3 class="anchored" data-anchor-id="async-style-with-threading" id="async-style-with-threading">2. Async-style with Threading</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AsyncResult:</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, future):</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.future <span class="op">=</span> future</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get(<span class="va">self</span>, timeout<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.future.result(timeout<span class="op">=</span>timeout)</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> is_ready(<span class="va">self</span>):</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.future.done()</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AsyncExecutor:</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, max_workers<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.executor <span class="op">=</span> ThreadPoolExecutor(max_workers<span class="op">=</span>max_workers)</span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> submit(<span class="va">self</span>, func, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a>        future <span class="op">=</span> <span class="va">self</span>.executor.submit(func, <span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb22-21"><a href="#cb22-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> AsyncResult(future)</span>
<span id="cb22-22"><a href="#cb22-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-23"><a href="#cb22-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="bu">map</span>(<span class="va">self</span>, func, iterable):</span>
<span id="cb22-24"><a href="#cb22-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [<span class="va">self</span>.submit(func, item) <span class="cf">for</span> item <span class="kw">in</span> iterable]</span>
<span id="cb22-25"><a href="#cb22-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-26"><a href="#cb22-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> shutdown(<span class="va">self</span>):</span>
<span id="cb22-27"><a href="#cb22-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.executor.shutdown(wait<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb22-28"><a href="#cb22-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-29"><a href="#cb22-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb22-30"><a href="#cb22-30" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> long_running_task(n):</span>
<span id="cb22-31"><a href="#cb22-31" aria-hidden="true" tabindex="-1"></a>    time.sleep(n)</span>
<span id="cb22-32"><a href="#cb22-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> n <span class="op">*</span> n</span>
<span id="cb22-33"><a href="#cb22-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-34"><a href="#cb22-34" aria-hidden="true" tabindex="-1"></a>async_executor <span class="op">=</span> AsyncExecutor(max_workers<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb22-35"><a href="#cb22-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-36"><a href="#cb22-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Submit tasks</span></span>
<span id="cb22-37"><a href="#cb22-37" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> []</span>
<span id="cb22-38"><a href="#cb22-38" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, <span class="dv">4</span>):</span>
<span id="cb22-39"><a href="#cb22-39" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> async_executor.submit(long_running_task, i)</span>
<span id="cb22-40"><a href="#cb22-40" aria-hidden="true" tabindex="-1"></a>    results.append(result)</span>
<span id="cb22-41"><a href="#cb22-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-42"><a href="#cb22-42" aria-hidden="true" tabindex="-1"></a><span class="co"># Wait for results</span></span>
<span id="cb22-43"><a href="#cb22-43" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, result <span class="kw">in</span> <span class="bu">enumerate</span>(results):</span>
<span id="cb22-44"><a href="#cb22-44" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Task </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss"> result: </span><span class="sc">{</span>result<span class="sc">.</span>get()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb22-45"><a href="#cb22-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-46"><a href="#cb22-46" aria-hidden="true" tabindex="-1"></a>async_executor.shutdown()</span></code></pre></div></div>
</section>
<section id="process-pool-with-initialization" class="level3">
<h3 class="anchored" data-anchor-id="process-pool-with-initialization" id="process-pool-with-initialization">3. Process Pool with Initialization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Global variable for each process</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>process_data <span class="op">=</span> <span class="va">None</span></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> init_process(shared_data):</span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> process_data</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>    process_data <span class="op">=</span> shared_data</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Process </span><span class="sc">{</span>multiprocessing<span class="sc">.</span>current_process()<span class="sc">.</span>name<span class="sc">}</span><span class="ss"> initialized"</span>)</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker_with_init(item):</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> process_data</span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use the initialized data</span></span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> item <span class="op">*</span> process_data</span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>    shared_value <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb23-20"><a href="#cb23-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-21"><a href="#cb23-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> multiprocessing.Pool(</span>
<span id="cb23-22"><a href="#cb23-22" aria-hidden="true" tabindex="-1"></a>        processes<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb23-23"><a href="#cb23-23" aria-hidden="true" tabindex="-1"></a>        initializer<span class="op">=</span>init_process,</span>
<span id="cb23-24"><a href="#cb23-24" aria-hidden="true" tabindex="-1"></a>        initargs<span class="op">=</span>(shared_value,)</span>
<span id="cb23-25"><a href="#cb23-25" aria-hidden="true" tabindex="-1"></a>    ) <span class="im">as</span> pool:</span>
<span id="cb23-26"><a href="#cb23-26" aria-hidden="true" tabindex="-1"></a>        items <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb23-27"><a href="#cb23-27" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> pool.<span class="bu">map</span>(worker_with_init, items)</span>
<span id="cb23-28"><a href="#cb23-28" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Results: </span><span class="sc">{</span>results<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="real-world-examples" class="level2">
<h2 class="anchored" data-anchor-id="real-world-examples" id="real-world-examples">Real-World Examples</h2>
<section id="web-scraper" class="level3">
<h3 class="anchored" data-anchor-id="web-scraper" id="web-scraper">1. Web Scraper</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, as_completed</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> urllib.parse <span class="im">import</span> urljoin, urlparse</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> queue</span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> WebScraper:</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, max_workers<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_workers <span class="op">=</span> max_workers</span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.session <span class="op">=</span> requests.Session()</span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.results <span class="op">=</span> []</span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lock <span class="op">=</span> threading.Lock()</span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-15"><a href="#cb24-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> fetch_url(<span class="va">self</span>, url):</span>
<span id="cb24-16"><a href="#cb24-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb24-17"><a href="#cb24-17" aria-hidden="true" tabindex="-1"></a>            response <span class="op">=</span> <span class="va">self</span>.session.get(url, timeout<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb24-18"><a href="#cb24-18" aria-hidden="true" tabindex="-1"></a>            response.raise_for_status()</span>
<span id="cb24-19"><a href="#cb24-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb24-20"><a href="#cb24-20" aria-hidden="true" tabindex="-1"></a>                <span class="st">'url'</span>: url,</span>
<span id="cb24-21"><a href="#cb24-21" aria-hidden="true" tabindex="-1"></a>                <span class="st">'status'</span>: response.status_code,</span>
<span id="cb24-22"><a href="#cb24-22" aria-hidden="true" tabindex="-1"></a>                <span class="st">'content_length'</span>: <span class="bu">len</span>(response.content),</span>
<span id="cb24-23"><a href="#cb24-23" aria-hidden="true" tabindex="-1"></a>                <span class="st">'title'</span>: <span class="va">self</span>._extract_title(response.text)</span>
<span id="cb24-24"><a href="#cb24-24" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb24-25"><a href="#cb24-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb24-26"><a href="#cb24-26" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb24-27"><a href="#cb24-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">'url'</span>: url,</span>
<span id="cb24-28"><a href="#cb24-28" aria-hidden="true" tabindex="-1"></a>                <span class="st">'error'</span>: <span class="bu">str</span>(e)</span>
<span id="cb24-29"><a href="#cb24-29" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb24-30"><a href="#cb24-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-31"><a href="#cb24-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _extract_title(<span class="va">self</span>, html):</span>
<span id="cb24-32"><a href="#cb24-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simple title extraction</span></span>
<span id="cb24-33"><a href="#cb24-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb24-34"><a href="#cb24-34" aria-hidden="true" tabindex="-1"></a>            start <span class="op">=</span> html.find(<span class="st">'&lt;title&gt;'</span>) <span class="op">+</span> <span class="dv">7</span></span>
<span id="cb24-35"><a href="#cb24-35" aria-hidden="true" tabindex="-1"></a>            end <span class="op">=</span> html.find(<span class="st">'&lt;/title&gt;'</span>, start)</span>
<span id="cb24-36"><a href="#cb24-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> html[start:end].strip()</span>
<span id="cb24-37"><a href="#cb24-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span>:</span>
<span id="cb24-38"><a href="#cb24-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="st">"No title"</span></span>
<span id="cb24-39"><a href="#cb24-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-40"><a href="#cb24-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> scrape_urls(<span class="va">self</span>, urls):</span>
<span id="cb24-41"><a href="#cb24-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="va">self</span>.max_workers) <span class="im">as</span> executor:</span>
<span id="cb24-42"><a href="#cb24-42" aria-hidden="true" tabindex="-1"></a>            future_to_url <span class="op">=</span> {executor.submit(<span class="va">self</span>.fetch_url, url): url <span class="cf">for</span> url <span class="kw">in</span> urls}</span>
<span id="cb24-43"><a href="#cb24-43" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-44"><a href="#cb24-44" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> future <span class="kw">in</span> as_completed(future_to_url):</span>
<span id="cb24-45"><a href="#cb24-45" aria-hidden="true" tabindex="-1"></a>                result <span class="op">=</span> future.result()</span>
<span id="cb24-46"><a href="#cb24-46" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> <span class="va">self</span>.lock:</span>
<span id="cb24-47"><a href="#cb24-47" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.results.append(result)</span>
<span id="cb24-48"><a href="#cb24-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-49"><a href="#cb24-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.results</span>
<span id="cb24-50"><a href="#cb24-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-51"><a href="#cb24-51" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb24-52"><a href="#cb24-52" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb24-53"><a href="#cb24-53" aria-hidden="true" tabindex="-1"></a>    urls <span class="op">=</span> [</span>
<span id="cb24-54"><a href="#cb24-54" aria-hidden="true" tabindex="-1"></a>        <span class="st">'https://httpbin.org/delay/1'</span>,</span>
<span id="cb24-55"><a href="#cb24-55" aria-hidden="true" tabindex="-1"></a>        <span class="st">'https://httpbin.org/delay/2'</span>,</span>
<span id="cb24-56"><a href="#cb24-56" aria-hidden="true" tabindex="-1"></a>        <span class="st">'https://httpbin.org/status/200'</span>,</span>
<span id="cb24-57"><a href="#cb24-57" aria-hidden="true" tabindex="-1"></a>        <span class="st">'https://httpbin.org/status/404'</span></span>
<span id="cb24-58"><a href="#cb24-58" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb24-59"><a href="#cb24-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-60"><a href="#cb24-60" aria-hidden="true" tabindex="-1"></a>    scraper <span class="op">=</span> WebScraper(max_workers<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb24-61"><a href="#cb24-61" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> scraper.scrape_urls(urls)</span>
<span id="cb24-62"><a href="#cb24-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-63"><a href="#cb24-63" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb24-64"><a href="#cb24-64" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(result)</span></code></pre></div></div>
</section>
<section id="file-processing-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="file-processing-pipeline" id="file-processing-pipeline">2. File Processing Pipeline</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ProcessPoolExecutor, ThreadPoolExecutor</span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FileProcessor:</span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_dir, output_dir, max_workers<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.input_dir <span class="op">=</span> input_dir</span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output_dir <span class="op">=</span> output_dir</span>
<span id="cb25-12"><a href="#cb25-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_workers <span class="op">=</span> max_workers</span>
<span id="cb25-13"><a href="#cb25-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processed_files <span class="op">=</span> []</span>
<span id="cb25-14"><a href="#cb25-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lock <span class="op">=</span> threading.Lock()</span>
<span id="cb25-15"><a href="#cb25-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-16"><a href="#cb25-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> process_file(<span class="va">self</span>, filepath):</span>
<span id="cb25-17"><a href="#cb25-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Process a single file"""</span></span>
<span id="cb25-18"><a href="#cb25-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb25-19"><a href="#cb25-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(filepath, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb25-20"><a href="#cb25-20" aria-hidden="true" tabindex="-1"></a>                data <span class="op">=</span> json.load(f)</span>
<span id="cb25-21"><a href="#cb25-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-22"><a href="#cb25-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simulate processing</span></span>
<span id="cb25-23"><a href="#cb25-23" aria-hidden="true" tabindex="-1"></a>            processed_data <span class="op">=</span> {</span>
<span id="cb25-24"><a href="#cb25-24" aria-hidden="true" tabindex="-1"></a>                <span class="st">'original_file'</span>: filepath,</span>
<span id="cb25-25"><a href="#cb25-25" aria-hidden="true" tabindex="-1"></a>                <span class="st">'processed_at'</span>: time.time(),</span>
<span id="cb25-26"><a href="#cb25-26" aria-hidden="true" tabindex="-1"></a>                <span class="st">'record_count'</span>: <span class="bu">len</span>(data) <span class="cf">if</span> <span class="bu">isinstance</span>(data, <span class="bu">list</span>) <span class="cf">else</span> <span class="dv">1</span>,</span>
<span id="cb25-27"><a href="#cb25-27" aria-hidden="true" tabindex="-1"></a>                <span class="st">'processing_time'</span>: <span class="fl">0.1</span></span>
<span id="cb25-28"><a href="#cb25-28" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb25-29"><a href="#cb25-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-30"><a href="#cb25-30" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="fl">0.1</span>)  <span class="co"># Simulate processing time</span></span>
<span id="cb25-31"><a href="#cb25-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-32"><a href="#cb25-32" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Write processed file</span></span>
<span id="cb25-33"><a href="#cb25-33" aria-hidden="true" tabindex="-1"></a>            output_filename <span class="op">=</span> <span class="ss">f"processed_</span><span class="sc">{</span>os<span class="sc">.</span>path<span class="sc">.</span>basename(filepath)<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb25-34"><a href="#cb25-34" aria-hidden="true" tabindex="-1"></a>            output_path <span class="op">=</span> os.path.join(<span class="va">self</span>.output_dir, output_filename)</span>
<span id="cb25-35"><a href="#cb25-35" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-36"><a href="#cb25-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(output_path, <span class="st">'w'</span>) <span class="im">as</span> f:</span>
<span id="cb25-37"><a href="#cb25-37" aria-hidden="true" tabindex="-1"></a>                json.dump(processed_data, f, indent<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb25-38"><a href="#cb25-38" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb25-39"><a href="#cb25-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb25-40"><a href="#cb25-40" aria-hidden="true" tabindex="-1"></a>                <span class="st">'input'</span>: filepath,</span>
<span id="cb25-41"><a href="#cb25-41" aria-hidden="true" tabindex="-1"></a>                <span class="st">'output'</span>: output_path,</span>
<span id="cb25-42"><a href="#cb25-42" aria-hidden="true" tabindex="-1"></a>                <span class="st">'status'</span>: <span class="st">'success'</span></span>
<span id="cb25-43"><a href="#cb25-43" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb25-44"><a href="#cb25-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-45"><a href="#cb25-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb25-46"><a href="#cb25-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {</span>
<span id="cb25-47"><a href="#cb25-47" aria-hidden="true" tabindex="-1"></a>                <span class="st">'input'</span>: filepath,</span>
<span id="cb25-48"><a href="#cb25-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">'error'</span>: <span class="bu">str</span>(e),</span>
<span id="cb25-49"><a href="#cb25-49" aria-hidden="true" tabindex="-1"></a>                <span class="st">'status'</span>: <span class="st">'failed'</span></span>
<span id="cb25-50"><a href="#cb25-50" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb25-51"><a href="#cb25-51" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-52"><a href="#cb25-52" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> process_directory(<span class="va">self</span>):</span>
<span id="cb25-53"><a href="#cb25-53" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Process all JSON files in the input directory"""</span></span>
<span id="cb25-54"><a href="#cb25-54" aria-hidden="true" tabindex="-1"></a>        json_files <span class="op">=</span> []</span>
<span id="cb25-55"><a href="#cb25-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> root, dirs, files <span class="kw">in</span> os.walk(<span class="va">self</span>.input_dir):</span>
<span id="cb25-56"><a href="#cb25-56" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> <span class="bu">file</span> <span class="kw">in</span> files:</span>
<span id="cb25-57"><a href="#cb25-57" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="bu">file</span>.endswith(<span class="st">'.json'</span>):</span>
<span id="cb25-58"><a href="#cb25-58" aria-hidden="true" tabindex="-1"></a>                    json_files.append(os.path.join(root, <span class="bu">file</span>))</span>
<span id="cb25-59"><a href="#cb25-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-60"><a href="#cb25-60" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Found </span><span class="sc">{</span><span class="bu">len</span>(json_files)<span class="sc">}</span><span class="ss"> JSON files to process"</span>)</span>
<span id="cb25-61"><a href="#cb25-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-62"><a href="#cb25-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process files in parallel</span></span>
<span id="cb25-63"><a href="#cb25-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="va">self</span>.max_workers) <span class="im">as</span> executor:</span>
<span id="cb25-64"><a href="#cb25-64" aria-hidden="true" tabindex="-1"></a>            results <span class="op">=</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(<span class="va">self</span>.process_file, json_files))</span>
<span id="cb25-65"><a href="#cb25-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb25-66"><a href="#cb25-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb25-67"><a href="#cb25-67" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-68"><a href="#cb25-68" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb25-69"><a href="#cb25-69" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb25-70"><a href="#cb25-70" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create sample data</span></span>
<span id="cb25-71"><a href="#cb25-71" aria-hidden="true" tabindex="-1"></a>    os.makedirs(<span class="st">'input_data'</span>, exist_ok<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb25-72"><a href="#cb25-72" aria-hidden="true" tabindex="-1"></a>    os.makedirs(<span class="st">'output_data'</span>, exist_ok<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb25-73"><a href="#cb25-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-74"><a href="#cb25-74" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create sample JSON files</span></span>
<span id="cb25-75"><a href="#cb25-75" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb25-76"><a href="#cb25-76" aria-hidden="true" tabindex="-1"></a>        sample_data <span class="op">=</span> [{<span class="st">'id'</span>: j, <span class="st">'value'</span>: j <span class="op">*</span> <span class="dv">10</span>} <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>)]</span>
<span id="cb25-77"><a href="#cb25-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="ss">f'input_data/sample_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">.json'</span>, <span class="st">'w'</span>) <span class="im">as</span> f:</span>
<span id="cb25-78"><a href="#cb25-78" aria-hidden="true" tabindex="-1"></a>            json.dump(sample_data, f)</span>
<span id="cb25-79"><a href="#cb25-79" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-80"><a href="#cb25-80" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Process files</span></span>
<span id="cb25-81"><a href="#cb25-81" aria-hidden="true" tabindex="-1"></a>    processor <span class="op">=</span> FileProcessor(<span class="st">'input_data'</span>, <span class="st">'output_data'</span>, max_workers<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb25-82"><a href="#cb25-82" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> processor.process_directory()</span>
<span id="cb25-83"><a href="#cb25-83" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb25-84"><a href="#cb25-84" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Print results</span></span>
<span id="cb25-85"><a href="#cb25-85" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb25-86"><a href="#cb25-86" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(result)</span></code></pre></div></div>
</section>
<section id="real-time-data-processing" class="level3">
<h3 class="anchored" data-anchor-id="real-time-data-processing" id="real-time-data-processing">3. Real-time Data Processing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> queue</span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> datetime <span class="im">import</span> datetime</span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DataProcessor:</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_workers<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.input_queue <span class="op">=</span> queue.Queue()</span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output_queue <span class="op">=</span> queue.Queue()</span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_workers <span class="op">=</span> num_workers</span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.workers <span class="op">=</span> []</span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.running <span class="op">=</span> <span class="va">False</span></span>
<span id="cb26-15"><a href="#cb26-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processed_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb26-16"><a href="#cb26-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lock <span class="op">=</span> threading.Lock()</span>
<span id="cb26-17"><a href="#cb26-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-18"><a href="#cb26-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> worker(<span class="va">self</span>, worker_id):</span>
<span id="cb26-19"><a href="#cb26-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Process data items from the queue"""</span></span>
<span id="cb26-20"><a href="#cb26-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">while</span> <span class="va">self</span>.running:</span>
<span id="cb26-21"><a href="#cb26-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb26-22"><a href="#cb26-22" aria-hidden="true" tabindex="-1"></a>                data <span class="op">=</span> <span class="va">self</span>.input_queue.get(timeout<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb26-23"><a href="#cb26-23" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> data <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb26-24"><a href="#cb26-24" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb26-25"><a href="#cb26-25" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb26-26"><a href="#cb26-26" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Simulate processing</span></span>
<span id="cb26-27"><a href="#cb26-27" aria-hidden="true" tabindex="-1"></a>                processed_data <span class="op">=</span> <span class="va">self</span>.process_data(data, worker_id)</span>
<span id="cb26-28"><a href="#cb26-28" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.output_queue.put(processed_data)</span>
<span id="cb26-29"><a href="#cb26-29" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb26-30"><a href="#cb26-30" aria-hidden="true" tabindex="-1"></a>                <span class="cf">with</span> <span class="va">self</span>.lock:</span>
<span id="cb26-31"><a href="#cb26-31" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.processed_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb26-32"><a href="#cb26-32" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb26-33"><a href="#cb26-33" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.input_queue.task_done()</span>
<span id="cb26-34"><a href="#cb26-34" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb26-35"><a href="#cb26-35" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> queue.Empty:</span>
<span id="cb26-36"><a href="#cb26-36" aria-hidden="true" tabindex="-1"></a>                <span class="cf">continue</span></span>
<span id="cb26-37"><a href="#cb26-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-38"><a href="#cb26-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> process_data(<span class="va">self</span>, data, worker_id):</span>
<span id="cb26-39"><a href="#cb26-39" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Process individual data item"""</span></span>
<span id="cb26-40"><a href="#cb26-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate processing time</span></span>
<span id="cb26-41"><a href="#cb26-41" aria-hidden="true" tabindex="-1"></a>        time.sleep(random.uniform(<span class="fl">0.1</span>, <span class="fl">0.5</span>))</span>
<span id="cb26-42"><a href="#cb26-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-43"><a href="#cb26-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb26-44"><a href="#cb26-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">'worker_id'</span>: worker_id,</span>
<span id="cb26-45"><a href="#cb26-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'original_data'</span>: data,</span>
<span id="cb26-46"><a href="#cb26-46" aria-hidden="true" tabindex="-1"></a>            <span class="st">'processed_at'</span>: datetime.now().isoformat(),</span>
<span id="cb26-47"><a href="#cb26-47" aria-hidden="true" tabindex="-1"></a>            <span class="st">'result'</span>: data[<span class="st">'value'</span>] <span class="op">*</span> <span class="dv">2</span> <span class="cf">if</span> <span class="st">'value'</span> <span class="kw">in</span> data <span class="cf">else</span> <span class="st">'processed'</span></span>
<span id="cb26-48"><a href="#cb26-48" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb26-49"><a href="#cb26-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-50"><a href="#cb26-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> start(<span class="va">self</span>):</span>
<span id="cb26-51"><a href="#cb26-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Start the worker threads"""</span></span>
<span id="cb26-52"><a href="#cb26-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.running <span class="op">=</span> <span class="va">True</span></span>
<span id="cb26-53"><a href="#cb26-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_workers):</span>
<span id="cb26-54"><a href="#cb26-54" aria-hidden="true" tabindex="-1"></a>            worker <span class="op">=</span> threading.Thread(target<span class="op">=</span><span class="va">self</span>.worker, args<span class="op">=</span>(i,))</span>
<span id="cb26-55"><a href="#cb26-55" aria-hidden="true" tabindex="-1"></a>            worker.start()</span>
<span id="cb26-56"><a href="#cb26-56" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.workers.append(worker)</span>
<span id="cb26-57"><a href="#cb26-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-58"><a href="#cb26-58" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> stop(<span class="va">self</span>):</span>
<span id="cb26-59"><a href="#cb26-59" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Stop all worker threads"""</span></span>
<span id="cb26-60"><a href="#cb26-60" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.running <span class="op">=</span> <span class="va">False</span></span>
<span id="cb26-61"><a href="#cb26-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-62"><a href="#cb26-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add sentinel values to wake up workers</span></span>
<span id="cb26-63"><a href="#cb26-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.num_workers):</span>
<span id="cb26-64"><a href="#cb26-64" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.input_queue.put(<span class="va">None</span>)</span>
<span id="cb26-65"><a href="#cb26-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb26-66"><a href="#cb26-66" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Wait for workers to finish</span></span>
<span id="cb26-67"><a href="#cb26-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> worker <span class="kw">in</span> <span class="va">self</span>.workers:</span>
<span id="cb26-68"><a href="#cb26-68" aria-hidden="true" tabindex="-1"></a>            worker.join()</span>
<span id="cb26-69"><a href="#cb26-69" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-70"><a href="#cb26-70" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> add_data(<span class="va">self</span>, data):</span>
<span id="cb26-71"><a href="#cb26-71" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Add data to the processing queue"""</span></span>
<span id="cb26-72"><a href="#cb26-72" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.input_queue.put(data)</span>
<span id="cb26-73"><a href="#cb26-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-74"><a href="#cb26-74" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_result(<span class="va">self</span>, timeout<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb26-75"><a href="#cb26-75" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get processed result"""</span></span>
<span id="cb26-76"><a href="#cb26-76" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb26-77"><a href="#cb26-77" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.output_queue.get(timeout<span class="op">=</span>timeout)</span>
<span id="cb26-78"><a href="#cb26-78" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> queue.Empty:</span>
<span id="cb26-79"><a href="#cb26-79" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb26-80"><a href="#cb26-80" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-81"><a href="#cb26-81" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_stats(<span class="va">self</span>):</span>
<span id="cb26-82"><a href="#cb26-82" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get processing statistics"""</span></span>
<span id="cb26-83"><a href="#cb26-83" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb26-84"><a href="#cb26-84" aria-hidden="true" tabindex="-1"></a>            <span class="st">'input_queue_size'</span>: <span class="va">self</span>.input_queue.qsize(),</span>
<span id="cb26-85"><a href="#cb26-85" aria-hidden="true" tabindex="-1"></a>            <span class="st">'output_queue_size'</span>: <span class="va">self</span>.output_queue.qsize(),</span>
<span id="cb26-86"><a href="#cb26-86" aria-hidden="true" tabindex="-1"></a>            <span class="st">'processed_count'</span>: <span class="va">self</span>.processed_count,</span>
<span id="cb26-87"><a href="#cb26-87" aria-hidden="true" tabindex="-1"></a>            <span class="st">'active_workers'</span>: <span class="bu">len</span>([w <span class="cf">for</span> w <span class="kw">in</span> <span class="va">self</span>.workers <span class="cf">if</span> w.is_alive()])</span>
<span id="cb26-88"><a href="#cb26-88" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb26-89"><a href="#cb26-89" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-90"><a href="#cb26-90" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb26-91"><a href="#cb26-91" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb26-92"><a href="#cb26-92" aria-hidden="true" tabindex="-1"></a>    processor <span class="op">=</span> DataProcessor(num_workers<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb26-93"><a href="#cb26-93" aria-hidden="true" tabindex="-1"></a>    processor.start()</span>
<span id="cb26-94"><a href="#cb26-94" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-95"><a href="#cb26-95" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Simulate data streaming</span></span>
<span id="cb26-96"><a href="#cb26-96" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> data_generator():</span>
<span id="cb26-97"><a href="#cb26-97" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">20</span>):</span>
<span id="cb26-98"><a href="#cb26-98" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> {<span class="st">'id'</span>: i, <span class="st">'value'</span>: random.randint(<span class="dv">1</span>, <span class="dv">100</span>)}</span>
<span id="cb26-99"><a href="#cb26-99" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb26-100"><a href="#cb26-100" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-101"><a href="#cb26-101" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Add data to processor</span></span>
<span id="cb26-102"><a href="#cb26-102" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> data <span class="kw">in</span> data_generator():</span>
<span id="cb26-103"><a href="#cb26-103" aria-hidden="true" tabindex="-1"></a>        processor.add_data(data)</span>
<span id="cb26-104"><a href="#cb26-104" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Added data: </span><span class="sc">{</span>data<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb26-105"><a href="#cb26-105" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-106"><a href="#cb26-106" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Collect results</span></span>
<span id="cb26-107"><a href="#cb26-107" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb26-108"><a href="#cb26-108" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb26-109"><a href="#cb26-109" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="bu">len</span>(results) <span class="op">&lt;</span> <span class="dv">20</span> <span class="kw">and</span> time.time() <span class="op">-</span> start_time <span class="op">&lt;</span> <span class="dv">30</span>:</span>
<span id="cb26-110"><a href="#cb26-110" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> processor.get_result(timeout<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb26-111"><a href="#cb26-111" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> result:</span>
<span id="cb26-112"><a href="#cb26-112" aria-hidden="true" tabindex="-1"></a>            results.append(result)</span>
<span id="cb26-113"><a href="#cb26-113" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Got result: </span><span class="sc">{</span>result<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb26-114"><a href="#cb26-114" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-115"><a href="#cb26-115" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Print statistics</span></span>
<span id="cb26-116"><a href="#cb26-116" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Final stats: </span><span class="sc">{</span>processor<span class="sc">.</span>get_stats()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb26-117"><a href="#cb26-117" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-118"><a href="#cb26-118" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-119"><a href="#cb26-119" aria-hidden="true" tabindex="-1"></a><span class="co">## Troubleshooting Common Issues</span></span>
<span id="cb26-120"><a href="#cb26-120" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-121"><a href="#cb26-121" aria-hidden="true" tabindex="-1"></a><span class="co">### 1. Race Conditions</span></span>
<span id="cb26-122"><a href="#cb26-122" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-123"><a href="#cb26-123" aria-hidden="true" tabindex="-1"></a>```python</span>
<span id="cb26-124"><a href="#cb26-124" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb26-125"><a href="#cb26-125" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb26-126"><a href="#cb26-126" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-127"><a href="#cb26-127" aria-hidden="true" tabindex="-1"></a><span class="co"># Problem: Race condition</span></span>
<span id="cb26-128"><a href="#cb26-128" aria-hidden="true" tabindex="-1"></a>shared_counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb26-129"><a href="#cb26-129" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-130"><a href="#cb26-130" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> unsafe_increment():</span>
<span id="cb26-131"><a href="#cb26-131" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> shared_counter</span>
<span id="cb26-132"><a href="#cb26-132" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100000</span>):</span>
<span id="cb26-133"><a href="#cb26-133" aria-hidden="true" tabindex="-1"></a>        shared_counter <span class="op">+=</span> <span class="dv">1</span>  <span class="co"># This is not atomic!</span></span>
<span id="cb26-134"><a href="#cb26-134" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-135"><a href="#cb26-135" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Use locks</span></span>
<span id="cb26-136"><a href="#cb26-136" aria-hidden="true" tabindex="-1"></a>safe_counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb26-137"><a href="#cb26-137" aria-hidden="true" tabindex="-1"></a>counter_lock <span class="op">=</span> threading.Lock()</span>
<span id="cb26-138"><a href="#cb26-138" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-139"><a href="#cb26-139" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_increment():</span>
<span id="cb26-140"><a href="#cb26-140" aria-hidden="true" tabindex="-1"></a>    <span class="kw">global</span> safe_counter</span>
<span id="cb26-141"><a href="#cb26-141" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100000</span>):</span>
<span id="cb26-142"><a href="#cb26-142" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> counter_lock:</span>
<span id="cb26-143"><a href="#cb26-143" aria-hidden="true" tabindex="-1"></a>            safe_counter <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb26-144"><a href="#cb26-144" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-145"><a href="#cb26-145" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative: Use atomic operations</span></span>
<span id="cb26-146"><a href="#cb26-146" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> threading <span class="im">import</span> Lock</span>
<span id="cb26-147"><a href="#cb26-147" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb26-148"><a href="#cb26-148" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-149"><a href="#cb26-149" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AtomicCounter:</span>
<span id="cb26-150"><a href="#cb26-150" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb26-151"><a href="#cb26-151" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._value <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb26-152"><a href="#cb26-152" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._lock <span class="op">=</span> Lock()</span>
<span id="cb26-153"><a href="#cb26-153" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-154"><a href="#cb26-154" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> increment(<span class="va">self</span>):</span>
<span id="cb26-155"><a href="#cb26-155" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="va">self</span>._lock:</span>
<span id="cb26-156"><a href="#cb26-156" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._value <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb26-157"><a href="#cb26-157" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb26-158"><a href="#cb26-158" aria-hidden="true" tabindex="-1"></a>    <span class="at">@property</span></span>
<span id="cb26-159"><a href="#cb26-159" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> value(<span class="va">self</span>):</span>
<span id="cb26-160"><a href="#cb26-160" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="va">self</span>._lock:</span>
<span id="cb26-161"><a href="#cb26-161" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>._value</span>
<span id="cb26-162"><a href="#cb26-162" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-163"><a href="#cb26-163" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb26-164"><a href="#cb26-164" aria-hidden="true" tabindex="-1"></a>atomic_counter <span class="op">=</span> AtomicCounter()</span>
<span id="cb26-165"><a href="#cb26-165" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-166"><a href="#cb26-166" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker():</span>
<span id="cb26-167"><a href="#cb26-167" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100000</span>):</span>
<span id="cb26-168"><a href="#cb26-168" aria-hidden="true" tabindex="-1"></a>        atomic_counter.increment()</span>
<span id="cb26-169"><a href="#cb26-169" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-170"><a href="#cb26-170" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> [threading.Thread(target<span class="op">=</span>worker) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>)]</span>
<span id="cb26-171"><a href="#cb26-171" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb26-172"><a href="#cb26-172" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb26-173"><a href="#cb26-173" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb26-174"><a href="#cb26-174" aria-hidden="true" tabindex="-1"></a>    t.join()</span>
<span id="cb26-175"><a href="#cb26-175" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-176"><a href="#cb26-176" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Atomic counter final value: </span><span class="sc">{</span>atomic_counter<span class="sc">.</span>value<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="deadlocks" class="level3">
<h3 class="anchored" data-anchor-id="deadlocks" id="deadlocks">2. Deadlocks</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Problem: Deadlock scenario</span></span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a>lock1 <span class="op">=</span> threading.Lock()</span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a>lock2 <span class="op">=</span> threading.Lock()</span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task1():</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> lock1:</span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Task 1 acquired lock1"</span>)</span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb27-12"><a href="#cb27-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock2:</span>
<span id="cb27-13"><a href="#cb27-13" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Task 1 acquired lock2"</span>)</span>
<span id="cb27-14"><a href="#cb27-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-15"><a href="#cb27-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task2():</span>
<span id="cb27-16"><a href="#cb27-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> lock2:</span>
<span id="cb27-17"><a href="#cb27-17" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Task 2 acquired lock2"</span>)</span>
<span id="cb27-18"><a href="#cb27-18" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb27-19"><a href="#cb27-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock1:</span>
<span id="cb27-20"><a href="#cb27-20" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Task 2 acquired lock1"</span>)</span>
<span id="cb27-21"><a href="#cb27-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-22"><a href="#cb27-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Always acquire locks in the same order</span></span>
<span id="cb27-23"><a href="#cb27-23" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_task1():</span>
<span id="cb27-24"><a href="#cb27-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> lock1:</span>
<span id="cb27-25"><a href="#cb27-25" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Safe Task 1 acquired lock1"</span>)</span>
<span id="cb27-26"><a href="#cb27-26" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb27-27"><a href="#cb27-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock2:</span>
<span id="cb27-28"><a href="#cb27-28" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Safe Task 1 acquired lock2"</span>)</span>
<span id="cb27-29"><a href="#cb27-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-30"><a href="#cb27-30" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_task2():</span>
<span id="cb27-31"><a href="#cb27-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> lock1:  <span class="co"># Same order as safe_task1</span></span>
<span id="cb27-32"><a href="#cb27-32" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Safe Task 2 acquired lock1"</span>)</span>
<span id="cb27-33"><a href="#cb27-33" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb27-34"><a href="#cb27-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> lock2:</span>
<span id="cb27-35"><a href="#cb27-35" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Safe Task 2 acquired lock2"</span>)</span>
<span id="cb27-36"><a href="#cb27-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-37"><a href="#cb27-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative: Use timeout</span></span>
<span id="cb27-38"><a href="#cb27-38" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb27-39"><a href="#cb27-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-40"><a href="#cb27-40" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task_with_timeout():</span>
<span id="cb27-41"><a href="#cb27-41" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> lock1.acquire(timeout<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb27-42"><a href="#cb27-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb27-43"><a href="#cb27-43" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Acquired lock1"</span>)</span>
<span id="cb27-44"><a href="#cb27-44" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> lock2.acquire(timeout<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb27-45"><a href="#cb27-45" aria-hidden="true" tabindex="-1"></a>                <span class="cf">try</span>:</span>
<span id="cb27-46"><a href="#cb27-46" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(<span class="st">"Acquired lock2"</span>)</span>
<span id="cb27-47"><a href="#cb27-47" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Do work</span></span>
<span id="cb27-48"><a href="#cb27-48" aria-hidden="true" tabindex="-1"></a>                <span class="cf">finally</span>:</span>
<span id="cb27-49"><a href="#cb27-49" aria-hidden="true" tabindex="-1"></a>                    lock2.release()</span>
<span id="cb27-50"><a href="#cb27-50" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb27-51"><a href="#cb27-51" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="st">"Could not acquire lock2"</span>)</span>
<span id="cb27-52"><a href="#cb27-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">finally</span>:</span>
<span id="cb27-53"><a href="#cb27-53" aria-hidden="true" tabindex="-1"></a>            lock1.release()</span>
<span id="cb27-54"><a href="#cb27-54" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb27-55"><a href="#cb27-55" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Could not acquire lock1"</span>)</span></code></pre></div></div>
</section>
<section id="memory-leaks-in-multiprocessing" class="level3">
<h3 class="anchored" data-anchor-id="memory-leaks-in-multiprocessing" id="memory-leaks-in-multiprocessing">3. Memory Leaks in Multiprocessing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> psutil</span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Problem: Not properly cleaning up processes</span></span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_leak_example():</span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a>    processes <span class="op">=</span> []</span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a>        p <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span><span class="kw">lambda</span>: time.sleep(<span class="dv">10</span>))</span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a>        p.start()</span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a>        processes.append(p)</span>
<span id="cb28-12"><a href="#cb28-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forgetting to join processes can lead to zombie processes</span></span>
<span id="cb28-13"><a href="#cb28-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-14"><a href="#cb28-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Proper cleanup</span></span>
<span id="cb28-15"><a href="#cb28-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> proper_process_management():</span>
<span id="cb28-16"><a href="#cb28-16" aria-hidden="true" tabindex="-1"></a>    processes <span class="op">=</span> []</span>
<span id="cb28-17"><a href="#cb28-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb28-18"><a href="#cb28-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb28-19"><a href="#cb28-19" aria-hidden="true" tabindex="-1"></a>            p <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span><span class="kw">lambda</span>: time.sleep(<span class="dv">1</span>))</span>
<span id="cb28-20"><a href="#cb28-20" aria-hidden="true" tabindex="-1"></a>            p.start()</span>
<span id="cb28-21"><a href="#cb28-21" aria-hidden="true" tabindex="-1"></a>            processes.append(p)</span>
<span id="cb28-22"><a href="#cb28-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb28-23"><a href="#cb28-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Wait for all processes to complete</span></span>
<span id="cb28-24"><a href="#cb28-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb28-25"><a href="#cb28-25" aria-hidden="true" tabindex="-1"></a>            p.join()</span>
<span id="cb28-26"><a href="#cb28-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb28-27"><a href="#cb28-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">KeyboardInterrupt</span>:</span>
<span id="cb28-28"><a href="#cb28-28" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Interrupting processes..."</span>)</span>
<span id="cb28-29"><a href="#cb28-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb28-30"><a href="#cb28-30" aria-hidden="true" tabindex="-1"></a>            p.terminate()</span>
<span id="cb28-31"><a href="#cb28-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb28-32"><a href="#cb28-32" aria-hidden="true" tabindex="-1"></a>            p.join()</span>
<span id="cb28-33"><a href="#cb28-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-34"><a href="#cb28-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Context manager approach</span></span>
<span id="cb28-35"><a href="#cb28-35" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb28-36"><a href="#cb28-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-37"><a href="#cb28-37" aria-hidden="true" tabindex="-1"></a><span class="at">@contextmanager</span></span>
<span id="cb28-38"><a href="#cb28-38" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> managed_processes(target_func, num_processes):</span>
<span id="cb28-39"><a href="#cb28-39" aria-hidden="true" tabindex="-1"></a>    processes <span class="op">=</span> []</span>
<span id="cb28-40"><a href="#cb28-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb28-41"><a href="#cb28-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_processes):</span>
<span id="cb28-42"><a href="#cb28-42" aria-hidden="true" tabindex="-1"></a>            p <span class="op">=</span> multiprocessing.Process(target<span class="op">=</span>target_func)</span>
<span id="cb28-43"><a href="#cb28-43" aria-hidden="true" tabindex="-1"></a>            p.start()</span>
<span id="cb28-44"><a href="#cb28-44" aria-hidden="true" tabindex="-1"></a>            processes.append(p)</span>
<span id="cb28-45"><a href="#cb28-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">yield</span> processes</span>
<span id="cb28-46"><a href="#cb28-46" aria-hidden="true" tabindex="-1"></a>    <span class="cf">finally</span>:</span>
<span id="cb28-47"><a href="#cb28-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb28-48"><a href="#cb28-48" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> p.is_alive():</span>
<span id="cb28-49"><a href="#cb28-49" aria-hidden="true" tabindex="-1"></a>                p.terminate()</span>
<span id="cb28-50"><a href="#cb28-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> processes:</span>
<span id="cb28-51"><a href="#cb28-51" aria-hidden="true" tabindex="-1"></a>            p.join()</span>
<span id="cb28-52"><a href="#cb28-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-53"><a href="#cb28-53" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb28-54"><a href="#cb28-54" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> worker_task():</span>
<span id="cb28-55"><a href="#cb28-55" aria-hidden="true" tabindex="-1"></a>    time.sleep(<span class="dv">1</span>)</span>
<span id="cb28-56"><a href="#cb28-56" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Worker </span><span class="sc">{</span>os<span class="sc">.</span>getpid()<span class="sc">}</span><span class="ss"> finished"</span>)</span>
<span id="cb28-57"><a href="#cb28-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-58"><a href="#cb28-58" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb28-59"><a href="#cb28-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> managed_processes(worker_task, <span class="dv">4</span>) <span class="im">as</span> processes:</span>
<span id="cb28-60"><a href="#cb28-60" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Started </span><span class="sc">{</span><span class="bu">len</span>(processes)<span class="sc">}</span><span class="ss"> processes"</span>)</span>
<span id="cb28-61"><a href="#cb28-61" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Processes will be properly cleaned up</span></span></code></pre></div></div>
</section>
<section id="pickle-errors-in-multiprocessing" class="level3">
<h3 class="anchored" data-anchor-id="pickle-errors-in-multiprocessing" id="pickle-errors-in-multiprocessing">4. Pickle Errors in Multiprocessing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pickle</span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Problem: Cannot pickle certain objects</span></span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> UnpicklableClass:</span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lambda_func <span class="op">=</span> <span class="kw">lambda</span> x: x <span class="op">*</span> <span class="dv">2</span>  <span class="co"># Cannot pickle lambda</span></span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.file_handle <span class="op">=</span> <span class="bu">open</span>(<span class="st">'temp.txt'</span>, <span class="st">'w'</span>)  <span class="co"># Cannot pickle file handles</span></span>
<span id="cb29-9"><a href="#cb29-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-10"><a href="#cb29-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Use picklable alternatives</span></span>
<span id="cb29-11"><a href="#cb29-11" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PicklableClass:</span>
<span id="cb29-12"><a href="#cb29-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb29-13"><a href="#cb29-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.multiplier <span class="op">=</span> <span class="dv">2</span></span>
<span id="cb29-14"><a href="#cb29-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb29-15"><a href="#cb29-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> multiply(<span class="va">self</span>, x):</span>
<span id="cb29-16"><a href="#cb29-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> <span class="va">self</span>.multiplier</span>
<span id="cb29-17"><a href="#cb29-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-18"><a href="#cb29-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_with_method(obj, value):</span>
<span id="cb29-19"><a href="#cb29-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> obj.multiply(value)</span>
<span id="cb29-20"><a href="#cb29-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-21"><a href="#cb29-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative: Use dill for advanced pickling</span></span>
<span id="cb29-22"><a href="#cb29-22" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb29-23"><a href="#cb29-23" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> dill</span>
<span id="cb29-24"><a href="#cb29-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb29-25"><a href="#cb29-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> advanced_pickle_function():</span>
<span id="cb29-26"><a href="#cb29-26" aria-hidden="true" tabindex="-1"></a>        func <span class="op">=</span> <span class="kw">lambda</span> x: x <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb29-27"><a href="#cb29-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> dill.dumps(func)</span>
<span id="cb29-28"><a href="#cb29-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb29-29"><a href="#cb29-29" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">ImportError</span>:</span>
<span id="cb29-30"><a href="#cb29-30" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"dill not available"</span>)</span>
<span id="cb29-31"><a href="#cb29-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-32"><a href="#cb29-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Using multiprocessing with proper pickling</span></span>
<span id="cb29-33"><a href="#cb29-33" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_multiprocessing_example():</span>
<span id="cb29-34"><a href="#cb29-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb29-35"><a href="#cb29-35" aria-hidden="true" tabindex="-1"></a>        obj <span class="op">=</span> PicklableClass()</span>
<span id="cb29-36"><a href="#cb29-36" aria-hidden="true" tabindex="-1"></a>        values <span class="op">=</span> [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>]</span>
<span id="cb29-37"><a href="#cb29-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb29-38"><a href="#cb29-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> multiprocessing.Pool(processes<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> pool:</span>
<span id="cb29-39"><a href="#cb29-39" aria-hidden="true" tabindex="-1"></a>            results <span class="op">=</span> pool.starmap(process_with_method, [(obj, v) <span class="cf">for</span> v <span class="kw">in</span> values])</span>
<span id="cb29-40"><a href="#cb29-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb29-41"><a href="#cb29-41" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Results: </span><span class="sc">{</span>results<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="exception-handling-in-concurrent-code" class="level3">
<h3 class="anchored" data-anchor-id="exception-handling-in-concurrent-code" id="exception-handling-in-concurrent-code">5. Exception Handling in Concurrent Code</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, ProcessPoolExecutor, as_completed</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup logging</span></span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a>logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb30-8"><a href="#cb30-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-9"><a href="#cb30-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> risky_task(task_id):</span>
<span id="cb30-10"><a href="#cb30-10" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> random</span>
<span id="cb30-11"><a href="#cb30-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.3</span>:  <span class="co"># 30% chance of failure</span></span>
<span id="cb30-12"><a href="#cb30-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Task </span><span class="sc">{</span>task_id<span class="sc">}</span><span class="ss"> failed"</span>)</span>
<span id="cb30-13"><a href="#cb30-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Task </span><span class="sc">{</span>task_id<span class="sc">}</span><span class="ss"> completed"</span></span>
<span id="cb30-14"><a href="#cb30-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-15"><a href="#cb30-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Thread exception handling</span></span>
<span id="cb30-16"><a href="#cb30-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> handle_thread_exceptions():</span>
<span id="cb30-17"><a href="#cb30-17" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb30-18"><a href="#cb30-18" aria-hidden="true" tabindex="-1"></a>    errors <span class="op">=</span> []</span>
<span id="cb30-19"><a href="#cb30-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-20"><a href="#cb30-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb30-21"><a href="#cb30-21" aria-hidden="true" tabindex="-1"></a>        futures <span class="op">=</span> [executor.submit(risky_task, i) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>)]</span>
<span id="cb30-22"><a href="#cb30-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb30-23"><a href="#cb30-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> future <span class="kw">in</span> as_completed(futures):</span>
<span id="cb30-24"><a href="#cb30-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb30-25"><a href="#cb30-25" aria-hidden="true" tabindex="-1"></a>                result <span class="op">=</span> future.result()</span>
<span id="cb30-26"><a href="#cb30-26" aria-hidden="true" tabindex="-1"></a>                results.append(result)</span>
<span id="cb30-27"><a href="#cb30-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb30-28"><a href="#cb30-28" aria-hidden="true" tabindex="-1"></a>                errors.append(<span class="bu">str</span>(e))</span>
<span id="cb30-29"><a href="#cb30-29" aria-hidden="true" tabindex="-1"></a>                logging.error(<span class="ss">f"Task failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-30"><a href="#cb30-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-31"><a href="#cb30-31" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Completed: </span><span class="sc">{</span><span class="bu">len</span>(results)<span class="sc">}</span><span class="ss">, Failed: </span><span class="sc">{</span><span class="bu">len</span>(errors)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-32"><a href="#cb30-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results, errors</span>
<span id="cb30-33"><a href="#cb30-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-34"><a href="#cb30-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Process exception handling</span></span>
<span id="cb30-35"><a href="#cb30-35" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> handle_process_exceptions():</span>
<span id="cb30-36"><a href="#cb30-36" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb30-37"><a href="#cb30-37" aria-hidden="true" tabindex="-1"></a>    errors <span class="op">=</span> []</span>
<span id="cb30-38"><a href="#cb30-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-39"><a href="#cb30-39" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb30-40"><a href="#cb30-40" aria-hidden="true" tabindex="-1"></a>        futures <span class="op">=</span> [executor.submit(risky_task, i) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>)]</span>
<span id="cb30-41"><a href="#cb30-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb30-42"><a href="#cb30-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> future <span class="kw">in</span> as_completed(futures):</span>
<span id="cb30-43"><a href="#cb30-43" aria-hidden="true" tabindex="-1"></a>            <span class="cf">try</span>:</span>
<span id="cb30-44"><a href="#cb30-44" aria-hidden="true" tabindex="-1"></a>                result <span class="op">=</span> future.result()</span>
<span id="cb30-45"><a href="#cb30-45" aria-hidden="true" tabindex="-1"></a>                results.append(result)</span>
<span id="cb30-46"><a href="#cb30-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb30-47"><a href="#cb30-47" aria-hidden="true" tabindex="-1"></a>                errors.append(<span class="bu">str</span>(e))</span>
<span id="cb30-48"><a href="#cb30-48" aria-hidden="true" tabindex="-1"></a>                logging.error(<span class="ss">f"Process task failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-49"><a href="#cb30-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-50"><a href="#cb30-50" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Completed: </span><span class="sc">{</span><span class="bu">len</span>(results)<span class="sc">}</span><span class="ss">, Failed: </span><span class="sc">{</span><span class="bu">len</span>(errors)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-51"><a href="#cb30-51" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results, errors</span>
<span id="cb30-52"><a href="#cb30-52" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-53"><a href="#cb30-53" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom exception handler</span></span>
<span id="cb30-54"><a href="#cb30-54" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ExceptionHandler:</span>
<span id="cb30-55"><a href="#cb30-55" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb30-56"><a href="#cb30-56" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.exceptions <span class="op">=</span> []</span>
<span id="cb30-57"><a href="#cb30-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lock <span class="op">=</span> threading.Lock()</span>
<span id="cb30-58"><a href="#cb30-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-59"><a href="#cb30-59" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> handle_exception(<span class="va">self</span>, exception):</span>
<span id="cb30-60"><a href="#cb30-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="va">self</span>.lock:</span>
<span id="cb30-61"><a href="#cb30-61" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.exceptions.append(exception)</span>
<span id="cb30-62"><a href="#cb30-62" aria-hidden="true" tabindex="-1"></a>            logging.error(<span class="ss">f"Exception caught: </span><span class="sc">{</span>exception<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb30-63"><a href="#cb30-63" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-64"><a href="#cb30-64" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> task_with_exception_handler(task_id, exception_handler):</span>
<span id="cb30-65"><a href="#cb30-65" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb30-66"><a href="#cb30-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> risky_task(task_id)</span>
<span id="cb30-67"><a href="#cb30-67" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb30-68"><a href="#cb30-68" aria-hidden="true" tabindex="-1"></a>        exception_handler.handle_exception(e)</span>
<span id="cb30-69"><a href="#cb30-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb30-70"><a href="#cb30-70" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-71"><a href="#cb30-71" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb30-72"><a href="#cb30-72" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb30-73"><a href="#cb30-73" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Thread exception handling:"</span>)</span>
<span id="cb30-74"><a href="#cb30-74" aria-hidden="true" tabindex="-1"></a>    handle_thread_exceptions()</span>
<span id="cb30-75"><a href="#cb30-75" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb30-76"><a href="#cb30-76" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">Process exception handling:"</span>)</span>
<span id="cb30-77"><a href="#cb30-77" aria-hidden="true" tabindex="-1"></a>    handle_process_exceptions()</span></code></pre></div></div>
</section>
<section id="performance-monitoring" class="level3">
<h3 class="anchored" data-anchor-id="performance-monitoring" id="performance-monitoring">6. Performance Monitoring</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> multiprocessing</span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> psutil</span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> concurrent.futures <span class="im">import</span> ThreadPoolExecutor, ProcessPoolExecutor</span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PerformanceMonitor:</span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb31-9"><a href="#cb31-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.start_time <span class="op">=</span> <span class="va">None</span></span>
<span id="cb31-10"><a href="#cb31-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.end_time <span class="op">=</span> <span class="va">None</span></span>
<span id="cb31-11"><a href="#cb31-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cpu_percent <span class="op">=</span> []</span>
<span id="cb31-12"><a href="#cb31-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.memory_percent <span class="op">=</span> []</span>
<span id="cb31-13"><a href="#cb31-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitoring <span class="op">=</span> <span class="va">False</span></span>
<span id="cb31-14"><a href="#cb31-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitor_thread <span class="op">=</span> <span class="va">None</span></span>
<span id="cb31-15"><a href="#cb31-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-16"><a href="#cb31-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> start_monitoring(<span class="va">self</span>):</span>
<span id="cb31-17"><a href="#cb31-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.start_time <span class="op">=</span> time.time()</span>
<span id="cb31-18"><a href="#cb31-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitoring <span class="op">=</span> <span class="va">True</span></span>
<span id="cb31-19"><a href="#cb31-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitor_thread <span class="op">=</span> threading.Thread(target<span class="op">=</span><span class="va">self</span>._monitor)</span>
<span id="cb31-20"><a href="#cb31-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitor_thread.start()</span>
<span id="cb31-21"><a href="#cb31-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-22"><a href="#cb31-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> stop_monitoring(<span class="va">self</span>):</span>
<span id="cb31-23"><a href="#cb31-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.end_time <span class="op">=</span> time.time()</span>
<span id="cb31-24"><a href="#cb31-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.monitoring <span class="op">=</span> <span class="va">False</span></span>
<span id="cb31-25"><a href="#cb31-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.monitor_thread:</span>
<span id="cb31-26"><a href="#cb31-26" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.monitor_thread.join()</span>
<span id="cb31-27"><a href="#cb31-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-28"><a href="#cb31-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _monitor(<span class="va">self</span>):</span>
<span id="cb31-29"><a href="#cb31-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">while</span> <span class="va">self</span>.monitoring:</span>
<span id="cb31-30"><a href="#cb31-30" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.cpu_percent.append(psutil.cpu_percent())</span>
<span id="cb31-31"><a href="#cb31-31" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.memory_percent.append(psutil.virtual_memory().percent)</span>
<span id="cb31-32"><a href="#cb31-32" aria-hidden="true" tabindex="-1"></a>            time.sleep(<span class="fl">0.1</span>)</span>
<span id="cb31-33"><a href="#cb31-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-34"><a href="#cb31-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_stats(<span class="va">self</span>):</span>
<span id="cb31-35"><a href="#cb31-35" aria-hidden="true" tabindex="-1"></a>        duration <span class="op">=</span> <span class="va">self</span>.end_time <span class="op">-</span> <span class="va">self</span>.start_time <span class="cf">if</span> <span class="va">self</span>.end_time <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb31-36"><a href="#cb31-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb31-37"><a href="#cb31-37" aria-hidden="true" tabindex="-1"></a>            <span class="st">'duration'</span>: duration,</span>
<span id="cb31-38"><a href="#cb31-38" aria-hidden="true" tabindex="-1"></a>            <span class="st">'avg_cpu'</span>: <span class="bu">sum</span>(<span class="va">self</span>.cpu_percent) <span class="op">/</span> <span class="bu">len</span>(<span class="va">self</span>.cpu_percent) <span class="cf">if</span> <span class="va">self</span>.cpu_percent <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb31-39"><a href="#cb31-39" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_cpu'</span>: <span class="bu">max</span>(<span class="va">self</span>.cpu_percent) <span class="cf">if</span> <span class="va">self</span>.cpu_percent <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb31-40"><a href="#cb31-40" aria-hidden="true" tabindex="-1"></a>            <span class="st">'avg_memory'</span>: <span class="bu">sum</span>(<span class="va">self</span>.memory_percent) <span class="op">/</span> <span class="bu">len</span>(<span class="va">self</span>.memory_percent) <span class="cf">if</span> <span class="va">self</span>.memory_percent <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb31-41"><a href="#cb31-41" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_memory'</span>: <span class="bu">max</span>(<span class="va">self</span>.memory_percent) <span class="cf">if</span> <span class="va">self</span>.memory_percent <span class="cf">else</span> <span class="dv">0</span></span>
<span id="cb31-42"><a href="#cb31-42" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb31-43"><a href="#cb31-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-44"><a href="#cb31-44" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_intensive_task(n):</span>
<span id="cb31-45"><a href="#cb31-45" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb31-46"><a href="#cb31-46" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n <span class="op">*</span> <span class="dv">100000</span>):</span>
<span id="cb31-47"><a href="#cb31-47" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> i</span>
<span id="cb31-48"><a href="#cb31-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total</span>
<span id="cb31-49"><a href="#cb31-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-50"><a href="#cb31-50" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_approaches():</span>
<span id="cb31-51"><a href="#cb31-51" aria-hidden="true" tabindex="-1"></a>    tasks <span class="op">=</span> [<span class="dv">1000</span>] <span class="op">*</span> <span class="dv">8</span></span>
<span id="cb31-52"><a href="#cb31-52" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-53"><a href="#cb31-53" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sequential</span></span>
<span id="cb31-54"><a href="#cb31-54" aria-hidden="true" tabindex="-1"></a>    monitor <span class="op">=</span> PerformanceMonitor()</span>
<span id="cb31-55"><a href="#cb31-55" aria-hidden="true" tabindex="-1"></a>    monitor.start_monitoring()</span>
<span id="cb31-56"><a href="#cb31-56" aria-hidden="true" tabindex="-1"></a>    sequential_results <span class="op">=</span> [cpu_intensive_task(n) <span class="cf">for</span> n <span class="kw">in</span> tasks]</span>
<span id="cb31-57"><a href="#cb31-57" aria-hidden="true" tabindex="-1"></a>    monitor.stop_monitoring()</span>
<span id="cb31-58"><a href="#cb31-58" aria-hidden="true" tabindex="-1"></a>    sequential_stats <span class="op">=</span> monitor.get_stats()</span>
<span id="cb31-59"><a href="#cb31-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-60"><a href="#cb31-60" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Threading</span></span>
<span id="cb31-61"><a href="#cb31-61" aria-hidden="true" tabindex="-1"></a>    monitor <span class="op">=</span> PerformanceMonitor()</span>
<span id="cb31-62"><a href="#cb31-62" aria-hidden="true" tabindex="-1"></a>    monitor.start_monitoring()</span>
<span id="cb31-63"><a href="#cb31-63" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ThreadPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb31-64"><a href="#cb31-64" aria-hidden="true" tabindex="-1"></a>        thread_results <span class="op">=</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(cpu_intensive_task, tasks))</span>
<span id="cb31-65"><a href="#cb31-65" aria-hidden="true" tabindex="-1"></a>    monitor.stop_monitoring()</span>
<span id="cb31-66"><a href="#cb31-66" aria-hidden="true" tabindex="-1"></a>    thread_stats <span class="op">=</span> monitor.get_stats()</span>
<span id="cb31-67"><a href="#cb31-67" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-68"><a href="#cb31-68" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Multiprocessing</span></span>
<span id="cb31-69"><a href="#cb31-69" aria-hidden="true" tabindex="-1"></a>    monitor <span class="op">=</span> PerformanceMonitor()</span>
<span id="cb31-70"><a href="#cb31-70" aria-hidden="true" tabindex="-1"></a>    monitor.start_monitoring()</span>
<span id="cb31-71"><a href="#cb31-71" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> ProcessPoolExecutor(max_workers<span class="op">=</span><span class="dv">4</span>) <span class="im">as</span> executor:</span>
<span id="cb31-72"><a href="#cb31-72" aria-hidden="true" tabindex="-1"></a>        process_results <span class="op">=</span> <span class="bu">list</span>(executor.<span class="bu">map</span>(cpu_intensive_task, tasks))</span>
<span id="cb31-73"><a href="#cb31-73" aria-hidden="true" tabindex="-1"></a>    monitor.stop_monitoring()</span>
<span id="cb31-74"><a href="#cb31-74" aria-hidden="true" tabindex="-1"></a>    process_stats <span class="op">=</span> monitor.get_stats()</span>
<span id="cb31-75"><a href="#cb31-75" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-76"><a href="#cb31-76" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Performance Comparison:"</span>)</span>
<span id="cb31-77"><a href="#cb31-77" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Sequential - Duration: </span><span class="sc">{</span>sequential_stats[<span class="st">'duration'</span>]<span class="sc">:.2f}</span><span class="ss">s, "</span></span>
<span id="cb31-78"><a href="#cb31-78" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Avg CPU: </span><span class="sc">{</span>sequential_stats[<span class="st">'avg_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%, "</span></span>
<span id="cb31-79"><a href="#cb31-79" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Max CPU: </span><span class="sc">{</span>sequential_stats[<span class="st">'max_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb31-80"><a href="#cb31-80" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-81"><a href="#cb31-81" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Threading - Duration: </span><span class="sc">{</span>thread_stats[<span class="st">'duration'</span>]<span class="sc">:.2f}</span><span class="ss">s, "</span></span>
<span id="cb31-82"><a href="#cb31-82" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Avg CPU: </span><span class="sc">{</span>thread_stats[<span class="st">'avg_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%, "</span></span>
<span id="cb31-83"><a href="#cb31-83" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Max CPU: </span><span class="sc">{</span>thread_stats[<span class="st">'max_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb31-84"><a href="#cb31-84" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb31-85"><a href="#cb31-85" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Multiprocessing - Duration: </span><span class="sc">{</span>process_stats[<span class="st">'duration'</span>]<span class="sc">:.2f}</span><span class="ss">s, "</span></span>
<span id="cb31-86"><a href="#cb31-86" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Avg CPU: </span><span class="sc">{</span>process_stats[<span class="st">'avg_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%, "</span></span>
<span id="cb31-87"><a href="#cb31-87" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"Max CPU: </span><span class="sc">{</span>process_stats[<span class="st">'max_cpu'</span>]<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb31-88"><a href="#cb31-88" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-89"><a href="#cb31-89" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb31-90"><a href="#cb31-90" aria-hidden="true" tabindex="-1"></a>    benchmark_approaches()</span></code></pre></div></div>
</section>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways" id="key-takeaways">Key Takeaways</h2>
<section id="when-to-use-threading" class="level3">
<h3 class="anchored" data-anchor-id="when-to-use-threading" id="when-to-use-threading">When to Use Threading</h3>
<ul>
<li>I/O-bound operations (file reading, network requests, database queries)</li>
<li>Tasks that spend time waiting for external resources</li>
<li>When you need shared memory access</li>
<li>Lighter weight than processes</li>
</ul>
</section>
<section id="when-to-use-multiprocessing" class="level3">
<h3 class="anchored" data-anchor-id="when-to-use-multiprocessing" id="when-to-use-multiprocessing">When to Use Multiprocessing</h3>
<ul>
<li>CPU-intensive computations</li>
<li>Tasks that can be parallelized independently</li>
<li>When you need to bypass the GIL</li>
<li>When process isolation is important for stability</li>
</ul>
</section>
<section id="general-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="general-best-practices" id="general-best-practices">General Best Practices</h3>
<ul>
<li>Always use context managers (<code>with</code> statements) for resource management</li>
<li>Handle exceptions properly in concurrent code</li>
<li>Use appropriate synchronization primitives to avoid race conditions</li>
<li>Monitor performance to ensure concurrency is actually helping</li>
<li>Consider using <code>concurrent.futures</code> for simpler concurrent programming</li>
<li>Be mindful of the overhead of creating threads/processes</li>
<li>Test concurrent code thoroughly as bugs can be hard to reproduce</li>
</ul>
</section>
<section id="common-pitfalls-to-avoid" class="level3">
<h3 class="anchored" data-anchor-id="common-pitfalls-to-avoid" id="common-pitfalls-to-avoid">Common Pitfalls to Avoid</h3>
<ul>
<li>Race conditions due to shared state</li>
<li>Deadlocks from improper lock ordering</li>
<li>Memory leaks from not properly cleaning up processes</li>
<li>Pickle errors when passing objects between processes</li>
<li>Not handling exceptions in concurrent tasks</li>
<li>Creating too many threads/processes (use pools instead)</li>
</ul>
<p>This guide provides a solid foundation for understanding and implementing concurrent programming in Python. Remember that the choice between threading and multiprocessing depends on your specific use case, and sometimes a hybrid approach or alternative solutions like asyncio might be more appropriate.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Convolutional Kolmogorov-Arnold Networks: A Deep Dive into Next-Generation Neural Architectures]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-guide/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-guide/</guid>
      <pubDate>Sat, 05 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="convolutional-kolmogorov-arnold-networks-a-deep-dive-into-next-generation-neural-architectures" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-guide/ckan.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Convolutional Kolmogorov-Arnold Networks (CKANs) represent a groundbreaking fusion of classical mathematical theory and modern deep learning architectures. By integrating the Kolmogorov-Arnold representation theorem with convolutional neural networks, CKANs offer a novel approach to function approximation that challenges traditional activation function paradigms.</p>
<p>Traditional neural networks rely on fixed activation functions (ReLU, sigmoid, tanh) applied to linear transformations. In contrast, CKANs replace these fixed activations with learnable univariate functions, creating a more flexible and theoretically grounded architecture that can potentially achieve superior approximation capabilities with fewer parameters.</p>
</section>
<section id="theoretical-foundation-the-kolmogorov-arnold-representation-theorem" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-foundation-the-kolmogorov-arnold-representation-theorem" id="theoretical-foundation-the-kolmogorov-arnold-representation-theorem">Theoretical Foundation: The Kolmogorov-Arnold Representation Theorem</h2>
<p>The Kolmogorov-Arnold representation theorem, proven by Vladimir Arnold in 1957, states that any multivariate continuous function can be represented as a superposition of continuous functions of a single variable. Formally, for any continuous function f: [0,1]^n → ℝ, there exist continuous functions φ_{q,p}: ℝ → ℝ such that:</p>
<p><span class="math display">\[
f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q\left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right)
\]</span></p>
<p>This theorem suggests that complex multivariate functions can be decomposed into simpler univariate components, providing theoretical justification for the KAN architecture approach.</p>
</section>
<section id="architecture-design" class="level2">
<h2 class="anchored" data-anchor-id="architecture-design" id="architecture-design">Architecture Design</h2>
<section id="core-components" class="level3">
<h3 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h3>
<p>CKANs maintain the spatial processing capabilities of CNNs while incorporating KAN principles. The key architectural components include:</p>
<ol type="1">
<li><strong>Learnable Activation Functions</strong>: Replace traditional fixed activations with parameterized univariate functions</li>
<li><strong>Convolutional KAN Layers</strong>: Adapt KAN principles to work with spatial data</li>
<li><strong>Spline-based Function Approximation</strong>: Use B-splines or other basis functions to represent learnable activations</li>
<li><strong>Hierarchical Feature Extraction</strong>: Preserve CNN’s ability to learn hierarchical representations</li>
</ol>
</section>
<section id="convolutional-kan-layer-structure" class="level3">
<h3 class="anchored" data-anchor-id="convolutional-kan-layer-structure" id="convolutional-kan-layer-structure">Convolutional KAN Layer Structure</h3>
<p>A typical CKAN layer consists of:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ConvKANLayer(nn.Module):</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, kernel_size, grid_size<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Conv2d(in_channels, out_channels, kernel_size, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spline_functions <span class="op">=</span> SplineActivation(out_channels, grid_size)</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply convolution without bias</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        conv_out <span class="op">=</span> <span class="va">self</span>.conv(x)</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply learnable spline activations</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.spline_functions(conv_out)</span></code></pre></div></div>
</section>
</section>
<section id="implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="implementation-details" id="implementation-details">Implementation Details</h2>
<section id="spline-based-activation-functions" class="level3">
<h3 class="anchored" data-anchor-id="spline-based-activation-functions" id="spline-based-activation-functions">Spline-based Activation Functions</h3>
<p>The learnable activation functions are typically implemented using B-splines or other basis function expansions:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SplineActivation(nn.Module):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_channels, grid_size<span class="op">=</span><span class="dv">5</span>, spline_order<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.grid_size <span class="op">=</span> grid_size</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spline_order <span class="op">=</span> spline_order</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize grid points</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'grid'</span>, torch.linspace(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>, grid_size))</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Learnable spline coefficients for each channel</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.coefficients <span class="op">=</span> nn.Parameter(</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>            torch.randn(num_channels, grid_size <span class="op">+</span> spline_order)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>        batch_size, channels, height, width <span class="op">=</span> x.shape</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Reshape for spline evaluation</span></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        x_flat <span class="op">=</span> x.view(batch_size, channels, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply spline activation channel-wise</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        activated <span class="op">=</span> torch.zeros_like(x_flat)</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> c <span class="kw">in</span> <span class="bu">range</span>(channels):</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>            activated[:, c, :] <span class="op">=</span> <span class="va">self</span>.evaluate_spline(</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>                x_flat[:, c, :], <span class="va">self</span>.coefficients[c]</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> activated.view(batch_size, channels, height, width)</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_spline(<span class="va">self</span>, x, coeffs):</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># B-spline evaluation using de Boor's algorithm</span></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.de_boor_algorithm(x, coeffs)</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> de_boor_algorithm(<span class="va">self</span>, x, coeffs):</span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simplified B-spline evaluation</span></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># In practice, use optimized implementations</span></span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>        x_clamped <span class="op">=</span> torch.clamp(x, <span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>)</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear interpolation for simplicity (extend to higher orders)</span></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>        grid_indices <span class="op">=</span> torch.searchsorted(<span class="va">self</span>.grid, x_clamped)</span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>        grid_indices <span class="op">=</span> torch.clamp(grid_indices, <span class="dv">1</span>, <span class="bu">len</span>(<span class="va">self</span>.grid) <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>        x0 <span class="op">=</span> <span class="va">self</span>.grid[grid_indices <span class="op">-</span> <span class="dv">1</span>]</span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>        x1 <span class="op">=</span> <span class="va">self</span>.grid[grid_indices]</span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear interpolation weights</span></span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>        w1 <span class="op">=</span> (x_clamped <span class="op">-</span> x0) <span class="op">/</span> (x1 <span class="op">-</span> x0)</span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>        w0 <span class="op">=</span> <span class="dv">1</span> <span class="op">-</span> w1</span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Interpolate coefficients</span></span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>        y0 <span class="op">=</span> coeffs[grid_indices <span class="op">-</span> <span class="dv">1</span>]</span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>        y1 <span class="op">=</span> coeffs[grid_indices]</span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> w0 <span class="op">*</span> y0 <span class="op">+</span> w1 <span class="op">*</span> y1</span></code></pre></div></div>
</section>
<section id="complete-ckan-architecture" class="level3">
<h3 class="anchored" data-anchor-id="complete-ckan-architecture" id="complete-ckan-architecture">Complete CKAN Architecture</h3>
<p>Here’s a comprehensive implementation of a CKAN for image classification:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ConvKANBlock(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, kernel_size<span class="op">=</span><span class="dv">3</span>, </span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>                 stride<span class="op">=</span><span class="dv">1</span>, padding<span class="op">=</span><span class="dv">1</span>, grid_size<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Conv2d(</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>            in_channels, out_channels, kernel_size, </span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>            stride<span class="op">=</span>stride, padding<span class="op">=</span>padding, bias<span class="op">=</span><span class="va">False</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spline_activation <span class="op">=</span> SplineActivation(out_channels, grid_size)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.batch_norm <span class="op">=</span> nn.BatchNorm2d(out_channels)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv(x)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.spline_activation(x)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.batch_norm(x)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CKAN(nn.Module):</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>, grid_size<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feature extraction layers</span></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> ConvKANBlock(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">7</span>, stride<span class="op">=</span><span class="dv">2</span>, padding<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pool1 <span class="op">=</span> nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> ConvKANBlock(<span class="dv">64</span>, <span class="dv">128</span>, kernel_size<span class="op">=</span><span class="dv">5</span>, padding<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pool2 <span class="op">=</span> nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv3 <span class="op">=</span> ConvKANBlock(<span class="dv">128</span>, <span class="dv">256</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv4 <span class="op">=</span> ConvKANBlock(<span class="dv">256</span>, <span class="dv">256</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pool3 <span class="op">=</span> nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv5 <span class="op">=</span> ConvKANBlock(<span class="dv">256</span>, <span class="dv">512</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv6 <span class="op">=</span> ConvKANBlock(<span class="dv">512</span>, <span class="dv">512</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pool4 <span class="op">=</span> nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global average pooling</span></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_pool <span class="op">=</span> nn.AdaptiveAvgPool2d((<span class="dv">1</span>, <span class="dv">1</span>))</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classification head with KAN layers</span></span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">512</span>, <span class="dv">256</span>),</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>            SplineActivation1D(<span class="dv">256</span>, grid_size),</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.5</span>),</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">256</span>, num_classes)</span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feature extraction</span></span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv1(x)</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pool1(x)</span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv2(x)</span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pool2(x)</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv3(x)</span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv4(x)</span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pool3(x)</span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv5(x)</span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv6(x)</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pool4(x)</span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global pooling and classification</span></span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.global_pool(x)</span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SplineActivation1D(nn.Module):</span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""1D version for fully connected layers"""</span></span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_features, grid_size<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.grid_size <span class="op">=</span> grid_size</span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'grid'</span>, torch.linspace(<span class="op">-</span><span class="dv">2</span>, <span class="dv">2</span>, grid_size))</span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.coefficients <span class="op">=</span> nn.Parameter(torch.randn(num_features, grid_size))</span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>        batch_size, features <span class="op">=</span> x.shape</span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>        activated <span class="op">=</span> torch.zeros_like(x)</span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> f <span class="kw">in</span> <span class="bu">range</span>(features):</span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>            activated[:, f] <span class="op">=</span> <span class="va">self</span>.evaluate_spline_1d(x[:, f], <span class="va">self</span>.coefficients[f])</span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> activated</span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> evaluate_spline_1d(<span class="va">self</span>, x, coeffs):</span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>        x_clamped <span class="op">=</span> torch.clamp(x, <span class="op">-</span><span class="dv">2</span>, <span class="dv">2</span>)</span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>        grid_indices <span class="op">=</span> torch.searchsorted(<span class="va">self</span>.grid, x_clamped)</span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a>        grid_indices <span class="op">=</span> torch.clamp(grid_indices, <span class="dv">1</span>, <span class="bu">len</span>(<span class="va">self</span>.grid) <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a>        x0 <span class="op">=</span> <span class="va">self</span>.grid[grid_indices <span class="op">-</span> <span class="dv">1</span>]</span>
<span id="cb3-96"><a href="#cb3-96" aria-hidden="true" tabindex="-1"></a>        x1 <span class="op">=</span> <span class="va">self</span>.grid[grid_indices]</span>
<span id="cb3-97"><a href="#cb3-97" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-98"><a href="#cb3-98" aria-hidden="true" tabindex="-1"></a>        w1 <span class="op">=</span> (x_clamped <span class="op">-</span> x0) <span class="op">/</span> (x1 <span class="op">-</span> x0)</span>
<span id="cb3-99"><a href="#cb3-99" aria-hidden="true" tabindex="-1"></a>        w0 <span class="op">=</span> <span class="dv">1</span> <span class="op">-</span> w1</span>
<span id="cb3-100"><a href="#cb3-100" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-101"><a href="#cb3-101" aria-hidden="true" tabindex="-1"></a>        y0 <span class="op">=</span> coeffs[grid_indices <span class="op">-</span> <span class="dv">1</span>]</span>
<span id="cb3-102"><a href="#cb3-102" aria-hidden="true" tabindex="-1"></a>        y1 <span class="op">=</span> coeffs[grid_indices]</span>
<span id="cb3-103"><a href="#cb3-103" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-104"><a href="#cb3-104" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> w0 <span class="op">*</span> y0 <span class="op">+</span> w1 <span class="op">*</span> y1</span></code></pre></div></div>
</section>
</section>
<section id="training-considerations" class="level2">
<h2 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h2>
<section id="optimization-challenges" class="level3">
<h3 class="anchored" data-anchor-id="optimization-challenges" id="optimization-challenges">Optimization Challenges</h3>
<p>Training CKANs presents unique challenges:</p>
<ol type="1">
<li><strong>Spline Coefficient Initialization</strong>: Proper initialization of spline coefficients is crucial</li>
<li><strong>Learning Rate Scheduling</strong>: Different learning rates may be needed for spline parameters vs.&nbsp;convolution weights</li>
<li><strong>Regularization</strong>: Spline smoothness regularization prevents overfitting</li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CKANTrainer:</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, device):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model.to(device)</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Separate optimizers for different parameter types</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        conv_params <span class="op">=</span> []</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        spline_params <span class="op">=</span> []</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">'coefficients'</span> <span class="kw">in</span> name:</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>                spline_params.append(param)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>                conv_params.append(param)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv_optimizer <span class="op">=</span> torch.optim.Adam(conv_params, lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spline_optimizer <span class="op">=</span> torch.optim.Adam(spline_params, lr<span class="op">=</span><span class="fl">1e-2</span>)</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler_conv <span class="op">=</span> torch.optim.lr_scheduler.StepLR(</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.conv_optimizer, step_size<span class="op">=</span><span class="dv">30</span>, gamma<span class="op">=</span><span class="fl">0.1</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler_spline <span class="op">=</span> torch.optim.lr_scheduler.StepLR(</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.spline_optimizer, step_size<span class="op">=</span><span class="dv">30</span>, gamma<span class="op">=</span><span class="fl">0.1</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_epoch(<span class="va">self</span>, dataloader):</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.train()</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(<span class="va">self</span>.device), target.to(<span class="va">self</span>.device)</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Zero gradients</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.conv_optimizer.zero_grad()</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.spline_optimizer.zero_grad()</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(data)</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> <span class="va">self</span>.criterion(output, target)</span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add spline smoothness regularization</span></span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>            spline_reg <span class="op">=</span> <span class="va">self</span>.compute_spline_regularization()</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>            total_loss_with_reg <span class="op">=</span> loss <span class="op">+</span> <span class="fl">0.001</span> <span class="op">*</span> spline_reg</span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass</span></span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>            total_loss_with_reg.backward()</span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Optimize</span></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.conv_optimizer.step()</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.spline_optimizer.step()</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(dataloader)</span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_spline_regularization(<span class="va">self</span>):</span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute smoothness regularization for spline functions"""</span></span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        reg_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> module <span class="kw">in</span> <span class="va">self</span>.model.modules():</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(module, (SplineActivation, SplineActivation1D)):</span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Second derivative approximation for smoothness</span></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>                coeffs <span class="op">=</span> module.coefficients</span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>                second_deriv <span class="op">=</span> coeffs[:, <span class="dv">2</span>:] <span class="op">-</span> <span class="dv">2</span> <span class="op">*</span> coeffs[:, <span class="dv">1</span>:<span class="op">-</span><span class="dv">1</span>] <span class="op">+</span> coeffs[:, :<span class="op">-</span><span class="dv">2</span>]</span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a>                reg_loss <span class="op">+=</span> torch.mean(second_deriv <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> reg_loss</span></code></pre></div></div>
</section>
</section>
<section id="performance-analysis" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis" id="performance-analysis">Performance Analysis</h2>
<section id="theoretical-advantages" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-advantages" id="theoretical-advantages">Theoretical Advantages</h3>
<p>CKANs offer several theoretical advantages:</p>
<ol type="1">
<li><strong>Universal Approximation</strong>: The Kolmogorov-Arnold theorem guarantees that any continuous function can be represented</li>
<li><strong>Parameter Efficiency</strong>: Potentially fewer parameters needed compared to traditional CNNs</li>
<li><strong>Interpretability</strong>: Learnable activation functions provide insights into learned representations</li>
<li><strong>Adaptive Nonlinearity</strong>: Network can learn optimal nonlinear transformations for specific tasks</li>
</ol>
</section>
<section id="empirical-evaluation" class="level3">
<h3 class="anchored" data-anchor-id="empirical-evaluation" id="empirical-evaluation">Empirical Evaluation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate_ckan_performance():</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive evaluation framework"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model comparison</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    models <span class="op">=</span> {</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">'CKAN'</span>: CKAN(num_classes<span class="op">=</span><span class="dv">10</span>, grid_size<span class="op">=</span><span class="dv">5</span>),</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'CNN'</span>: TraditionalCNN(num_classes<span class="op">=</span><span class="dv">10</span>),</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'ResNet'</span>: torchvision.models.resnet18(num_classes<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluation metrics</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    metrics <span class="op">=</span> {</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">'accuracy'</span>: [],</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">'parameters'</span>: [],</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">'training_time'</span>: [],</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="st">'inference_time'</span>: []</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, model <span class="kw">in</span> models.items():</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Count parameters</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        param_count <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'parameters'</span>].append(param_count)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Training evaluation</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        trainer <span class="op">=</span> CKANTrainer(model, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">=</span> trainer.train_epoch(train_loader)</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        training_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'training_time'</span>].append(training_time)</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Accuracy evaluation</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>        accuracy <span class="op">=</span> evaluate_model(model, test_loader)</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'accuracy'</span>].append(accuracy)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference time</span></span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        inference_time <span class="op">=</span> measure_inference_time(model, test_loader)</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        metrics[<span class="st">'inference_time'</span>].append(inference_time)</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> metrics</span></code></pre></div></div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="adaptive-grid-refinement" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-grid-refinement" id="adaptive-grid-refinement">Adaptive Grid Refinement</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdaptiveSplineActivation(SplineActivation):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_channels, initial_grid_size<span class="op">=</span><span class="dv">5</span>, max_grid_size<span class="op">=</span><span class="dv">20</span>):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(num_channels, initial_grid_size)</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_grid_size <span class="op">=</span> max_grid_size</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.refinement_threshold <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> refine_grid(<span class="va">self</span>, x):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Adaptively refine grid based on activation distribution"""</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Analyze activation distribution</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            x_flat <span class="op">=</span> x.view(<span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>            hist, bin_edges <span class="op">=</span> torch.histogram(x_flat, bins<span class="op">=</span><span class="va">self</span>.grid_size)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Identify regions needing refinement</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            high_density_regions <span class="op">=</span> hist <span class="op">&gt;</span> <span class="va">self</span>.refinement_threshold <span class="op">*</span> torch.<span class="bu">max</span>(hist)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> torch.<span class="bu">any</span>(high_density_regions) <span class="kw">and</span> <span class="bu">len</span>(<span class="va">self</span>.grid) <span class="op">&lt;</span> <span class="va">self</span>.max_grid_size:</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Add grid points in high-density regions</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>                new_grid_points <span class="op">=</span> []</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(high_density_regions)):</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> high_density_regions[i]:</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>                        mid_point <span class="op">=</span> (bin_edges[i] <span class="op">+</span> bin_edges[i<span class="op">+</span><span class="dv">1</span>]) <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>                        new_grid_points.append(mid_point)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> new_grid_points:</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.grid <span class="op">=</span> torch.sort(torch.cat([<span class="va">self</span>.grid, torch.tensor(new_grid_points)]))[<span class="dv">0</span>]</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Resize coefficient matrix</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.resize_coefficients()</span></code></pre></div></div>
</section>
<section id="multi-scale-ckan-architecture" class="level3">
<h3 class="anchored" data-anchor-id="multi-scale-ckan-architecture" id="multi-scale-ckan-architecture">Multi-scale CKAN Architecture</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiScaleCKAN(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multi-scale feature extraction</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale1 <span class="op">=</span> ConvKANBlock(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale2 <span class="op">=</span> ConvKANBlock(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">5</span>, padding<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale3 <span class="op">=</span> ConvKANBlock(<span class="dv">3</span>, <span class="dv">64</span>, kernel_size<span class="op">=</span><span class="dv">7</span>, padding<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feature fusion</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fusion <span class="op">=</span> ConvKANBlock(<span class="dv">192</span>, <span class="dv">128</span>, kernel_size<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Subsequent layers</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv_blocks <span class="op">=</span> nn.Sequential(</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            ConvKANBlock(<span class="dv">128</span>, <span class="dv">256</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>            ConvKANBlock(<span class="dv">256</span>, <span class="dv">512</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>            ConvKANBlock(<span class="dv">512</span>, <span class="dv">1024</span>, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            nn.AdaptiveAvgPool2d((<span class="dv">1</span>, <span class="dv">1</span>))</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="dv">1024</span>, num_classes)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multi-scale feature extraction</span></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        s1 <span class="op">=</span> <span class="va">self</span>.scale1(x)</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        s2 <span class="op">=</span> <span class="va">self</span>.scale2(x)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        s3 <span class="op">=</span> <span class="va">self</span>.scale3(x)</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate and fuse</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        multi_scale <span class="op">=</span> torch.cat([s1, s2, s3], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        fused <span class="op">=</span> <span class="va">self</span>.fusion(multi_scale)</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process through remaining layers</span></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> <span class="va">self</span>.conv_blocks(fused)</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> features.view(features.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.classifier(features)</span></code></pre></div></div>
</section>
</section>
<section id="applications-and-future-directions" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-future-directions" id="applications-and-future-directions">Applications and Future Directions</h2>
<section id="computer-vision-applications" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-applications" id="computer-vision-applications">Computer Vision Applications</h3>
<p>CKANs have shown promising results in various computer vision tasks:</p>
<ol type="1">
<li><strong>Image Classification</strong>: Competitive accuracy with fewer parameters</li>
<li><strong>Object Detection</strong>: Improved feature representation for small objects</li>
<li><strong>Semantic Segmentation</strong>: Better boundary preservation through learnable activations</li>
<li><strong>Medical Imaging</strong>: Enhanced interpretability for diagnostic applications</li>
</ol>
</section>
<section id="research-directions" class="level3">
<h3 class="anchored" data-anchor-id="research-directions" id="research-directions">Research Directions</h3>
<p>Future research directions include:</p>
<ol type="1">
<li><strong>Theoretical Analysis</strong>: Deeper understanding of approximation capabilities</li>
<li><strong>Efficient Implementation</strong>: GPU-optimized spline evaluation algorithms</li>
<li><strong>Architecture Search</strong>: Automated design of CKAN architectures</li>
<li><strong>Transfer Learning</strong>: Pre-trained CKAN models for various domains</li>
<li><strong>Hybrid Architectures</strong>: Combining CKANs with attention mechanisms and transformers</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Convolutional Kolmogorov-Arnold Networks represent a significant advancement in neural network architecture design, offering a principled approach to function approximation that combines classical mathematical theory with modern deep learning techniques. While challenges remain in optimization and implementation, the theoretical foundations and empirical results suggest that CKANs could become a powerful tool in the deep learning toolkit.</p>
<p>The key advantages of CKANs include their theoretical grounding, parameter efficiency, and interpretability. As the field continues to evolve, we can expect further developments in optimization techniques, architectural innovations, and applications across diverse domains.</p>
<p>The implementation examples provided demonstrate the practical aspects of building and training CKANs, though real-world applications will require careful consideration of computational efficiency, hyperparameter tuning, and domain-specific adaptations. The future of CKANs looks promising, with potential applications spanning from computer vision to scientific computing and beyond.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[The Mathematics Behind Convolutional Kolmogorov-Arnold Networks]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-math/</guid>
      <pubDate>Sat, 05 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="the-mathematics-behind-convolutional-kolmogorov-arnold-networks" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-math/ckan-math.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Convolutional Kolmogorov-Arnold Networks (CKANs) represent a revolutionary approach to neural network architecture that combines the theoretical foundations of the Kolmogorov-Arnold representation theorem with the practical advantages of convolutional operations. Unlike traditional Convolutional Neural Networks (CNNs) that rely on fixed linear transformations followed by nonlinear activations, CKANs replace these components with learnable univariate functions, offering a more flexible and theoretically grounded approach to function approximation.</p>
</section>
<section id="the-kolmogorov-arnold-representation-theorem" class="level2">
<h2 class="anchored" data-anchor-id="the-kolmogorov-arnold-representation-theorem" id="the-kolmogorov-arnold-representation-theorem">The Kolmogorov-Arnold Representation Theorem</h2>
<section id="theoretical-foundation" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-foundation" id="theoretical-foundation">Theoretical Foundation</h3>
<p>The Kolmogorov-Arnold representation theorem, proved by Andrey Kolmogorov in 1957 and later refined by Vladimir Arnold, states that any multivariate continuous function can be represented as a superposition of continuous functions of a single variable.</p>
<p><strong>Theorem (Kolmogorov-Arnold)</strong>: For any continuous function <span class="math inline">\(f: [0,1]^n \to \mathbb{R}\)</span>, there exist continuous functions <span class="math inline">\(\phi_{q,p}: [0,1] \to \mathbb{R}\)</span> and <span class="math inline">\(\Phi_q: \mathbb{R} \to \mathbb{R}\)</span> such that:</p>
<p><span class="math display">\[
f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q\left(\sum_{p=1}^{n} \phi_{q,p}(x_p)\right)
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(q\)</span> ranges from <span class="math inline">\(0\)</span> to <span class="math inline">\(2n\)</span></li>
<li><span class="math inline">\(p\)</span> ranges from <span class="math inline">\(1\)</span> to <span class="math inline">\(n\)</span></li>
<li>The functions <span class="math inline">\(\phi_{q,p}\)</span> are universal (independent of <span class="math inline">\(f\)</span>)</li>
<li>Only the outer functions <span class="math inline">\(\Phi_q\)</span> depend on the specific function <span class="math inline">\(f\)</span></li>
</ul>
</section>
<section id="implications-for-neural-networks" class="level3">
<h3 class="anchored" data-anchor-id="implications-for-neural-networks" id="implications-for-neural-networks">Implications for Neural Networks</h3>
<p>This theorem suggests that instead of using traditional linear combinations followed by fixed activation functions, we can construct networks using compositions of univariate functions. This forms the theoretical backbone of Kolmogorov-Arnold Networks (KANs).</p>
</section>
</section>
<section id="from-kans-to-convolutional-kans" class="level2">
<h2 class="anchored" data-anchor-id="from-kans-to-convolutional-kans" id="from-kans-to-convolutional-kans">From KANs to Convolutional KANs</h2>
<section id="standard-kan-architecture" class="level3">
<h3 class="anchored" data-anchor-id="standard-kan-architecture" id="standard-kan-architecture">Standard KAN Architecture</h3>
<p>A standard KAN layer transforms input <span class="math inline">\(\mathbf{x} \in \mathbb{R}^{n_{in}}\)</span> to output <span class="math inline">\(\mathbf{y} \in \mathbb{R}^{n_{out}}\)</span> using:</p>
<p><span class="math display">\[
y_j = \sum_{i=1}^{n_{in}} \phi_{i,j}(x_i)
\]</span></p>
<p>where <span class="math inline">\(\phi_{i,j}: \mathbb{R} \to \mathbb{R}\)</span> are learnable univariate functions, typically parameterized using splines or other basis functions.</p>
</section>
<section id="convolutional-extension" class="level3">
<h3 class="anchored" data-anchor-id="convolutional-extension" id="convolutional-extension">Convolutional Extension</h3>
<p>The challenge in extending KANs to convolutional architectures lies in maintaining the univariate nature of the learnable functions while incorporating spatial locality and translation invariance. CKANs achieve this through several key innovations:</p>
</section>
</section>
<section id="mathematical-formulation-of-ckans" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-formulation-of-ckans" id="mathematical-formulation-of-ckans">Mathematical Formulation of CKANs</h2>
<section id="convolutional-kan-layer" class="level3">
<h3 class="anchored" data-anchor-id="convolutional-kan-layer" id="convolutional-kan-layer">1. Convolutional KAN Layer</h3>
<p>For a CKAN layer with input feature map <span class="math inline">\(\mathbf{X} \in \mathbb{R}^{H \times W \times C_{in}}\)</span> and output <span class="math inline">\(\mathbf{Y} \in \mathbb{R}^{H' \times W' \times C_{out}}\)</span>, the convolution operation is defined as:</p>
<p><span class="math display">\[
Y_{i,j,k} = \sum_{c=1}^{C_{in}} \sum_{u=0}^{K-1} \sum_{v=0}^{K-1} \phi_{c,k,u,v}(X_{i+u,j+v,c})
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\((i,j)\)</span> are spatial coordinates in the output feature map</li>
<li><span class="math inline">\(k\)</span> is the output channel index</li>
<li><span class="math inline">\(c\)</span> is the input channel index</li>
<li><span class="math inline">\(K\)</span> is the kernel size</li>
<li><span class="math inline">\(\phi_{c,k,u,v}\)</span> are learnable univariate functions specific to input channel <span class="math inline">\(c\)</span>, output channel <span class="math inline">\(k\)</span>, and kernel position <span class="math inline">\((u,v)\)</span></li>
</ul>
</section>
<section id="univariate-function-parameterization" class="level3">
<h3 class="anchored" data-anchor-id="univariate-function-parameterization" id="univariate-function-parameterization">2. Univariate Function Parameterization</h3>
<p>The univariate functions <span class="math inline">\(\phi\)</span> are typically parameterized using B-splines or other basis functions. For B-splines of degree <span class="math inline">\(d\)</span> with <span class="math inline">\(n\)</span> control points:</p>
<p><span class="math display">\[
\phi(x) = \sum_{i=0}^{n-1} c_i B_i^d(x)
\]</span></p>
<p>where <span class="math inline">\(c_i\)</span> are learnable coefficients and <span class="math inline">\(B_i^d(x)\)</span> are B-spline basis functions defined recursively:</p>
<p><span class="math display">\[
B_i^0(x) = \begin{cases} 1 &amp; \text{if } t_i \leq x &lt; t_{i+1} \\ 0 &amp; \text{otherwise} \end{cases}
\]</span></p>
<p><span class="math display">\[
B_i^d(x) = \frac{x - t_i}{t_{i+d} - t_i} B_i^{d-1}(x) + \frac{t_{i+d+1} - x}{t_{i+d+1} - t_{i+1}} B_{i+1}^{d-1}(x)
\]</span></p>
</section>
<section id="shared-function-approach" class="level3">
<h3 class="anchored" data-anchor-id="shared-function-approach" id="shared-function-approach">3. Shared Function Approach</h3>
<p>To reduce the number of parameters, CKANs often employ parameter sharing strategies:</p>
<section id="spatial-sharing" class="level4">
<h4 class="anchored" data-anchor-id="spatial-sharing">Spatial Sharing</h4>
<p>Functions are shared across spatial locations: <span class="math display">\[
\phi_{c,k}(x) \text{ for all positions } (u,v)
\]</span></p>
</section>
<section id="channel-grouping" class="level4">
<h4 class="anchored" data-anchor-id="channel-grouping">Channel Grouping</h4>
<p>Functions are shared within channel groups: <span class="math display">\[
\phi_{g,k}(x) \text{ where } g = \lfloor c/G \rfloor \text{ for group size } G
\]</span></p>
</section>
</section>
</section>
<section id="activation-functions-in-ckans" class="level2">
<h2 class="anchored" data-anchor-id="activation-functions-in-ckans" id="activation-functions-in-ckans">Activation Functions in CKANs</h2>
<section id="learnable-activation-functions" class="level3">
<h3 class="anchored" data-anchor-id="learnable-activation-functions" id="learnable-activation-functions">Learnable Activation Functions</h3>
<p>Unlike traditional CNNs with fixed activation functions (ReLU, sigmoid, etc.), CKANs use learnable activation functions. These can be viewed as univariate functions applied element-wise:</p>
<p><span class="math display">\[
\text{Activation}(x) = \psi(x)
\]</span></p>
<p>where <span class="math inline">\(\psi\)</span> is a learnable univariate function, often parameterized as:</p>
<p><span class="math display">\[
\psi(x) = \text{SiLU}(x) + \sum_{i=0}^{n-1} a_i B_i^d(x)
\]</span></p>
<p>The SiLU (Sigmoid Linear Unit) provides a smooth base function, while the spline terms allow for fine-tuning.</p>
</section>
</section>
<section id="training-dynamics-and-optimization" class="level2">
<h2 class="anchored" data-anchor-id="training-dynamics-and-optimization" id="training-dynamics-and-optimization">Training Dynamics and Optimization</h2>
<section id="gradient-computation" class="level3">
<h3 class="anchored" data-anchor-id="gradient-computation" id="gradient-computation">Gradient Computation</h3>
<p>The gradient of the loss function with respect to the spline coefficients involves the derivative of B-spline basis functions:</p>
<p><span class="math display">\[
\frac{\partial L}{\partial c_i} = \frac{\partial L}{\partial \phi} \cdot B_i^d(x)
\]</span></p>
<p>For the derivative of the function itself: <span class="math display">\[
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial \phi} \cdot \sum_{i=0}^{n-1} c_i \frac{dB_i^d(x)}{dx}
\]</span></p>
</section>
<section id="regularization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="regularization-techniques" id="regularization-techniques">Regularization Techniques</h3>
<p>CKANs typically employ several regularization techniques:</p>
<section id="smoothness-regularization" class="level4">
<h4 class="anchored" data-anchor-id="smoothness-regularization">Smoothness Regularization</h4>
<p><span class="math display">\[
R_{\text{smooth}} = \sum_{i,j} \int \left(\frac{d^2\phi_{i,j}(x)}{dx^2}\right)^2 dx
\]</span></p>
</section>
<section id="sparsity-regularization" class="level4">
<h4 class="anchored" data-anchor-id="sparsity-regularization">Sparsity Regularization</h4>
<p><span class="math display">\[
R_{\text{sparse}} = \sum_{i,j} \int |\phi_{i,j}(x)| dx
\]</span></p>
</section>
<section id="total-variation-regularization" class="level4">
<h4 class="anchored" data-anchor-id="total-variation-regularization">Total Variation Regularization</h4>
<p><span class="math display">\[
R_{\text{TV}} = \sum_{i,j} \int \left|\frac{d\phi_{i,j}(x)}{dx}\right| dx
\]</span></p>
</section>
</section>
</section>
<section id="computational-complexity-analysis" class="level2">
<h2 class="anchored" data-anchor-id="computational-complexity-analysis" id="computational-complexity-analysis">Computational Complexity Analysis</h2>
<section id="parameter-count" class="level3">
<h3 class="anchored" data-anchor-id="parameter-count" id="parameter-count">Parameter Count</h3>
<p>For a CKAN layer with:</p>
<ul>
<li>Input channels: <span class="math inline">\(C_{in}\)</span></li>
<li>Output channels: <span class="math inline">\(C_{out}\)</span></li>
<li>Kernel size: <span class="math inline">\(K \times K\)</span></li>
<li>Spline degree: <span class="math inline">\(d\)</span></li>
<li>Control points per spline: <span class="math inline">\(n\)</span></li>
</ul>
<p>The parameter count is: <span class="math display">\[
\text{Parameters} = C_{in} \times C_{out} \times K^2 \times n
\]</span></p>
<p>Compare this to traditional CNN: <span class="math display">\[
\text{Parameters}_{\text{CNN}} = C_{in} \times C_{out} \times K^2
\]</span></p>
</section>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<p>The forward pass complexity for a single CKAN layer is: <span class="math display">\[
O(H \times W \times C_{out} \times C_{in} \times K^2 \times n)
\]</span></p>
<p>where <span class="math inline">\(H \times W\)</span> is the spatial dimension of the output feature map.</p>
</section>
</section>
<section id="architectural-variations" class="level2">
<h2 class="anchored" data-anchor-id="architectural-variations" id="architectural-variations">Architectural Variations</h2>
<section id="depthwise-separable-ckans" class="level3">
<h3 class="anchored" data-anchor-id="depthwise-separable-ckans" id="depthwise-separable-ckans">1. Depthwise Separable CKANs</h3>
<p>Inspired by depthwise separable convolutions, this variant separates the operation into:</p>
<p><strong>Depthwise Convolution</strong>: <span class="math display">\[
Y_{i,j,c} = \sum_{u=0}^{K-1} \sum_{v=0}^{K-1} \phi_{c,u,v}(X_{i+u,j+v,c})
\]</span></p>
<p><strong>Pointwise Convolution</strong>: <span class="math display">\[
Z_{i,j,k} = \sum_{c=1}^{C_{in}} \psi_{c,k}(Y_{i,j,c})
\]</span></p>
</section>
<section id="dilated-ckans" class="level3">
<h3 class="anchored" data-anchor-id="dilated-ckans" id="dilated-ckans">2. Dilated CKANs</h3>
<p>Incorporating dilation for larger receptive fields: <span class="math display">\[
Y_{i,j,k} = \sum_{c=1}^{C_{in}} \sum_{u=0}^{K-1} \sum_{v=0}^{K-1} \phi_{c,k,u,v}(X_{i+d \cdot u,j+d \cdot v,c})
\]</span></p>
<p>where <span class="math inline">\(d\)</span> is the dilation factor.</p>
</section>
<section id="residual-ckans" class="level3">
<h3 class="anchored" data-anchor-id="residual-ckans" id="residual-ckans">3. Residual CKANs</h3>
<p>Combining residual connections with CKAN layers: <span class="math display">\[
Y = \text{CKAN}(X) + \alpha \cdot X
\]</span></p>
<p>where <span class="math inline">\(\alpha\)</span> is a learnable scaling factor.</p>
</section>
</section>
<section id="approximation-properties" class="level2">
<h2 class="anchored" data-anchor-id="approximation-properties" id="approximation-properties">Approximation Properties</h2>
<section id="universal-approximation" class="level3">
<h3 class="anchored" data-anchor-id="universal-approximation" id="universal-approximation">Universal Approximation</h3>
<p>CKANs inherit the universal approximation properties of KANs. For any continuous function <span class="math inline">\(f: \mathbb{R}^n \to \mathbb{R}\)</span> and any <span class="math inline">\(\epsilon &gt; 0\)</span>, there exists a CKAN that approximates <span class="math inline">\(f\)</span> within <span class="math inline">\(\epsilon\)</span> accuracy.</p>
</section>
<section id="convergence-analysis" class="level3">
<h3 class="anchored" data-anchor-id="convergence-analysis" id="convergence-analysis">Convergence Analysis</h3>
<p>The convergence rate of CKANs depends on several factors:</p>
<ol type="1">
<li><strong>Smoothness of target function</strong>: Smoother functions converge faster</li>
<li><strong>Spline degree</strong>: Higher degree splines provide better approximation but may overfit</li>
<li><strong>Number of control points</strong>: More control points increase expressivity but computational cost</li>
</ol>
<p>The approximation error for a function <span class="math inline">\(f\)</span> with <span class="math inline">\(s\)</span>-th order smoothness is bounded by: <span class="math display">\[
\|f - \text{CKAN}(f)\|_\infty \leq C \cdot h^s
\]</span></p>
<p>where <span class="math inline">\(h\)</span> is the spacing between spline knots and <span class="math inline">\(C\)</span> is a constant depending on <span class="math inline">\(f\)</span>.</p>
</section>
</section>
<section id="practical-implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="practical-implementation-considerations" id="practical-implementation-considerations">Practical Implementation Considerations</h2>
<section id="numerical-stability" class="level3">
<h3 class="anchored" data-anchor-id="numerical-stability" id="numerical-stability">Numerical Stability</h3>
<p>CKANs require careful attention to numerical stability:</p>
<ol type="1">
<li><strong>Spline knot placement</strong>: Uniform or adaptive knot placement strategies</li>
<li><strong>Coefficient initialization</strong>: Proper initialization of spline coefficients</li>
<li><strong>Gradient clipping</strong>: Preventing gradient explosion during backpropagation</li>
</ol>
</section>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<p>Several techniques can reduce memory usage:</p>
<ol type="1">
<li><strong>Lazy evaluation</strong>: Computing spline values on-demand</li>
<li><strong>Coefficient sharing</strong>: Sharing coefficients across similar functions</li>
<li><strong>Quantization</strong>: Using lower precision for spline coefficients</li>
</ol>
</section>
</section>
<section id="comparison-with-traditional-cnns" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-traditional-cnns" id="comparison-with-traditional-cnns">Comparison with Traditional CNNs</h2>
<section id="expressivity" class="level3">
<h3 class="anchored" data-anchor-id="expressivity" id="expressivity">Expressivity</h3>
<p>CKANs offer superior expressivity due to:</p>
<ul>
<li>Learnable activation functions</li>
<li>Non-linear transformations in each connection</li>
<li>Adaptive function shapes based on data</li>
</ul>
</section>
<section id="interpretability" class="level3">
<h3 class="anchored" data-anchor-id="interpretability" id="interpretability">Interpretability</h3>
<p>The univariate nature of CKAN functions provides better interpretability:</p>
<ul>
<li>Each function can be visualized as a 1D curve</li>
<li>Function shapes reveal learned patterns</li>
<li>Easier to understand feature transformations</li>
</ul>
</section>
<section id="computational-trade-offs" class="level3">
<h3 class="anchored" data-anchor-id="computational-trade-offs" id="computational-trade-offs">Computational Trade-offs</h3>
<p><strong>Advantages</strong>:</p>
<ul>
<li>Better function approximation with fewer layers</li>
<li>Interpretable learned functions</li>
<li>Theoretical guarantees</li>
</ul>
<p><strong>Disadvantages</strong>:</p>
<ul>
<li>Higher computational cost per layer</li>
<li>More parameters to optimize</li>
<li>Longer training times</li>
</ul>
</section>
</section>
<section id="future-directions-and-extensions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-extensions" id="future-directions-and-extensions">Future Directions and Extensions</h2>
<section id="theoretical-advances" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-advances" id="theoretical-advances">Theoretical Advances</h3>
<ol type="1">
<li><strong>Convergence guarantees</strong>: Developing stronger theoretical guarantees for CKAN convergence</li>
<li><strong>Optimal architectures</strong>: Finding optimal CKAN architectures for specific tasks</li>
<li><strong>Generalization bounds</strong>: Establishing generalization bounds for CKANs</li>
</ol>
</section>
<section id="practical-improvements" class="level3">
<h3 class="anchored" data-anchor-id="practical-improvements" id="practical-improvements">Practical Improvements</h3>
<ol type="1">
<li><strong>Efficient implementations</strong>: Developing more efficient CUDA kernels for CKAN operations</li>
<li><strong>Automated architecture search</strong>: Using neural architecture search for CKAN design</li>
<li><strong>Hardware acceleration</strong>: Designing specialized hardware for CKAN computations</li>
</ol>
</section>
<section id="applications" class="level3">
<h3 class="anchored" data-anchor-id="applications" id="applications">Applications</h3>
<ol type="1">
<li><strong>Computer vision</strong>: Image classification, object detection, segmentation</li>
<li><strong>Scientific computing</strong>: Solving partial differential equations</li>
<li><strong>Signal processing</strong>: Audio and video processing applications</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Convolutional Kolmogorov-Arnold Networks represent a significant advancement in neural network architectures, combining solid theoretical foundations with practical convolutional operations. While computationally more expensive than traditional CNNs, CKANs offer superior expressivity, interpretability, and theoretical guarantees. As the field continues to evolve, we can expect further optimizations and novel applications of this powerful architecture.</p>
<p>The mathematics behind CKANs reveals a rich interplay between approximation theory, spline functions, and deep learning, opening new avenues for both theoretical understanding and practical applications in machine learning.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Convolutional Kolmogorov-Arnold Networks vs Convolutional Neural Networks: A Comprehensive Analysis]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-vs-cnn/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-vs-cnn/</guid>
      <pubDate>Sat, 05 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="convolutional-kolmogorov-arnold-networks-vs-convolutional-neural-networks-a-comprehensive-analysis" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/convkan/ckan-vs-cnn/ckan-vs-cnn.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The landscape of deep learning has been revolutionized by Convolutional Neural Networks (CNNs), which have dominated computer vision tasks for over a decade. However, a new paradigm has emerged that challenges the fundamental assumptions of traditional neural architectures: Convolutional Kolmogorov-Arnold Networks (Convolutional KANs). This innovative approach represents a significant departure from conventional neural network design, offering enhanced parameter efficiency, interpretability, and expressive power.</p>
</section>
<section id="theoretical-foundation" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-foundation" id="theoretical-foundation">Theoretical Foundation</h2>
<section id="the-kolmogorov-arnold-representation-theorem" class="level3">
<h3 class="anchored" data-anchor-id="the-kolmogorov-arnold-representation-theorem" id="the-kolmogorov-arnold-representation-theorem">The Kolmogorov-Arnold Representation Theorem</h3>
<p>The theoretical foundation of KANs lies in the Kolmogorov-Arnold representation theorem, which states that any multivariate continuous function on a bounded domain can be represented as a finite composition of continuous functions of a single variable and the binary operation of addition. This mathematical principle fundamentally challenges the traditional multi-layer perceptron (MLP) approach and provides the basis for a new class of neural networks.</p>
<p>The theorem can be formally expressed as:</p>
<p><span class="math display">\[
f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q\left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right)
\]</span></p>
<p>Where <span class="math inline">\(\phi_{i}\)</span> and <span class="math inline">\(\phi_{i,j}\)</span> are continuous univariate functions, and f is the multivariate function being approximated.</p>
</section>
<section id="architectural-differences" class="level3">
<h3 class="anchored" data-anchor-id="architectural-differences" id="architectural-differences">Architectural Differences</h3>
<p>The fundamental architectural difference between traditional neural networks and KANs lies in the placement and nature of activation functions:</p>
<ul>
<li><strong>Traditional MLPs/CNNs</strong>: Fixed activation functions on nodes (neurons), with linear weights on edges</li>
<li><strong>KANs</strong>: Learnable activation functions on edges (weights), with no linear weights at all</li>
</ul>
</section>
</section>
<section id="convolutional-neural-networks-the-established-paradigm" class="level2">
<h2 class="anchored" data-anchor-id="convolutional-neural-networks-the-established-paradigm" id="convolutional-neural-networks-the-established-paradigm">Convolutional Neural Networks: The Established Paradigm</h2>
<section id="architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h3>
<p>CNNs have been the backbone of computer vision applications since their breakthrough in the early 2010s. The typical CNN architecture consists of:</p>
<ol type="1">
<li><strong>Convolutional Layers</strong>: Apply fixed-weight kernels with linear transformations</li>
<li><strong>Activation Functions</strong>: Non-linear functions (ReLU, sigmoid, tanh) applied to neurons</li>
<li><strong>Pooling Layers</strong>: Downsample feature maps to reduce computational complexity</li>
<li><strong>Fully Connected Layers</strong>: Dense layers for final classification or regression</li>
</ol>
</section>
<section id="key-characteristics" class="level3">
<h3 class="anchored" data-anchor-id="key-characteristics" id="key-characteristics">Key Characteristics</h3>
<ul>
<li><strong>Parameter Sharing</strong>: Convolutional kernels share weights across spatial locations</li>
<li><strong>Translation Invariance</strong>: Ability to detect features regardless of their position in the input</li>
<li><strong>Hierarchical Feature Learning</strong>: Progressive abstraction from low-level to high-level features</li>
<li><strong>Fixed Activation Functions</strong>: Predetermined non-linear functions that remain constant during training</li>
</ul>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<p>Despite their success, CNNs face several inherent limitations:</p>
<ol type="1">
<li><strong>Parameter Inefficiency</strong>: Large numbers of parameters required for complex tasks</li>
<li><strong>Limited Interpretability</strong>: Black-box nature makes understanding difficult</li>
<li><strong>Fixed Representational Capacity</strong>: Predetermined activation functions limit adaptability</li>
<li><strong>Scaling Challenges</strong>: Performance improvements often require significantly larger models</li>
</ol>
</section>
</section>
<section id="convolutional-kolmogorov-arnold-networks-the-new-paradigm" class="level2">
<h2 class="anchored" data-anchor-id="convolutional-kolmogorov-arnold-networks-the-new-paradigm" id="convolutional-kolmogorov-arnold-networks-the-new-paradigm">Convolutional Kolmogorov-Arnold Networks: The New Paradigm</h2>
<section id="architecture-innovation" class="level3">
<h3 class="anchored" data-anchor-id="architecture-innovation" id="architecture-innovation">Architecture Innovation</h3>
<p>Convolutional KANs represent a revolutionary approach to neural network design by replacing traditional fixed-weight kernels with learnable non-linear functions. The key innovations include:</p>
<ol type="1">
<li><strong>Spline-Based Convolutional Layers</strong>: Replace fixed linear weights with learnable spline functions</li>
<li><strong>Edge-Based Activation</strong>: Activation functions are learned on the connections between neurons</li>
<li><strong>Adaptive Kernel Functions</strong>: Convolutional operations with learnable, non-linear transformations</li>
<li><strong>Flexible Representational Power</strong>: Ability to adapt the network’s fundamental computational primitives</li>
</ol>
</section>
<section id="technical-implementation" class="level3">
<h3 class="anchored" data-anchor-id="technical-implementation" id="technical-implementation">Technical Implementation</h3>
<p>In Convolutional KANs, every weight parameter is replaced by a univariate function parametrized as a B-spline. The spline functions provide:</p>
<ul>
<li><strong>Continuous Differentiability</strong>: Smooth gradients for effective backpropagation</li>
<li><strong>Local Control</strong>: Ability to modify function behavior in specific regions</li>
<li><strong>Efficient Representation</strong>: Compact parametrization of complex functions</li>
<li><strong>Numerical Stability</strong>: Well-conditioned optimization properties</li>
</ul>
</section>
<section id="architectural-flexibility" class="level3">
<h3 class="anchored" data-anchor-id="architectural-flexibility" id="architectural-flexibility">Architectural Flexibility</h3>
<p>The Convolutional KAN architecture allows for various configurations:</p>
<ul>
<li><strong>Hybrid Approaches</strong>: Combination of KAN convolutional layers with traditional MLPs</li>
<li><strong>Full KAN Networks</strong>: Complete replacement of traditional layers with KAN equivalents</li>
<li><strong>Scalable Design</strong>: Adaptable to different problem complexities and computational constraints</li>
</ul>
</section>
</section>
<section id="comparative-analysis" class="level2">
<h2 class="anchored" data-anchor-id="comparative-analysis" id="comparative-analysis">Comparative Analysis</h2>
<section id="parameter-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="parameter-efficiency" id="parameter-efficiency">Parameter Efficiency</h3>
<p>One of the most significant advantages of Convolutional KANs is their parameter efficiency. Research has demonstrated that Convolutional KANs require significantly fewer parameters compared to traditional CNNs while maintaining or improving performance. This efficiency stems from:</p>
<ol type="1">
<li><strong>Learnable Function Approximation</strong>: Spline-based functions can represent complex transformations with fewer parameters</li>
<li><strong>Adaptive Representation</strong>: Network can learn optimal activation functions for specific tasks</li>
<li><strong>Reduced Redundancy</strong>: Elimination of fixed linear weights reduces parameter overhead</li>
</ol>
</section>
<section id="expressive-power" class="level3">
<h3 class="anchored" data-anchor-id="expressive-power" id="expressive-power">Expressive Power</h3>
<p>Convolutional KANs offer superior expressive power through:</p>
<ol type="1">
<li><strong>Adaptive Activation Functions</strong>: Ability to learn task-specific non-linearities</li>
<li><strong>Enhanced Function Approximation</strong>: Theoretical foundation in universal approximation</li>
<li><strong>Flexible Computational Primitives</strong>: Learnable spline functions provide greater representational capacity</li>
</ol>
</section>
<section id="interpretability" class="level3">
<h3 class="anchored" data-anchor-id="interpretability" id="interpretability">Interpretability</h3>
<p>KANs provide enhanced interpretability compared to traditional CNNs:</p>
<ol type="1">
<li><strong>Visualizable Functions</strong>: Learned spline functions can be directly visualized and analyzed</li>
<li><strong>Human Interaction</strong>: Easier to understand and modify network behavior</li>
<li><strong>Mathematical Transparency</strong>: Clear mathematical foundation enables better analysis</li>
</ol>
</section>
<section id="performance-characteristics" class="level3">
<h3 class="anchored" data-anchor-id="performance-characteristics" id="performance-characteristics">Performance Characteristics</h3>
<p>Empirical evaluations have shown that Convolutional KANs can achieve:</p>
<ul>
<li><strong>Comparable or Superior Accuracy</strong>: Match or exceed CNN performance on various tasks</li>
<li><strong>Faster Neural Scaling Laws</strong>: More efficient scaling with increased model complexity</li>
<li><strong>Better Generalization</strong>: Improved performance on unseen data</li>
</ul>
</section>
</section>
<section id="practical-applications-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="practical-applications-and-limitations" id="practical-applications-and-limitations">Practical Applications and Limitations</h2>
<section id="suitable-applications" class="level3">
<h3 class="anchored" data-anchor-id="suitable-applications" id="suitable-applications">Suitable Applications</h3>
<p>Convolutional KANs are particularly well-suited for:</p>
<ol type="1">
<li><strong>Computer Vision Tasks</strong>: Image classification, object detection, segmentation</li>
<li><strong>Pattern Recognition</strong>: Complex pattern matching with adaptive feature extraction</li>
<li><strong>Scientific Computing</strong>: Problems requiring interpretable and efficient models</li>
<li><strong>Resource-Constrained Environments</strong>: Applications with limited computational resources</li>
</ol>
</section>
<section id="current-limitations" class="level3">
<h3 class="anchored" data-anchor-id="current-limitations" id="current-limitations">Current Limitations</h3>
<p>Despite their advantages, Convolutional KANs face certain challenges:</p>
<ol type="1">
<li><strong>Computational Complexity</strong>: Spline function evaluation may increase computational overhead</li>
<li><strong>Training Complexity</strong>: More complex optimization landscape due to learnable activation functions</li>
<li><strong>Limited Ecosystem</strong>: Fewer available tools and implementations compared to CNNs</li>
<li><strong>Scaling Challenges</strong>: Performance on very large-scale problems remains to be fully validated</li>
</ol>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="training-strategies" class="level3">
<h3 class="anchored" data-anchor-id="training-strategies" id="training-strategies">Training Strategies</h3>
<p>Effective training of Convolutional KANs requires:</p>
<ol type="1">
<li><strong>Careful Initialization</strong>: Proper initialization of spline parameters</li>
<li><strong>Adaptive Learning Rates</strong>: Different learning rates for different parameter types</li>
<li><strong>Regularization Techniques</strong>: Preventing overfitting in the learnable activation functions</li>
<li><strong>Numerical Stability</strong>: Ensuring stable spline function evaluation</li>
</ol>
</section>
<section id="hyperparameter-tuning" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h3>
<p>Key hyperparameters include:</p>
<ul>
<li><strong>Spline Order</strong>: Degree of the B-spline basis functions</li>
<li><strong>Grid Size</strong>: Number of control points for spline functions</li>
<li><strong>Regularization Strength</strong>: Balance between fitting and smoothness</li>
<li><strong>Learning Rate Schedules</strong>: Optimization strategy for different parameter types</li>
</ul>
</section>
</section>
<section id="future-directions-and-research-opportunities" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-research-opportunities" id="future-directions-and-research-opportunities">Future Directions and Research Opportunities</h2>
<section id="emerging-research-areas" class="level3">
<h3 class="anchored" data-anchor-id="emerging-research-areas" id="emerging-research-areas">Emerging Research Areas</h3>
<ol type="1">
<li><strong>Hybrid Architectures</strong>: Combining KANs with other neural network paradigms</li>
<li><strong>Specialized Applications</strong>: Domain-specific adaptations of Convolutional KANs</li>
<li><strong>Optimization Techniques</strong>: Novel training methods for improved efficiency</li>
<li><strong>Theoretical Analysis</strong>: Deeper understanding of KAN properties and capabilities</li>
</ol>
</section>
<section id="potential-developments" class="level3">
<h3 class="anchored" data-anchor-id="potential-developments" id="potential-developments">Potential Developments</h3>
<ol type="1">
<li><strong>Hardware Acceleration</strong>: Specialized hardware for efficient KAN computation</li>
<li><strong>AutoML Integration</strong>: Automated design and optimization of KAN architectures</li>
<li><strong>Large-Scale Applications</strong>: Scaling to very large datasets and complex problems</li>
<li><strong>Transfer Learning</strong>: Adapting pre-trained KAN models to new tasks</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Convolutional Kolmogorov-Arnold Networks represent a paradigm shift in neural network design, offering significant advantages in parameter efficiency, interpretability, and expressive power compared to traditional CNNs. While CNNs have proven their worth over the past decade, Convolutional KANs provide a compelling alternative that addresses many of the limitations of traditional approaches.</p>
<p>The key advantages of Convolutional KANs include their theoretical foundation in the Kolmogorov-Arnold representation theorem, enhanced parameter efficiency, superior interpretability, and adaptive representational capacity. However, challenges remain in terms of computational complexity, training strategies, and large-scale validation.</p>
<p>As research continues to advance, Convolutional KANs are poised to become increasingly important in the deep learning landscape, particularly for applications requiring efficient, interpretable, and high-performance neural networks. The choice between CNNs and Convolutional KANs will ultimately depend on specific application requirements, computational constraints, and the importance of interpretability in the given domain.</p>
<p>The future of computer vision and deep learning may well be shaped by the continued development and adoption of Kolmogorov-Arnold Networks, marking a new chapter in the evolution of artificial intelligence architectures.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Kolmogorov-Arnold Networks: Complete Implementation Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/kan/kan-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/kan/kan-code/</guid>
      <pubDate>Wed, 02 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="kolmogorov-arnold-networks-complete-implementation-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/kan/kan-code/kan.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Kolmogorov-Arnold Networks (KANs) represent a paradigm shift in neural network architecture, drawing inspiration from the mathematical foundations laid by Andrey Kolmogorov and Vladimir Arnold in the 1950s. Unlike traditional Multi-Layer Perceptrons (MLPs) that place learnable parameters on nodes, KANs position learnable activation functions on edges, fundamentally changing how neural networks process and learn from data.</p>
</section>
<section id="architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h2>
<section id="traditional-mlps-vs-kans" class="level3">
<h3 class="anchored" data-anchor-id="traditional-mlps-vs-kans" id="traditional-mlps-vs-kans">Traditional MLPs vs KANs</h3>
<p><strong>Multi-Layer Perceptrons (MLPs):</strong> - Learnable parameters: weights and biases on nodes - Fixed activation functions (ReLU, sigmoid, etc.) - Linear transformations followed by pointwise nonlinearities</p>
<p><strong>Kolmogorov-Arnold Networks (KANs):</strong> - Learnable parameters: activation functions on edges - No traditional weight matrices - Each edge has its own learnable univariate function</p>
</section>
</section>
<section id="mathematical-formulation" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-formulation" id="mathematical-formulation">Mathematical Formulation</h2>
<section id="layer-wise-computation" class="level3">
<h3 class="anchored" data-anchor-id="layer-wise-computation" id="layer-wise-computation">Layer-wise Computation</h3>
<p>For a KAN with L layers, the computation at layer l can be expressed as:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Pseudocode for KAN layer computation</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> kan_layer_forward(x, phi_functions):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co">    x: input tensor of shape (batch_size, input_dim)</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co">    phi_functions: learnable univariate functions for each edge</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> torch.zeros(batch_size, output_dim)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(input_dim):</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(output_dim):</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Apply learnable activation function φ_{i,j} to input x_i</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>            output[:, j] <span class="op">+=</span> phi_functions[i][j](x[:, i])</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output</span></code></pre></div></div>
</section>
<section id="learnable-activation-functions" class="level3">
<h3 class="anchored" data-anchor-id="learnable-activation-functions" id="learnable-activation-functions">Learnable Activation Functions</h3>
<p>The core innovation of KANs lies in the learnable activation functions. These are typically implemented using:</p>
<ol type="1">
<li><strong>B-splines</strong>: Piecewise polynomial functions that provide smooth, differentiable approximations</li>
<li><strong>Residual connections</strong>: Allow the network to learn both the spline component and a base function</li>
<li><strong>Grid-based parameterization</strong>: Enables efficient computation and gradient flow</li>
</ol>
</section>
</section>
<section id="implementation-details" class="level2">
<h2 class="anchored" data-anchor-id="implementation-details" id="implementation-details">Implementation Details</h2>
<section id="b-spline-based-activation-functions" class="level3">
<h3 class="anchored" data-anchor-id="b-spline-based-activation-functions" id="b-spline-based-activation-functions">B-spline Based Activation Functions</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BSplineActivation(nn.Module):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, grid_size<span class="op">=</span><span class="dv">5</span>, spline_order<span class="op">=</span><span class="dv">3</span>, grid_range<span class="op">=</span>(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>)):</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.grid_size <span class="op">=</span> grid_size</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spline_order <span class="op">=</span> spline_order</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.grid_range <span class="op">=</span> grid_range</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create uniform grid</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'grid'</span>, torch.linspace(</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>            grid_range[<span class="dv">0</span>], grid_range[<span class="dv">1</span>], grid_size <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        ))</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extend grid for B-spline computation</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        h <span class="op">=</span> (grid_range[<span class="dv">1</span>] <span class="op">-</span> grid_range[<span class="dv">0</span>]) <span class="op">/</span> grid_size</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>        extended_grid <span class="op">=</span> torch.cat([</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>            torch.arange(grid_range[<span class="dv">0</span>] <span class="op">-</span> spline_order <span class="op">*</span> h, grid_range[<span class="dv">0</span>], h),</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.grid,</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>            torch.arange(grid_range[<span class="dv">1</span>] <span class="op">+</span> h, grid_range[<span class="dv">1</span>] <span class="op">+</span> (spline_order <span class="op">+</span> <span class="dv">1</span>) <span class="op">*</span> h, h)</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'extended_grid'</span>, extended_grid)</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Learnable coefficients for B-spline</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.coefficients <span class="op">=</span> nn.Parameter(</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>            torch.randn(grid_size <span class="op">+</span> spline_order)</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Scale parameter for the activation</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale <span class="op">=</span> nn.Parameter(torch.ones(<span class="dv">1</span>))</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute B-spline basis functions</span></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        x_expanded <span class="op">=</span> x.unsqueeze(<span class="op">-</span><span class="dv">1</span>)  <span class="co"># (batch_size, 1)</span></span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute B-spline values</span></span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>        spline_values <span class="op">=</span> <span class="va">self</span>.compute_bspline(x_expanded)</span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear combination with learnable coefficients</span></span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.<span class="bu">sum</span>(spline_values <span class="op">*</span> <span class="va">self</span>.coefficients, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.scale <span class="op">*</span> output</span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_bspline(<span class="va">self</span>, x):</span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute B-spline basis functions using Cox-de Boor recursion"""</span></span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>        grid <span class="op">=</span> <span class="va">self</span>.extended_grid</span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>        order <span class="op">=</span> <span class="va">self</span>.spline_order</span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize basis functions</span></span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>        basis <span class="op">=</span> torch.zeros(x.shape[<span class="dv">0</span>], <span class="bu">len</span>(grid) <span class="op">-</span> <span class="dv">1</span>, device<span class="op">=</span>x.device)</span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Find intervals</span></span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(grid) <span class="op">-</span> <span class="dv">1</span>):</span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>            mask <span class="op">=</span> (x.squeeze(<span class="op">-</span><span class="dv">1</span>) <span class="op">&gt;=</span> grid[i]) <span class="op">&amp;</span> (x.squeeze(<span class="op">-</span><span class="dv">1</span>) <span class="op">&lt;</span> grid[i <span class="op">+</span> <span class="dv">1</span>])</span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>            basis[mask, i] <span class="op">=</span> <span class="fl">1.0</span></span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cox-de Boor recursion</span></span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> k <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, order <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>            new_basis <span class="op">=</span> torch.zeros_like(basis)</span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(grid) <span class="op">-</span> k <span class="op">-</span> <span class="dv">1</span>):</span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> grid[i <span class="op">+</span> k] <span class="op">!=</span> grid[i]:</span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a>                    alpha1 <span class="op">=</span> (x.squeeze(<span class="op">-</span><span class="dv">1</span>) <span class="op">-</span> grid[i]) <span class="op">/</span> (grid[i <span class="op">+</span> k] <span class="op">-</span> grid[i])</span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a>                    new_basis[:, i] <span class="op">+=</span> alpha1 <span class="op">*</span> basis[:, i]</span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> grid[i <span class="op">+</span> k <span class="op">+</span> <span class="dv">1</span>] <span class="op">!=</span> grid[i <span class="op">+</span> <span class="dv">1</span>]:</span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a>                    alpha2 <span class="op">=</span> (grid[i <span class="op">+</span> k <span class="op">+</span> <span class="dv">1</span>] <span class="op">-</span> x.squeeze(<span class="op">-</span><span class="dv">1</span>)) <span class="op">/</span> (grid[i <span class="op">+</span> k <span class="op">+</span> <span class="dv">1</span>] <span class="op">-</span> grid[i <span class="op">+</span> <span class="dv">1</span>])</span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a>                    new_basis[:, i] <span class="op">+=</span> alpha2 <span class="op">*</span> basis[:, i <span class="op">+</span> <span class="dv">1</span>]</span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a>            basis <span class="op">=</span> new_basis</span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> basis[:, :<span class="bu">len</span>(<span class="va">self</span>.coefficients)]</span>
<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> KANLayer(nn.Module):</span>
<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_dim, output_dim, grid_size<span class="op">=</span><span class="dv">5</span>, spline_order<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.input_dim <span class="op">=</span> input_dim</span>
<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output_dim <span class="op">=</span> output_dim</span>
<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create learnable activation functions for each edge</span></span>
<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.activations <span class="op">=</span> nn.ModuleList([</span>
<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a>            nn.ModuleList([</span>
<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a>                BSplineActivation(grid_size, spline_order) </span>
<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(output_dim)</span>
<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a>            ]) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(input_dim)</span>
<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb2-90"><a href="#cb2-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-91"><a href="#cb2-91" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Base linear transformation (residual connection)</span></span>
<span id="cb2-92"><a href="#cb2-92" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_weight <span class="op">=</span> nn.Parameter(torch.randn(input_dim, output_dim) <span class="op">*</span> <span class="fl">0.1</span>)</span>
<span id="cb2-93"><a href="#cb2-93" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-94"><a href="#cb2-94" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-95"><a href="#cb2-95" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb2-96"><a href="#cb2-96" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.zeros(batch_size, <span class="va">self</span>.output_dim, device<span class="op">=</span>x.device)</span>
<span id="cb2-97"><a href="#cb2-97" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-98"><a href="#cb2-98" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply learnable activations</span></span>
<span id="cb2-99"><a href="#cb2-99" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.input_dim):</span>
<span id="cb2-100"><a href="#cb2-100" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.output_dim):</span>
<span id="cb2-101"><a href="#cb2-101" aria-hidden="true" tabindex="-1"></a>                activated <span class="op">=</span> <span class="va">self</span>.activations[i][j](x[:, i])</span>
<span id="cb2-102"><a href="#cb2-102" aria-hidden="true" tabindex="-1"></a>                output[:, j] <span class="op">+=</span> activated</span>
<span id="cb2-103"><a href="#cb2-103" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-104"><a href="#cb2-104" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add base linear transformation</span></span>
<span id="cb2-105"><a href="#cb2-105" aria-hidden="true" tabindex="-1"></a>        base_output <span class="op">=</span> torch.matmul(x, <span class="va">self</span>.base_weight)</span>
<span id="cb2-106"><a href="#cb2-106" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-107"><a href="#cb2-107" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output <span class="op">+</span> base_output</span>
<span id="cb2-108"><a href="#cb2-108" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-109"><a href="#cb2-109" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> KolmogorovArnoldNetwork(nn.Module):</span>
<span id="cb2-110"><a href="#cb2-110" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, layer_dims, grid_size<span class="op">=</span><span class="dv">5</span>, spline_order<span class="op">=</span><span class="dv">3</span>):</span>
<span id="cb2-111"><a href="#cb2-111" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-112"><a href="#cb2-112" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList()</span>
<span id="cb2-113"><a href="#cb2-113" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-114"><a href="#cb2-114" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(layer_dims) <span class="op">-</span> <span class="dv">1</span>):</span>
<span id="cb2-115"><a href="#cb2-115" aria-hidden="true" tabindex="-1"></a>            layer <span class="op">=</span> KANLayer(</span>
<span id="cb2-116"><a href="#cb2-116" aria-hidden="true" tabindex="-1"></a>                layer_dims[i], </span>
<span id="cb2-117"><a href="#cb2-117" aria-hidden="true" tabindex="-1"></a>                layer_dims[i <span class="op">+</span> <span class="dv">1</span>], </span>
<span id="cb2-118"><a href="#cb2-118" aria-hidden="true" tabindex="-1"></a>                grid_size, </span>
<span id="cb2-119"><a href="#cb2-119" aria-hidden="true" tabindex="-1"></a>                spline_order</span>
<span id="cb2-120"><a href="#cb2-120" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-121"><a href="#cb2-121" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.layers.append(layer)</span>
<span id="cb2-122"><a href="#cb2-122" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-123"><a href="#cb2-123" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-124"><a href="#cb2-124" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb2-125"><a href="#cb2-125" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> layer(x)</span>
<span id="cb2-126"><a href="#cb2-126" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb2-127"><a href="#cb2-127" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-128"><a href="#cb2-128" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> regularization_loss(<span class="va">self</span>, regularization_factor<span class="op">=</span><span class="fl">1e-4</span>):</span>
<span id="cb2-129"><a href="#cb2-129" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compute regularization loss to encourage sparsity"""</span></span>
<span id="cb2-130"><a href="#cb2-130" aria-hidden="true" tabindex="-1"></a>        reg_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb2-131"><a href="#cb2-131" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb2-132"><a href="#cb2-132" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb2-133"><a href="#cb2-133" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb2-134"><a href="#cb2-134" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># L1 regularization on activation function coefficients</span></span>
<span id="cb2-135"><a href="#cb2-135" aria-hidden="true" tabindex="-1"></a>                    reg_loss <span class="op">+=</span> torch.<span class="bu">sum</span>(torch.<span class="bu">abs</span>(layer.activations[i][j].coefficients))</span>
<span id="cb2-136"><a href="#cb2-136" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-137"><a href="#cb2-137" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> regularization_factor <span class="op">*</span> reg_loss</span></code></pre></div></div>
</section>
<section id="training-loop-implementation" class="level3">
<h3 class="anchored" data-anchor-id="training-loop-implementation" id="training-loop-implementation">Training Loop Implementation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_kan(model, train_loader, val_loader, epochs<span class="op">=</span><span class="dv">100</span>, lr<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> torch.optim.lr_scheduler.ReduceLROnPlateau(</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        optimizer, mode<span class="op">=</span><span class="st">'min'</span>, factor<span class="op">=</span><span class="fl">0.5</span>, patience<span class="op">=</span><span class="dv">10</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.MSELoss()</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    train_losses <span class="op">=</span> []</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    val_losses <span class="op">=</span> []</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Training phase</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        model.train()</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        train_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Add regularization</span></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            reg_loss <span class="op">=</span> model.regularization_loss()</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">=</span> loss <span class="op">+</span> reg_loss</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>            total_loss.backward()</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping for stability</span></span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm<span class="op">=</span><span class="fl">1.0</span>)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">+=</span> total_loss.item()</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Validation phase</span></span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        val_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> data, target <span class="kw">in</span> val_loader:</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> model(data)</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>                val_loss <span class="op">+=</span> criterion(output, target).item()</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        avg_train_loss <span class="op">=</span> train_loss <span class="op">/</span> <span class="bu">len</span>(train_loader)</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        avg_val_loss <span class="op">=</span> val_loss <span class="op">/</span> <span class="bu">len</span>(val_loader)</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        train_losses.append(avg_train_loss)</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        val_losses.append(avg_val_loss)</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>        scheduler.step(avg_val_loss)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> epoch <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: Train Loss = </span><span class="sc">{</span>avg_train_loss<span class="sc">:.6f}</span><span class="ss">, '</span></span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>                  <span class="ss">f'Val Loss = </span><span class="sc">{</span>avg_val_loss<span class="sc">:.6f}</span><span class="ss">'</span>)</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_losses, val_losses</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features-and-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features-and-optimizations" id="advanced-features-and-optimizations">Advanced Features and Optimizations</h2>
<section id="pruning-and-sparsification" class="level3">
<h3 class="anchored" data-anchor-id="pruning-and-sparsification" id="pruning-and-sparsification">1. Pruning and Sparsification</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> prune_kan(model, threshold<span class="op">=</span><span class="fl">1e-2</span>):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Remove edges with small activation function magnitudes"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> model.layers:</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>                    activation <span class="op">=</span> layer.activations[i][j]</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Compute magnitude of activation function</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>                    magnitude <span class="op">=</span> torch.norm(activation.coefficients)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> magnitude <span class="op">&lt;</span> threshold:</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>                        <span class="co"># Zero out the activation function</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>                        activation.coefficients.fill_(<span class="dv">0</span>)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>                        activation.scale.fill_(<span class="dv">0</span>)</span></code></pre></div></div>
</section>
<section id="symbolic-regression-capabilities" class="level3">
<h3 class="anchored" data-anchor-id="symbolic-regression-capabilities" id="symbolic-regression-capabilities">2. Symbolic Regression Capabilities</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> symbolic_extraction(model, input_names, output_names):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Extract symbolic expressions from trained KAN"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    expressions <span class="op">=</span> []</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> layer_idx, layer <span class="kw">in</span> <span class="bu">enumerate</span>(model.layers):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        layer_expressions <span class="op">=</span> []</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>            terms <span class="op">=</span> []</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>                activation <span class="op">=</span> layer.activations[i][j]</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Check if activation is significant</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> torch.norm(activation.coefficients) <span class="op">&gt;</span> <span class="fl">1e-3</span>:</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Fit simple function to activation</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>                    func_type <span class="op">=</span> fit_symbolic_function(activation)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>                    terms.append(<span class="ss">f"</span><span class="sc">{</span>func_type<span class="sc">}</span><span class="ss">(</span><span class="sc">{</span>input_names[i]<span class="sc">}</span><span class="ss">)"</span>)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> terms:</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>                expression <span class="op">=</span> <span class="st">" + "</span>.join(terms)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>                layer_expressions.append(expression)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        expressions.append(layer_expressions)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> expressions</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fit_symbolic_function(activation):</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Fit symbolic function to learned activation"""</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sample the activation function</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>    x_test <span class="op">=</span> torch.linspace(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">100</span>)</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>    y_test <span class="op">=</span> activation(x_test).detach()</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Try fitting common functions</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>    functions <span class="op">=</span> {</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        <span class="st">'linear'</span>: <span class="kw">lambda</span> x, a, b: a <span class="op">*</span> x <span class="op">+</span> b,</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        <span class="st">'quadratic'</span>: <span class="kw">lambda</span> x, a, b, c: a <span class="op">*</span> x<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> b <span class="op">*</span> x <span class="op">+</span> c,</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        <span class="st">'sin'</span>: <span class="kw">lambda</span> x, a, b, c: a <span class="op">*</span> torch.sin(b <span class="op">*</span> x <span class="op">+</span> c),</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        <span class="st">'exp'</span>: <span class="kw">lambda</span> x, a, b: a <span class="op">*</span> torch.exp(b <span class="op">*</span> x),</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        <span class="st">'tanh'</span>: <span class="kw">lambda</span> x, a, b: a <span class="op">*</span> torch.tanh(b <span class="op">*</span> x)</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>    best_fit <span class="op">=</span> <span class="st">'linear'</span>  <span class="co"># Default</span></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>    min_error <span class="op">=</span> <span class="bu">float</span>(<span class="st">'inf'</span>)</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> func_name, func <span class="kw">in</span> functions.items():</span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simplified fitting (in practice, use scipy.optimize)</span></span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> func_name <span class="op">==</span> <span class="st">'linear'</span>:</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Simple linear regression</span></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>                A <span class="op">=</span> torch.stack([x_test, torch.ones_like(x_test)], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>                params <span class="op">=</span> torch.linalg.lstsq(A, y_test).solution</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>                pred <span class="op">=</span> func(x_test, params[<span class="dv">0</span>], params[<span class="dv">1</span>])</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Use first-order approximation</span></span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>                pred <span class="op">=</span> y_test  <span class="co"># Placeholder</span></span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>            error <span class="op">=</span> torch.mean((y_test <span class="op">-</span> pred)<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> error <span class="op">&lt;</span> min_error:</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>                min_error <span class="op">=</span> error</span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>                best_fit <span class="op">=</span> func_name</span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span>:</span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> best_fit</span></code></pre></div></div>
</section>
<section id="grid-adaptation" class="level3">
<h3 class="anchored" data-anchor-id="grid-adaptation" id="grid-adaptation">3. Grid Adaptation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> adaptive_grid_refinement(model, train_loader, refinement_factor<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Adapt grid points based on function complexity"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Collect statistics on activation function usage</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        activation_stats <span class="op">=</span> {}</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">&gt;</span> <span class="dv">10</span>:  <span class="co"># Sample a few batches</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> layer_idx, layer <span class="kw">in</span> <span class="bu">enumerate</span>(model.layers):</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> layer_idx <span class="kw">not</span> <span class="kw">in</span> activation_stats:</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>                    activation_stats[layer_idx] <span class="op">=</span> {}</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>                        key <span class="op">=</span> (i, j)</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> key <span class="kw">not</span> <span class="kw">in</span> activation_stats[layer_idx]:</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                            activation_stats[layer_idx][key] <span class="op">=</span> []</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>                        </span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>                        <span class="co"># Record input values for this activation</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> layer_idx <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>                            input_vals <span class="op">=</span> data[:, i]</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">else</span>:</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>                            <span class="co"># Would need to track intermediate activations</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>                            input_vals <span class="op">=</span> data[:, i]  <span class="co"># Simplified</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>                        </span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>                        activation_stats[layer_idx][key].extend(</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>                            input_vals.cpu().numpy().tolist()</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>                        )</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Refine grids based on usage patterns</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer_idx, layer <span class="kw">in</span> <span class="bu">enumerate</span>(model.layers):</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>                    activation <span class="op">=</span> layer.activations[i][j]</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>                    key <span class="op">=</span> (i, j)</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> key <span class="kw">in</span> activation_stats[layer_idx]:</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>                        input_range <span class="op">=</span> activation_stats[layer_idx][key]</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>                        </span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>                        <span class="co"># Compute density and refine grid</span></span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>                        hist, bins <span class="op">=</span> torch.histogram(</span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>                            torch.tensor(input_range), bins<span class="op">=</span>activation.grid_size</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>                        )</span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>                        </span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>                        <span class="co"># Areas with high density get more grid points</span></span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>                        high_density_regions <span class="op">=</span> hist <span class="op">&gt;</span> hist.mean()</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>                        </span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> high_density_regions.<span class="bu">any</span>():</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>                            <span class="co"># Refine grid (simplified implementation)</span></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>                            new_grid_size <span class="op">=</span> activation.grid_size <span class="op">*</span> refinement_factor</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>                            <span class="co"># Would need to properly interpolate coefficients</span></span></code></pre></div></div>
</section>
</section>
<section id="practical-applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="practical-applications-and-use-cases" id="practical-applications-and-use-cases">Practical Applications and Use Cases</h2>
<section id="function-approximation" class="level3">
<h3 class="anchored" data-anchor-id="function-approximation" id="function-approximation">1. Function Approximation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Approximating a complex mathematical function</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test_function_approximation():</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Generate synthetic data</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> target_function(x):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.sin(x[:, <span class="dv">0</span>]) <span class="op">*</span> torch.cos(x[:, <span class="dv">1</span>]) <span class="op">+</span> <span class="fl">0.5</span> <span class="op">*</span> x[:, <span class="dv">0</span>]<span class="op">**</span><span class="dv">2</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create dataset</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    n_samples <span class="op">=</span> <span class="dv">1000</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(n_samples, <span class="dv">2</span>)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> target_function(x).unsqueeze(<span class="dv">1</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Split data</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    train_size <span class="op">=</span> <span class="bu">int</span>(<span class="fl">0.8</span> <span class="op">*</span> n_samples)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    train_x, test_x <span class="op">=</span> x[:train_size], x[train_size:]</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    train_y, test_y <span class="op">=</span> y[:train_size], y[train_size:]</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create KAN model</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> KolmogorovArnoldNetwork([<span class="dv">2</span>, <span class="dv">5</span>, <span class="dv">1</span>], grid_size<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train model</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> torch.utils.data.TensorDataset(train_x, train_y)</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> torch.utils.data.DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    val_dataset <span class="op">=</span> torch.utils.data.TensorDataset(test_x, test_y)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    val_loader <span class="op">=</span> torch.utils.data.DataLoader(val_dataset, batch_size<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>    train_losses, val_losses <span class="op">=</span> train_kan(model, train_loader, val_loader)</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluate</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        predictions <span class="op">=</span> model(test_x)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        mse <span class="op">=</span> torch.mean((predictions <span class="op">-</span> test_y)<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Test MSE: </span><span class="sc">{</span>mse<span class="sc">.</span>item()<span class="sc">:.6f}</span><span class="ss">"</span>)</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model, train_losses, val_losses</span></code></pre></div></div>
</section>
<section id="scientific-computing" class="level3">
<h3 class="anchored" data-anchor-id="scientific-computing" id="scientific-computing">2. Scientific Computing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Solving differential equations</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> solve_pde_with_kan():</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Use KAN to solve partial differential equations"""</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">class</span> PDESolver(nn.Module):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>            <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.kan <span class="op">=</span> KolmogorovArnoldNetwork([<span class="dv">2</span>, <span class="dv">10</span>, <span class="dv">10</span>, <span class="dv">1</span>])</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> forward(<span class="va">self</span>, x, t):</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>            inputs <span class="op">=</span> torch.stack([x, t], dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">self</span>.kan(inputs)</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> physics_loss(<span class="va">self</span>, x, t):</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>            <span class="co">"""Compute physics-informed loss for PDE"""</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>            x.requires_grad_(<span class="va">True</span>)</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>            t.requires_grad_(<span class="va">True</span>)</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>            u <span class="op">=</span> <span class="va">self</span>.forward(x, t)</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute derivatives</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>            u_t <span class="op">=</span> torch.autograd.grad(</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>                u, t, torch.ones_like(u), create_graph<span class="op">=</span><span class="va">True</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>            )[<span class="dv">0</span>]</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>            u_x <span class="op">=</span> torch.autograd.grad(</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>                u, x, torch.ones_like(u), create_graph<span class="op">=</span><span class="va">True</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>            )[<span class="dv">0</span>]</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>            u_xx <span class="op">=</span> torch.autograd.grad(</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>                u_x, x, torch.ones_like(u_x), create_graph<span class="op">=</span><span class="va">True</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>            )[<span class="dv">0</span>]</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># PDE residual: u_t - u_xx = 0 (heat equation)</span></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>            pde_residual <span class="op">=</span> u_t <span class="op">-</span> u_xx</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> torch.mean(pde_residual<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training would involve minimizing physics loss</span></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># along with boundary and initial conditions</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> PDESolver()</span></code></pre></div></div>
</section>
</section>
<section id="performance-analysis-and-comparisons" class="level2">
<h2 class="anchored" data-anchor-id="performance-analysis-and-comparisons" id="performance-analysis-and-comparisons">Performance Analysis and Comparisons</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<p><strong>Memory Complexity:</strong> - MLPs: O(Σ(n_i × n_{i+1})) where n_i is the number of neurons in layer i - KANs: O(Σ(n_i × n_{i+1} × G)) where G is the grid size for B-splines</p>
<p><strong>Time Complexity:</strong> - Forward pass: O(Σ(n_i × n_{i+1} × G × k)) where k is the spline order - Backward pass: Similar, with additional complexity for B-spline derivative computation</p>
</section>
<section id="advantages-of-kans" class="level3">
<h3 class="anchored" data-anchor-id="advantages-of-kans" id="advantages-of-kans">Advantages of KANs</h3>
<ol type="1">
<li><strong>Interpretability</strong>: Learnable activation functions can be visualized and analyzed</li>
<li><strong>Expressiveness</strong>: Can represent complex functions with fewer parameters in some cases</li>
<li><strong>Scientific Computing</strong>: Natural fit for problems requiring symbolic regression</li>
<li><strong>Adaptive Capacity</strong>: Can learn specialized activation functions for different parts of the input space</li>
</ol>
</section>
<section id="limitations" class="level3">
<h3 class="anchored" data-anchor-id="limitations" id="limitations">Limitations</h3>
<ol type="1">
<li><strong>Computational Overhead</strong>: B-spline computation is more expensive than simple activations</li>
<li><strong>Memory Usage</strong>: Requires more memory due to grid-based parameterization</li>
<li><strong>Training Stability</strong>: Can be more sensitive to hyperparameter choices</li>
<li><strong>Limited Scale</strong>: Current implementations don’t scale to very large networks easily</li>
</ol>
</section>
</section>
<section id="best-practices-and-hyperparameter-tuning" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-hyperparameter-tuning" id="best-practices-and-hyperparameter-tuning">Best Practices and Hyperparameter Tuning</h2>
<section id="grid-size-selection" class="level3">
<h3 class="anchored" data-anchor-id="grid-size-selection" id="grid-size-selection">Grid Size Selection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> tune_grid_size(data_complexity, input_dim):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Heuristic for selecting appropriate grid size"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    base_grid_size <span class="op">=</span> <span class="bu">max</span>(<span class="dv">5</span>, <span class="bu">min</span>(<span class="dv">20</span>, <span class="bu">int</span>(math.log(data_complexity) <span class="op">*</span> <span class="dv">2</span>)))</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Adjust based on input dimensionality</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> input_dim <span class="op">&gt;</span> <span class="dv">10</span>:</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        base_grid_size <span class="op">=</span> <span class="bu">max</span>(<span class="dv">3</span>, base_grid_size <span class="op">-</span> <span class="dv">2</span>)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">elif</span> input_dim <span class="op">&lt;</span> <span class="dv">3</span>:</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        base_grid_size <span class="op">=</span> <span class="bu">min</span>(<span class="dv">25</span>, base_grid_size <span class="op">+</span> <span class="dv">3</span>)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> base_grid_size</span></code></pre></div></div>
</section>
<section id="regularization-strategies" class="level3">
<h3 class="anchored" data-anchor-id="regularization-strategies" id="regularization-strategies">Regularization Strategies</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> advanced_regularization(model, l1_factor<span class="op">=</span><span class="fl">1e-4</span>, smoothness_factor<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive regularization for KANs"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    reg_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> layer <span class="kw">in</span> model.layers:</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(layer.input_dim):</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(layer.output_dim):</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>                activation <span class="op">=</span> layer.activations[i][j]</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>                <span class="co"># L1 regularization for sparsity</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>                l1_loss <span class="op">=</span> torch.<span class="bu">sum</span>(torch.<span class="bu">abs</span>(activation.coefficients))</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Smoothness regularization</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="bu">len</span>(activation.coefficients) <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>                    smoothness_loss <span class="op">=</span> torch.<span class="bu">sum</span>(</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>                        (activation.coefficients[<span class="dv">1</span>:] <span class="op">-</span> activation.coefficients[:<span class="op">-</span><span class="dv">1</span>])<span class="op">**</span><span class="dv">2</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>                    )</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>                <span class="cf">else</span>:</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>                    smoothness_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>                reg_loss <span class="op">+=</span> l1_factor <span class="op">*</span> l1_loss <span class="op">+</span> smoothness_factor <span class="op">*</span> smoothness_loss</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> reg_loss</span></code></pre></div></div>
</section>
</section>
<section id="future-directions-and-research-opportunities" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-research-opportunities" id="future-directions-and-research-opportunities">Future Directions and Research Opportunities</h2>
<section id="scalability-improvements" class="level3">
<h3 class="anchored" data-anchor-id="scalability-improvements" id="scalability-improvements">1. Scalability Improvements</h3>
<ul>
<li>Efficient GPU implementations of B-spline computations</li>
<li>Sparse KAN architectures for high-dimensional problems</li>
<li>Distributed training strategies</li>
</ul>
</section>
<section id="theoretical-developments" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-developments" id="theoretical-developments">2. Theoretical Developments</h3>
<ul>
<li>Approximation theory for KAN architectures</li>
<li>Convergence guarantees and optimization landscapes</li>
<li>Connections to other function approximation methods</li>
</ul>
</section>
<section id="application-domains" class="level3">
<h3 class="anchored" data-anchor-id="application-domains" id="application-domains">3. Application Domains</h3>
<ul>
<li>Scientific machine learning and physics-informed neural networks</li>
<li>Automated theorem proving and symbolic computation</li>
<li>Interpretable AI for critical applications</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Kolmogorov-Arnold Networks represent a fundamental rethinking of neural network architecture, moving from node-based to edge-based learnable parameters. While they face challenges in terms of computational efficiency and scalability, their unique properties make them particularly well-suited for scientific computing, interpretable AI, and function approximation tasks.</p>
<p>The mathematical elegance of KANs, rooted in classical approximation theory, combined with their practical capabilities for symbolic regression and interpretable modeling, positions them as an important tool in the modern machine learning toolkit. As implementation techniques improve and computational bottlenecks are addressed, we can expect to see broader adoption of KAN-based approaches across various domains.</p>
<p>The code implementations provided here offer a foundation for experimenting with KANs, but ongoing research continues to refine these architectures and explore their full potential. Whether KANs will revolutionize neural network design remains to be seen, but they certainly offer a compelling alternative perspective on how neural networks can learn and represent complex functions.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[The Mathematics Behind Kolmogorov-Arnold Networks]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/kan/kan-math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/kan/kan-math/</guid>
      <pubDate>Wed, 02 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="the-mathematics-behind-kolmogorov-arnold-networks" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/kan/kan-math/kan.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Kolmogorov-Arnold Networks (KANs) represent a paradigm shift in neural network architecture, moving away from the traditional linear combinations of fixed activation functions toward networks that learn the activation functions themselves. This revolutionary approach is grounded in the profound mathematical insights of Andrey Kolmogorov and Vladimir Arnold, whose representation theorem provides the theoretical foundation for these networks.</p>
</section>
<section id="the-kolmogorov-arnold-representation-theorem" class="level2">
<h2 class="anchored" data-anchor-id="the-kolmogorov-arnold-representation-theorem" id="the-kolmogorov-arnold-representation-theorem">The Kolmogorov-Arnold Representation Theorem</h2>
<section id="historical-context-and-statement" class="level3">
<h3 class="anchored" data-anchor-id="historical-context-and-statement" id="historical-context-and-statement">Historical Context and Statement</h3>
<p>In 1957, Andrey Kolmogorov and his student Vladimir Arnold proved a remarkable theorem that fundamentally changed our understanding of multivariate function representation. The theorem states:</p>
<p><strong>Kolmogorov-Arnold Theorem</strong>: Every continuous multivariate function defined on a bounded domain can be represented as a composition and superposition of continuous functions of a single variable.</p>
<p>Formally, for any continuous function <span class="math inline">\(f: [0,1]^n \to \mathbb{R}\)</span>, there exist continuous functions <span class="math inline">\(\phi_{q,p}: [0,1] \to \mathbb{R}\)</span> and <span class="math inline">\(\Phi_q: \mathbb{R} \to \mathbb{R}\)</span> such that:</p>
<p><span class="math display">\[
f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q\left(\sum_{p=1}^{n} \phi_{q,p}(x_p)\right)
\]</span></p>
<p>where the inner functions <span class="math inline">\(\phi_{q,p}\)</span> are independent of <span class="math inline">\(f\)</span> and depend only on the dimension <span class="math inline">\(n\)</span>.</p>
</section>
<section id="mathematical-significance" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-significance" id="mathematical-significance">Mathematical Significance</h3>
<p>This theorem is remarkable because it demonstrates that the curse of dimensionality can be overcome through clever composition of univariate functions. The key insights are:</p>
<ol type="1">
<li><strong>Universality</strong>: The inner functions <span class="math inline">\(\phi_{q,p}\)</span> are universal and independent of the target function <span class="math inline">\(f\)</span></li>
<li><strong>Compositionality</strong>: Complex multivariate functions can be decomposed into simpler univariate components</li>
<li><strong>Finite Width</strong>: Only <span class="math inline">\(2n+1\)</span> terms are needed in the outer sum</li>
</ol>
</section>
</section>
<section id="from-classical-theory-to-neural-networks" class="level2">
<h2 class="anchored" data-anchor-id="from-classical-theory-to-neural-networks" id="from-classical-theory-to-neural-networks">From Classical Theory to Neural Networks</h2>
<section id="traditional-neural-networks-vs.-kans" class="level3">
<h3 class="anchored" data-anchor-id="traditional-neural-networks-vs.-kans" id="traditional-neural-networks-vs.-kans">Traditional Neural Networks vs.&nbsp;KANs</h3>
<p>Traditional multilayer perceptrons (MLPs) implement the universal approximation theorem through:</p>
<p><span class="math display">\[
f(x) = \sum_{i=1}^{m} w_i \sigma\left(\sum_{j=1}^{n} w_{ij} x_j + b_i\right)
\]</span></p>
<p>where <span class="math inline">\(\sigma\)</span> is a fixed activation function (e.g., ReLU, sigmoid, tanh).</p>
<p>KANs, inspired by the Kolmogorov-Arnold theorem, instead use:</p>
<p><span class="math display">\[
f(x) = \sum_{q=0}^{2n} \Phi_q\left(\sum_{p=1}^{n} \phi_{q,p}(x_p)\right)
\]</span></p>
<p>where both <span class="math inline">\(\phi_{q,p}\)</span> and <span class="math inline">\(\Phi_q\)</span> are learnable functions.</p>
</section>
<section id="the-fundamental-difference" class="level3">
<h3 class="anchored" data-anchor-id="the-fundamental-difference" id="the-fundamental-difference">The Fundamental Difference</h3>
<p>The crucial difference lies in <strong>where</strong> the nonlinearity is applied: - <strong>MLPs</strong>: Apply fixed nonlinear activations to linear combinations of inputs - <strong>KANs</strong>: Learn the nonlinear functions themselves, applied to individual variables</p>
</section>
</section>
<section id="mathematical-foundations-of-kan-architecture" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-foundations-of-kan-architecture" id="mathematical-foundations-of-kan-architecture">Mathematical Foundations of KAN Architecture</h2>
<section id="function-parametrization" class="level3">
<h3 class="anchored" data-anchor-id="function-parametrization" id="function-parametrization">Function Parametrization</h3>
<p>In practical implementations, the learnable functions <span class="math inline">\(\phi\)</span> and <span class="math inline">\(\Phi\)</span> are typically parametrized using:</p>
<section id="b-splines" class="level4">
<h4 class="anchored" data-anchor-id="b-splines">B-Splines</h4>
<p>B-splines provide a flexible and numerically stable way to represent univariate functions:</p>
<p><span class="math display">\[\phi(x) = \sum_{i=0}^{G} c_i B_i^k(x)\]</span></p>
<p>where: - <span class="math inline">\(B_i^k(x)\)</span> are B-spline basis functions of degree <span class="math inline">\(k\)</span> - <span class="math inline">\(c_i\)</span> are learnable coefficients - <span class="math inline">\(G\)</span> is the number of control points</p>
</section>
<section id="advantages-of-b-splines" class="level4">
<h4 class="anchored" data-anchor-id="advantages-of-b-splines">Advantages of B-Splines:</h4>
<ul>
<li><strong>Local Support</strong>: Changes in coefficients affect only local regions</li>
<li><strong>Smoothness</strong>: Degree <span class="math inline">\(k\)</span> splines are <span class="math inline">\(C^{k-1}\)</span> continuous</li>
<li><strong>Numerical Stability</strong>: Well-conditioned basis functions</li>
<li><strong>Interpretability</strong>: Control points provide intuitive understanding</li>
</ul>
</section>
</section>
<section id="layer-wise-composition" class="level3">
<h3 class="anchored" data-anchor-id="layer-wise-composition" id="layer-wise-composition">Layer-wise Composition</h3>
<p>A practical KAN extends the basic representation through multiple layers:</p>
<p><span class="math display">\[\text{KAN}(x) = \text{KAN}_L \circ \text{KAN}_{L-1} \circ \cdots \circ \text{KAN}_1(x)\]</span></p>
<p>where each layer <span class="math inline">\(\text{KAN}_\ell\)</span> transforms inputs through learnable univariate functions:</p>
<p><span class="math display">\[\text{KAN}_\ell(x^{(\ell-1)}) = \left(\sum_{j=1}^{n_{\ell-1}} \phi_{\ell,i,j}(x^{(\ell-1)}_j)\right)_{i=1}^{n_\ell}\]</span></p>
</section>
<section id="residual-connections" class="level3">
<h3 class="anchored" data-anchor-id="residual-connections" id="residual-connections">Residual Connections</h3>
<p>To enhance expressivity and training stability, KANs often include residual connections:</p>
<p><span class="math display">\[\phi_{\ell,i,j}(x) = w_{\ell,i,j} \cdot \text{spline}_{\ell,i,j}(x) + b_{\ell,i,j} \cdot x\]</span></p>
<p>where: - <span class="math inline">\(\text{spline}_{\ell,i,j}(x)\)</span> is the B-spline component - <span class="math inline">\(w_{\ell,i,j}\)</span> and <span class="math inline">\(b_{\ell,i,j}\)</span> are learnable parameters - The linear term <span class="math inline">\(b_{\ell,i,j} \cdot x\)</span> provides a residual connection</p>
</section>
</section>
<section id="optimization-and-training" class="level2">
<h2 class="anchored" data-anchor-id="optimization-and-training" id="optimization-and-training">Optimization and Training</h2>
<section id="loss-function" class="level3">
<h3 class="anchored" data-anchor-id="loss-function" id="loss-function">Loss Function</h3>
<p>The training objective for KANs typically includes both accuracy and regularization terms:</p>
<p><span class="math display">\[\mathcal{L} = \mathcal{L}_{\text{data}} + \lambda_1 \mathcal{L}_{\text{sparse}} + \lambda_2 \mathcal{L}_{\text{smooth}}\]</span></p>
<p>where: - <span class="math inline">\(\mathcal{L}_{\text{data}}\)</span> is the standard prediction loss (MSE, cross-entropy, etc.) - <span class="math inline">\(\mathcal{L}_{\text{sparse}}\)</span> encourages sparsity in the network - <span class="math inline">\(\mathcal{L}_{\text{smooth}}\)</span> promotes smooth activation functions</p>
</section>
<section id="sparsity-regularization" class="level3">
<h3 class="anchored" data-anchor-id="sparsity-regularization" id="sparsity-regularization">Sparsity Regularization</h3>
<p>To encourage interpretable networks, KANs use sparsity regularization:</p>
<p><span class="math display">\[
\mathcal{L}_{\text{sparse}} = \sum_{\ell,i,j} |w_{\ell,i,j}| + |b_{\ell,i,j}|
\]</span></p>
<p>This L1 penalty encourages many connections to become exactly zero, leading to sparse, interpretable networks.</p>
</section>
<section id="smoothness-regularization" class="level3">
<h3 class="anchored" data-anchor-id="smoothness-regularization" id="smoothness-regularization">Smoothness Regularization</h3>
<p>To prevent overfitting and ensure smooth activation functions:</p>
<p><span class="math display">\[
\mathcal{L}_{\text{smooth}} = \sum_{\ell,i,j} \int \left(\frac{d^2}{dx^2} \phi_{\ell,i,j}(x)\right)^2 dx
\]</span></p>
<p>This penalizes high curvature in the learned functions, promoting smooth and generalizable representations.</p>
</section>
</section>
<section id="theoretical-properties" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-properties" id="theoretical-properties">Theoretical Properties</h2>
<section id="universal-approximation" class="level3">
<h3 class="anchored" data-anchor-id="universal-approximation" id="universal-approximation">Universal Approximation</h3>
<p>KANs inherit universal approximation properties from the Kolmogorov-Arnold theorem:</p>
<p><strong>Theorem</strong>: Given sufficient width and depth, KANs can approximate any continuous function on a compact domain to arbitrary accuracy.</p>
<p><strong>Proof Sketch</strong>: The constructive proof of the Kolmogorov-Arnold theorem shows that any continuous function can be represented in the KAN form. The B-spline parametrization provides the flexibility to approximate the required univariate functions.</p>
</section>
<section id="expressivity-analysis" class="level3">
<h3 class="anchored" data-anchor-id="expressivity-analysis" id="expressivity-analysis">Expressivity Analysis</h3>
<p>The expressivity of KANs can be analyzed through several lenses:</p>
<section id="parameter-efficiency" class="level4">
<h4 class="anchored" data-anchor-id="parameter-efficiency">Parameter Efficiency</h4>
<p>For a function of <span class="math inline">\(n\)</span> variables requiring <span class="math inline">\(m\)</span> parameters in an MLP, a KAN might achieve similar approximation quality with fewer parameters due to its compositional structure.</p>
</section>
<section id="sample-complexity" class="level4">
<h4 class="anchored" data-anchor-id="sample-complexity">Sample Complexity</h4>
<p>The sample complexity of KANs is related to the intrinsic dimensionality of the target function rather than the ambient dimensionality, potentially providing advantages for high-dimensional problems with low-dimensional structure.</p>
</section>
</section>
<section id="approximation-rates" class="level3">
<h3 class="anchored" data-anchor-id="approximation-rates" id="approximation-rates">Approximation Rates</h3>
<p>Under smoothness assumptions on the target function, KANs can achieve superior approximation rates:</p>
<p><strong>Theorem</strong>: For target functions with bounded mixed derivatives, KANs achieve approximation error <span class="math inline">\(O(n^{-r/d})\)</span> where <span class="math inline">\(r\)</span> is the smoothness parameter and <span class="math inline">\(d\)</span> is the intrinsic dimension.</p>
</section>
</section>
<section id="computational-complexity" class="level2">
<h2 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h2>
<section id="forward-pass-complexity" class="level3">
<h3 class="anchored" data-anchor-id="forward-pass-complexity" id="forward-pass-complexity">Forward Pass Complexity</h3>
<p>For a KAN with <span class="math inline">\(L\)</span> layers and width <span class="math inline">\(n\)</span>: - <strong>Time Complexity</strong>: <span class="math inline">\(O(L \cdot n^2 \cdot G)\)</span> where <span class="math inline">\(G\)</span> is the number of B-spline coefficients - <strong>Space Complexity</strong>: <span class="math inline">\(O(L \cdot n^2 \cdot G)\)</span> for parameter storage</p>
</section>
<section id="backward-pass-complexity" class="level3">
<h3 class="anchored" data-anchor-id="backward-pass-complexity" id="backward-pass-complexity">Backward Pass Complexity</h3>
<p>The gradient computation involves: - Gradients with respect to B-spline coefficients - Gradients with respect to residual connection parameters - Chain rule application through the compositional structure</p>
<p>The overall complexity remains <span class="math inline">\(O(L \cdot n^2 \cdot G)\)</span> for both forward and backward passes.</p>
</section>
</section>
<section id="interpretability-and-symbolic-regression" class="level2">
<h2 class="anchored" data-anchor-id="interpretability-and-symbolic-regression" id="interpretability-and-symbolic-regression">Interpretability and Symbolic Regression</h2>
<section id="automatic-symbolification" class="level3">
<h3 class="anchored" data-anchor-id="automatic-symbolification" id="automatic-symbolification">Automatic Symbolification</h3>
<p>One of the most remarkable features of KANs is their ability to discover symbolic representations:</p>
<section id="pruning-process" class="level4">
<h4 class="anchored" data-anchor-id="pruning-process">Pruning Process</h4>
<ol type="1">
<li><strong>Training</strong>: Train the full KAN with sparsity regularization</li>
<li><strong>Pruning</strong>: Remove connections with small weights</li>
<li><strong>Symbolification</strong>: Replace smooth functions with symbolic equivalents</li>
</ol>
</section>
<section id="symbol-discovery" class="level4">
<h4 class="anchored" data-anchor-id="symbol-discovery">Symbol Discovery</h4>
<p>KANs can automatically discover that learned functions correspond to elementary functions: - Polynomials: <span class="math inline">\(x^n\)</span> - Exponentials: <span class="math inline">\(e^x\)</span> - Trigonometric: <span class="math inline">\(\sin(x)\)</span>, <span class="math inline">\(\cos(x)\)</span> - Logarithmic: <span class="math inline">\(\log(x)\)</span></p>
</section>
</section>
<section id="mathematical-insights" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-insights" id="mathematical-insights">Mathematical Insights</h3>
<p>The learned functions often reveal mathematical structure:</p>
<p><span class="math display">\[
f(x_1, x_2) = \sin(x_1) + x_2^2
\]</span></p>
<p>might be discovered as:</p>
<p><span class="math display">\[
\text{KAN}(x_1, x_2) = \Phi_1(\phi_{1,1}(x_1)) + \Phi_2(\phi_{2,2}(x_2))
\]</span></p>
<p>where <span class="math inline">\(\phi_{1,1} \approx \sin\)</span> and <span class="math inline">\(\phi_{2,2} \approx x^2\)</span>.</p>
</section>
</section>
<section id="advanced-mathematical-concepts" class="level2">
<h2 class="anchored" data-anchor-id="advanced-mathematical-concepts" id="advanced-mathematical-concepts">Advanced Mathematical Concepts</h2>
<section id="measure-theory-perspectives" class="level3">
<h3 class="anchored" data-anchor-id="measure-theory-perspectives" id="measure-theory-perspectives">Measure Theory Perspectives</h3>
<p>From a measure-theoretic viewpoint, the Kolmogorov-Arnold theorem can be understood as a statement about the existence of certain measurable functions that achieve the required representation.</p>
</section>
<section id="functional-analysis" class="level3">
<h3 class="anchored" data-anchor-id="functional-analysis" id="functional-analysis">Functional Analysis</h3>
<p>The space of functions representable by KANs forms a dense subset of <span class="math inline">\(C([0,1]^n)\)</span> under the uniform norm, providing a functional analytic foundation for their approximation capabilities.</p>
</section>
<section id="information-theory" class="level3">
<h3 class="anchored" data-anchor-id="information-theory" id="information-theory">Information Theory</h3>
<p>The representational efficiency of KANs can be analyzed through the lens of information theory, where the learned functions encode essential information about the target function’s structure.</p>
</section>
</section>
<section id="limitations-and-extensions" class="level2">
<h2 class="anchored" data-anchor-id="limitations-and-extensions" id="limitations-and-extensions">Limitations and Extensions</h2>
<section id="theoretical-limitations" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-limitations" id="theoretical-limitations">Theoretical Limitations</h3>
<ol type="1">
<li><strong>Constructive vs.&nbsp;Practical</strong>: The original Kolmogorov-Arnold theorem is non-constructive; practical KANs use approximations</li>
<li><strong>Smoothness Requirements</strong>: The theorem applies to continuous functions; practical considerations require differentiability</li>
<li><strong>Domain Restrictions</strong>: The theorem is stated for bounded domains; extensions to unbounded domains require careful treatment</li>
</ol>
</section>
<section id="recent-extensions" class="level3">
<h3 class="anchored" data-anchor-id="recent-extensions" id="recent-extensions">Recent Extensions</h3>
<section id="multidimensional-kans" class="level4">
<h4 class="anchored" data-anchor-id="multidimensional-kans">Multidimensional KANs</h4>
<p>Extensions to handle tensor-valued inputs and outputs:</p>
<p><span class="math display">\[
\text{Tensor-KAN}: \mathbb{R}^{n_1 \times n_2 \times \cdots} \to \mathbb{R}^{m_1 \times m_2 \times \cdots}
\]</span></p>
</section>
<section id="convolutional-kans" class="level4">
<h4 class="anchored" data-anchor-id="convolutional-kans">Convolutional KANs</h4>
<p>Incorporating spatial structure through learnable convolution-like operations:</p>
<p><span class="math display">\[
\text{Conv-KAN}(x) = \sum_{i,j} \phi_{i,j}(x * k_{i,j})
\]</span></p>
<p>where <span class="math inline">\(k_{i,j}\)</span> are learnable kernels and <span class="math inline">\(\phi_{i,j}\)</span> are learnable activation functions.</p>
</section>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="theoretical-developments" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-developments" id="theoretical-developments">Theoretical Developments</h3>
<ol type="1">
<li><strong>Approximation Theory</strong>: Tighter bounds on approximation rates</li>
<li><strong>Optimization Theory</strong>: Convergence guarantees for KAN training</li>
<li><strong>Generalization Theory</strong>: Sample complexity bounds for KANs</li>
</ol>
</section>
<section id="practical-innovations" class="level3">
<h3 class="anchored" data-anchor-id="practical-innovations" id="practical-innovations">Practical Innovations</h3>
<ol type="1">
<li><strong>Efficient Implementations</strong>: GPU-optimized B-spline evaluations</li>
<li><strong>Architecture Search</strong>: Automated design of KAN topologies</li>
<li><strong>Hybrid Models</strong>: Combinations of KANs with other architectures</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Kolmogorov-Arnold Networks represent a fundamental shift in neural network design, moving from fixed activation functions to learnable univariate functions. The mathematical foundations, rooted in the profound insights of Kolmogorov and Arnold, provide both theoretical guarantees and practical advantages. The ability to automatically discover symbolic representations while maintaining universal approximation capabilities makes KANs a powerful tool for both machine learning and mathematical discovery.</p>
<p>The interplay between classical approximation theory and modern deep learning exemplified by KANs suggests that there are still fundamental insights to be gained by revisiting classical mathematical results through the lens of contemporary computational capabilities. As we continue to develop and refine these networks, we can expect them to play an increasingly important role in both theoretical understanding and practical applications of neural computation.</p>
<p>The mathematical elegance of KANs lies not just in their theoretical foundations, but in their ability to bridge the gap between approximation theory and interpretable machine learning, offering a path toward more transparent and mathematically principled artificial intelligence systems.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Kolmogorov-Arnold Networks: Revolutionizing Neural Architecture Design]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/kan/kans-guide/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/kan/kans-guide/</guid>
      <pubDate>Wed, 02 Jul 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="kolmogorov-arnold-networks-revolutionizing-neural-architecture-design" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/kan/kans-guide/kan.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Kolmogorov-Arnold Networks (KANs) represent a paradigm shift in neural network architecture design, moving away from the traditional Multi-Layer Perceptron (MLP) approach that has dominated machine learning for decades. Named after mathematicians Andrey Kolmogorov and Vladimir Arnold, these networks are based on the Kolmogorov-Arnold representation theorem, which provides a mathematical foundation for representing multivariate continuous functions.</p>
<p>Unlike traditional neural networks that place fixed activation functions at nodes (neurons), KANs place learnable activation functions on edges (weights). This fundamental architectural change offers several advantages, including better interpretability, higher accuracy with fewer parameters, and improved generalization capabilities.</p>
</section>
<section id="mathematical-foundation-the-kolmogorov-arnold-theorem" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-foundation-the-kolmogorov-arnold-theorem" id="mathematical-foundation-the-kolmogorov-arnold-theorem">Mathematical Foundation: The Kolmogorov-Arnold Theorem</h2>
<p>The Kolmogorov-Arnold representation theorem, proven in 1957, states that every multivariate continuous function can be represented as a composition and superposition of continuous functions of a single variable. Mathematically, for any continuous function <span class="math inline">\(f: [0,1]^n \rightarrow \mathbb{R}\)</span> , there exist continuous functions <span class="math inline">\(\phi_{q,p}: \mathbb{R} \rightarrow \mathbb{R}\)</span> such that:</p>
<p><span class="math display">\[
f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q\left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right)
\]</span></p>
<p>This theorem provides the theoretical foundation for KANs, suggesting that complex multivariate functions can be decomposed into simpler univariate functions arranged in a specific hierarchical structure.</p>
</section>
<section id="architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h2>
<section id="traditional-mlps-vs-kans" class="level3">
<h3 class="anchored" data-anchor-id="traditional-mlps-vs-kans" id="traditional-mlps-vs-kans">Traditional MLPs vs KANs</h3>
<p><strong>Traditional MLPs:</strong> - Fixed activation functions (ReLU, sigmoid, tanh) at nodes - Linear transformations on edges (weights and biases) - Learning occurs through weight optimization - Limited interpretability due to distributed representations</p>
<p><strong>Kolmogorov-Arnold Networks:</strong> - Learnable activation functions on edges - No traditional linear weights - Each edge contains a univariate function (typically B-splines) - Nodes perform simple summation operations - Enhanced interpretability through edge function visualization</p>
</section>
<section id="kan-layer-structure" class="level3">
<h3 class="anchored" data-anchor-id="kan-layer-structure" id="kan-layer-structure">KAN Layer Structure</h3>
<p>A single KAN layer transforms an input vector of dimension <code>n_in</code> to an output vector of dimension <code>n_out</code>. Each connection between input and output nodes contains a learnable univariate function, typically parameterized using B-splines.</p>
<p>The transformation can be expressed as: <span class="math display">\[
y_j = \sum_{i=1}^{n_{\text{in}}} \phi_{i,j}(x_i)
\]</span></p>
<p>Where <span class="math inline">\(\phi_{i,j}\)</span> represents the learnable function on the edge connecting input i to output j.</p>
</section>
</section>
<section id="key-components-and-implementation" class="level2">
<h2 class="anchored" data-anchor-id="key-components-and-implementation" id="key-components-and-implementation">Key Components and Implementation</h2>
<section id="b-spline-parameterization" class="level3">
<h3 class="anchored" data-anchor-id="b-spline-parameterization" id="b-spline-parameterization">B-Spline Parameterization</h3>
<p>KANs typically use B-splines to parameterize the learnable functions on edges. B-splines offer several advantages:</p>
<ul>
<li><strong>Smoothness</strong>: Provide continuous derivatives up to a specified order</li>
<li><strong>Local Support</strong>: Changes in one region don’t affect distant regions</li>
<li><strong>Flexibility</strong>: Can approximate a wide variety of functions</li>
<li><strong>Computational Efficiency</strong>: Enable efficient computation and differentiation</li>
</ul>
</section>
<section id="grid-structure" class="level3">
<h3 class="anchored" data-anchor-id="grid-structure" id="grid-structure">Grid Structure</h3>
<p>The B-splines are defined over a grid of control points. Key parameters include:</p>
<ul>
<li><strong>Grid Size</strong>: Number of intervals in the spline grid</li>
<li><strong>Spline Order</strong>: Determines smoothness (typically cubic, k=3)</li>
<li><strong>Grid Range</strong>: Input domain coverage for the splines</li>
</ul>
</section>
<section id="residual-connections" class="level3">
<h3 class="anchored" data-anchor-id="residual-connections" id="residual-connections">Residual Connections</h3>
<p>Modern KAN implementations often include residual connections to improve training stability and enable deeper networks. These connections add a linear component to each edge function:</p>
<p><span class="math display">\[
\phi_{i,j}(x) = \text{spline\_function}(x) + \text{linear\_function}(x)
\]</span></p>
</section>
</section>
<section id="training-process" class="level2">
<h2 class="anchored" data-anchor-id="training-process" id="training-process">Training Process</h2>
<section id="forward-pass" class="level3">
<h3 class="anchored" data-anchor-id="forward-pass" id="forward-pass">Forward Pass</h3>
<ol type="1">
<li><strong>Input Processing</strong>: Input features are fed to the first layer</li>
<li><strong>Edge Function Evaluation</strong>: Each edge computes its learnable function</li>
<li><strong>Node Summation</strong>: Output nodes sum contributions from all incoming edges</li>
<li><strong>Layer Propagation</strong>: Process repeats through subsequent layers</li>
</ol>
</section>
<section id="backward-pass" class="level3">
<h3 class="anchored" data-anchor-id="backward-pass" id="backward-pass">Backward Pass</h3>
<p>Training KANs requires computing gradients with respect to: - <strong>Spline Coefficients</strong>: Control points of B-spline functions - <strong>Grid Points</strong>: Locations of spline knots (in adaptive variants) - <strong>Scaling Parameters</strong>: Normalization factors for inputs/outputs</p>
</section>
<section id="optimization-challenges" class="level3">
<h3 class="anchored" data-anchor-id="optimization-challenges" id="optimization-challenges">Optimization Challenges</h3>
<ul>
<li><strong>Non-convexity</strong>: Multiple local minima in the loss landscape</li>
<li><strong>Grid Adaptation</strong>: Dynamically adjusting spline grids during training</li>
<li><strong>Regularization</strong>: Preventing overfitting in high-capacity edge functions</li>
</ul>
</section>
</section>
<section id="advantages-of-kans" class="level2">
<h2 class="anchored" data-anchor-id="advantages-of-kans" id="advantages-of-kans">Advantages of KANs</h2>
<section id="enhanced-interpretability" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-interpretability" id="enhanced-interpretability">Enhanced Interpretability</h3>
<p>KANs offer superior interpretability compared to traditional MLPs:</p>
<ul>
<li><strong>Function Visualization</strong>: Edge functions can be plotted and analyzed</li>
<li><strong>Feature Attribution</strong>: Direct observation of how inputs transform through the network</li>
<li><strong>Symbolic Regression</strong>: Potential for discovering analytical expressions</li>
</ul>
</section>
<section id="parameter-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="parameter-efficiency" id="parameter-efficiency">Parameter Efficiency</h3>
<p>Despite their flexibility, KANs often achieve better performance with fewer parameters:</p>
<ul>
<li><strong>Targeted Learning</strong>: Functions are learned where needed (on edges)</li>
<li><strong>Shared Complexity</strong>: Similar transformations can be learned across different edges</li>
<li><strong>Adaptive Complexity</strong>: Grid refinement allows dynamic complexity adjustment</li>
</ul>
</section>
<section id="better-generalization" class="level3">
<h3 class="anchored" data-anchor-id="better-generalization" id="better-generalization">Better Generalization</h3>
<p>KANs demonstrate improved generalization capabilities:</p>
<ul>
<li><strong>Inductive Bias</strong>: Architecture naturally incorporates smooth function assumptions</li>
<li><strong>Regularization</strong>: B-spline smoothness acts as implicit regularization</li>
<li><strong>Feature Learning</strong>: Automatic discovery of relevant transformations</li>
</ul>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="scientific-computing" class="level3">
<h3 class="anchored" data-anchor-id="scientific-computing" id="scientific-computing">Scientific Computing</h3>
<p>KANs excel in scientific applications where interpretability is crucial:</p>
<ul>
<li><strong>Physics Modeling</strong>: Discovering governing equations from data</li>
<li><strong>Material Science</strong>: Property prediction with interpretable relationships</li>
<li><strong>Climate Modeling</strong>: Understanding complex environmental interactions</li>
</ul>
</section>
<section id="function-approximation" class="level3">
<h3 class="anchored" data-anchor-id="function-approximation" id="function-approximation">Function Approximation</h3>
<p>Natural fit for problems requiring accurate function approximation:</p>
<ul>
<li><strong>Regression Tasks</strong>: Continuous function learning with high accuracy</li>
<li><strong>Time Series</strong>: Modeling temporal dependencies with interpretable components</li>
<li><strong>Control Systems</strong>: Learning control policies with explainable behavior</li>
</ul>
</section>
<section id="symbolic-regression" class="level3">
<h3 class="anchored" data-anchor-id="symbolic-regression" id="symbolic-regression">Symbolic Regression</h3>
<p>KANs can facilitate symbolic regression tasks:</p>
<ul>
<li><strong>Equation Discovery</strong>: Finding analytical expressions for data relationships</li>
<li><strong>Scientific Discovery</strong>: Uncovering natural laws from experimental data</li>
<li><strong>Feature Engineering</strong>: Automatic discovery of useful feature transformations</li>
</ul>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<p><strong>Memory Requirements:</strong> - B-spline coefficients storage - Grid point management - Intermediate activation storage</p>
<p><strong>Computational Cost:</strong> - Spline evaluation overhead - Grid adaptation algorithms - Gradient computation complexity</p>
</section>
<section id="hyperparameter-tuning" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h3>
<p>Critical hyperparameters for KANs:</p>
<ul>
<li><strong>Grid Size</strong>: Balance between expressiveness and computational cost</li>
<li><strong>Spline Order</strong>: Trade-off between smoothness and flexibility</li>
<li><strong>Network Depth</strong>: Number of KAN layers</li>
<li><strong>Width</strong>: Number of nodes per layer</li>
</ul>
</section>
<section id="software-implementation" class="level3">
<h3 class="anchored" data-anchor-id="software-implementation" id="software-implementation">Software Implementation</h3>
<p>Popular KAN implementations:</p>
<ul>
<li><strong>PyKAN</strong>: Official implementation with comprehensive features</li>
<li><strong>TensorFlow/PyTorch</strong>: Custom implementations and third-party libraries</li>
<li><strong>JAX</strong>: High-performance implementations for research</li>
</ul>
</section>
</section>
<section id="current-limitations-and-challenges" class="level2">
<h2 class="anchored" data-anchor-id="current-limitations-and-challenges" id="current-limitations-and-challenges">Current Limitations and Challenges</h2>
<section id="scalability-issues" class="level3">
<h3 class="anchored" data-anchor-id="scalability-issues" id="scalability-issues">Scalability Issues</h3>
<ul>
<li><strong>Memory Overhead</strong>: Higher memory requirements compared to MLPs</li>
<li><strong>Training Time</strong>: Longer training due to complex function optimization</li>
<li><strong>Large-Scale Applications</strong>: Challenges in scaling to very large datasets</li>
</ul>
</section>
<section id="theoretical-gaps" class="level3">
<h3 class="anchored" data-anchor-id="theoretical-gaps" id="theoretical-gaps">Theoretical Gaps</h3>
<ul>
<li><strong>Approximation Theory</strong>: Limited theoretical understanding of approximation capabilities</li>
<li><strong>Optimization Landscape</strong>: Incomplete analysis of loss surface properties</li>
<li><strong>Generalization Bounds</strong>: Lack of theoretical generalization guarantees</li>
</ul>
</section>
<section id="practical-considerations" class="level3">
<h3 class="anchored" data-anchor-id="practical-considerations" id="practical-considerations">Practical Considerations</h3>
<ul>
<li><strong>Implementation Complexity</strong>: More complex to implement than standard MLPs</li>
<li><strong>Debugging Difficulty</strong>: Harder to diagnose training issues</li>
<li><strong>Limited Tooling</strong>: Fewer established best practices and tools</li>
</ul>
</section>
</section>
<section id="recent-developments-and-research-directions" class="level2">
<h2 class="anchored" data-anchor-id="recent-developments-and-research-directions" id="recent-developments-and-research-directions">Recent Developments and Research Directions</h2>
<section id="architectural-innovations" class="level3">
<h3 class="anchored" data-anchor-id="architectural-innovations" id="architectural-innovations">Architectural Innovations</h3>
<p><strong>Multi-dimensional KANs</strong>: Extensions to handle tensor inputs directly <strong>Convolutional KANs</strong>: Integration with convolutional architectures <strong>Recurrent KANs</strong>: Application to sequential data processing</p>
</section>
<section id="optimization-improvements" class="level3">
<h3 class="anchored" data-anchor-id="optimization-improvements" id="optimization-improvements">Optimization Improvements</h3>
<p><strong>Adaptive Grids</strong>: Dynamic grid refinement during training <strong>Regularization Techniques</strong>: Novel approaches to prevent overfitting <strong>Training Algorithms</strong>: Specialized optimizers for KAN training</p>
</section>
<section id="application-expansions" class="level3">
<h3 class="anchored" data-anchor-id="application-expansions" id="application-expansions">Application Expansions</h3>
<p><strong>Computer Vision</strong>: Exploring KANs for image processing tasks <strong>Natural Language Processing</strong>: Investigating applications in text analysis <strong>Reinforcement Learning</strong>: Using KANs for policy and value function approximation</p>
</section>
</section>
<section id="comparison-with-other-architectures" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-other-architectures" id="comparison-with-other-architectures">Comparison with Other Architectures</h2>
<section id="kans-vs-mlps" class="level3">
<h3 class="anchored" data-anchor-id="kans-vs-mlps" id="kans-vs-mlps">KANs vs MLPs</h3>
<table class="caption-top table">
<thead>
<tr class="header">
<th>Aspect</th>
<th>KANs</th>
<th>MLPs</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Activation Location</td>
<td>Edges</td>
<td>Nodes</td>
</tr>
<tr class="even">
<td>Interpretability</td>
<td>High</td>
<td>Low</td>
</tr>
<tr class="odd">
<td>Parameter Efficiency</td>
<td>Often Better</td>
<td>Standard</td>
</tr>
<tr class="even">
<td>Training Complexity</td>
<td>Higher</td>
<td>Lower</td>
</tr>
<tr class="odd">
<td>Computational Cost</td>
<td>Higher</td>
<td>Lower</td>
</tr>
</tbody>
</table>
</section>
<section id="kans-vs-transformers" class="level3">
<h3 class="anchored" data-anchor-id="kans-vs-transformers" id="kans-vs-transformers">KANs vs Transformers</h3>
<p>While Transformers excel in sequence modeling, KANs offer advantages in:</p>
<ul>
<li><strong>Interpretability</strong>: Clear function visualization</li>
<li><strong>Scientific Applications</strong>: Natural fit for physics-based problems</li>
<li><strong>Small Data Regimes</strong>: Better performance with limited training data</li>
</ul>
</section>
<section id="kans-vs-decision-trees" class="level3">
<h3 class="anchored" data-anchor-id="kans-vs-decision-trees" id="kans-vs-decision-trees">KANs vs Decision Trees</h3>
<p>Both offer interpretability, but differ in:</p>
<ul>
<li><strong>Function Types</strong>: Continuous vs.&nbsp;piecewise constant</li>
<li><strong>Expressiveness</strong>: Higher capacity in KANs</li>
<li><strong>Training</strong>: Gradient-based vs.&nbsp;greedy splitting</li>
</ul>
</section>
</section>
<section id="future-outlook" class="level2">
<h2 class="anchored" data-anchor-id="future-outlook" id="future-outlook">Future Outlook</h2>
<section id="emerging-trends" class="level3">
<h3 class="anchored" data-anchor-id="emerging-trends" id="emerging-trends">Emerging Trends</h3>
<p><strong>Hybrid Architectures</strong>: Combining KANs with other neural network types <strong>Automated Design</strong>: Using neural architecture search for KAN optimization <strong>Hardware Acceleration</strong>: Specialized hardware for efficient KAN computation</p>
</section>
<section id="research-opportunities" class="level3">
<h3 class="anchored" data-anchor-id="research-opportunities" id="research-opportunities">Research Opportunities</h3>
<p><strong>Theoretical Foundations</strong>: Developing rigorous theoretical frameworks <strong>Scalability Solutions</strong>: Addressing computational and memory challenges <strong>Domain-Specific Variants</strong>: Tailoring KANs for specific application domains</p>
</section>
<section id="industry-adoption" class="level3">
<h3 class="anchored" data-anchor-id="industry-adoption" id="industry-adoption">Industry Adoption</h3>
<p><strong>Scientific Software</strong>: Integration into computational science tools <strong>Interpretable AI</strong>: Applications requiring explainable machine learning <strong>Edge Computing</strong>: Optimized implementations for resource-constrained environments</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Kolmogorov-Arnold Networks represent a significant advancement in neural network architecture design, offering a compelling alternative to traditional MLPs. Their foundation in mathematical theory, combined with enhanced interpretability and parameter efficiency, makes them particularly valuable for scientific computing and applications requiring explainable AI.</p>
<p>While challenges remain in terms of computational complexity and scalability, ongoing research continues to address these limitations. As the field matures, KANs are likely to find increased adoption in domains where interpretability and mathematical rigor are paramount.</p>
<p>The future of KANs looks promising, with active research communities working on theoretical foundations, practical implementations, and novel applications. As our understanding of these networks deepens and computational tools improve, KANs may well become a standard tool in the machine learning practitioner’s toolkit.</p>
</section>
<section id="references-and-further-reading" class="level2">
<h2 class="anchored" data-anchor-id="references-and-further-reading" id="references-and-further-reading">References and Further Reading</h2>
<ul>
<li>Original KAN Paper: “KAN: Kolmogorov-Arnold Networks” (Liu et al., 2024)</li>
<li>Kolmogorov-Arnold Representation Theorem: Original mathematical foundations</li>
<li>B-Spline Theory: Mathematical background for function parameterization</li>
<li>Scientific Computing Applications: Domain-specific KAN implementations</li>
<li>Interpretable Machine Learning: Broader context for explainable AI methods</li>
</ul>
<hr>
<p><em>This article provides a comprehensive introduction to Kolmogorov-Arnold Networks. For the latest developments and implementations, readers are encouraged to follow recent research publications and open-source projects in the field.</em></p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch 2.x Compilation Pipeline: From FX to Hardware]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/pytorch-to-end/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/pytorch-to-end/</guid>
      <pubDate>Mon, 30 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>mlops</category>
      <category>code</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="pytorch-2.x-compilation-pipeline-from-fx-to-hardware" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/pytorch-to-end/pipeline-banner.png" class="img-fluid"></p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>PyTorch 2.x introduced a revolutionary compilation stack that transforms high-level Python code into highly optimized machine code. This guide explores the complete pipeline: <strong>PyTorch → FX → Inductor → Backend (Triton/NvFuser/C++) → Hardware (GPU/CPU)</strong>.</p>
</section>
<section id="the-big-picture" class="level2">
<h2 class="anchored" data-anchor-id="the-big-picture" id="the-big-picture">The Big Picture</h2>
<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/pytorch-to-end/pypipeend.png" class="img-fluid"></p>
<p>The compilation pipeline transforms dynamic Python code into static, optimized kernels that run directly on hardware.</p>
</section>
<section id="pytorch-fx-graph-capture" class="level2">
<h2 class="anchored" data-anchor-id="pytorch-fx-graph-capture" id="pytorch-fx-graph-capture">PyTorch FX: Graph Capture</h2>
<section id="what-is-fx" class="level3">
<h3 class="anchored" data-anchor-id="what-is-fx" id="what-is-fx">What is FX?</h3>
<p>FX (Functional eXtensions) is PyTorch’s graph representation system that captures the computational graph of PyTorch programs. Unlike traditional static graphs, FX maintains Python semantics while enabling powerful transformations.</p>
</section>
<section id="basic-fx-usage" class="level3">
<h3 class="anchored" data-anchor-id="basic-fx-usage" id="basic-fx-usage">Basic FX Usage</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.fx <span class="im">as</span> fx</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(torch.nn.Module):</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.linear <span class="op">=</span> torch.nn.Linear(<span class="dv">10</span>, <span class="dv">5</span>)</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.linear(x)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.relu(x)</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and trace the model</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel()</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>traced_model <span class="op">=</span> fx.symbolic_trace(model)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"FX Graph:"</span>)</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(traced_model.graph)</span></code></pre></div></div>
</section>
<section id="fx-graph-representation" class="level3">
<h3 class="anchored" data-anchor-id="fx-graph-representation" id="fx-graph-representation">FX Graph Representation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># The FX graph shows the computation flow</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    linear_weight <span class="op">=</span> <span class="va">self</span>.linear.weight</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    linear_bias <span class="op">=</span> <span class="va">self</span>.linear.bias</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    linear <span class="op">=</span> torch._C._nn.linear(x, linear_weight, linear_bias)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    relu <span class="op">=</span> torch.relu(linear)</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    mul <span class="op">=</span> relu <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> mul</span></code></pre></div></div>
</section>
<section id="manual-fx-transformations" class="level3">
<h3 class="anchored" data-anchor-id="manual-fx-transformations" id="manual-fx-transformations">Manual FX Transformations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.fx <span class="im">as</span> fx</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> replace_relu_with_gelu(model: fx.GraphModule) <span class="op">-&gt;</span> fx.GraphModule:</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Replace all ReLU operations with GELU"""</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> node <span class="kw">in</span> model.graph.nodes:</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> node.op <span class="op">==</span> <span class="st">'call_function'</span> <span class="kw">and</span> node.target <span class="op">==</span> torch.relu:</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>            node.target <span class="op">=</span> torch.nn.functional.gelu</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    model.recompile()</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply transformation</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>transformed_model <span class="op">=</span> replace_relu_with_gelu(traced_model)</span></code></pre></div></div>
</section>
<section id="key-features-of-fx" class="level3">
<h3 class="anchored" data-anchor-id="key-features-of-fx" id="key-features-of-fx">Key Features of FX</h3>
<p><strong>Dynamic Graph Capture</strong>: FX traces through actual Python execution, capturing control flow and dynamic shapes while building a graph representation. This approach bridges the gap between eager execution and static optimization.</p>
<p><strong>Operator-Level Granularity</strong>: The FX graph represents computations at the PyTorch operator level, providing a clean abstraction that’s both human-readable and machine-optimizable.</p>
<p><strong>Transformation Framework</strong>: FX provides a robust system for graph transformations, enabling optimizations like operator fusion, dead code elimination, and layout transformations.</p>
</section>
</section>
<section id="torchinductor-the-compiler" class="level2">
<h2 class="anchored" data-anchor-id="torchinductor-the-compiler" id="torchinductor-the-compiler">TorchInductor: The Compiler</h2>
<section id="understanding-inductor" class="level3">
<h3 class="anchored" data-anchor-id="understanding-inductor" id="understanding-inductor">Understanding Inductor</h3>
<p>TorchInductor is PyTorch’s deep learning compiler that takes FX graphs and applies sophisticated optimizations. It serves as the brain of the compilation pipeline, making intelligent decisions about how to optimize and execute the computation.</p>
</section>
<section id="core-optimization-strategies" class="level3">
<h3 class="anchored" data-anchor-id="core-optimization-strategies" id="core-optimization-strategies">Core Optimization Strategies</h3>
<p><strong>Operator Fusion</strong>: TorchInductor identifies opportunities to fuse multiple operators into single kernels, reducing memory bandwidth requirements and improving cache locality. For example, a sequence like <code>conv → batch_norm → relu</code> becomes a single fused operation.</p>
<p><strong>Memory Layout Optimization</strong>: The compiler analyzes data access patterns and optimizes tensor layouts to maximize memory bandwidth utilization. This includes choosing between row-major and column-major layouts, as well as more complex blocked layouts for specific hardware.</p>
<p><strong>Kernel Selection and Scheduling</strong>: TorchInductor makes intelligent decisions about which backend to use for each operation and how to schedule operations for optimal performance across the entire graph.</p>
</section>
<section id="basic-compilation-with-torch.compile" class="level3">
<h3 class="anchored" data-anchor-id="basic-compilation-with-torch.compile" id="basic-compilation-with-torch.compile">Basic Compilation with torch.compile()</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Simple example</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> simple_function(x, y):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x.matmul(y) <span class="op">+</span> x.<span class="bu">sum</span>(dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile the function</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>compiled_fn <span class="op">=</span> torch.<span class="bu">compile</span>(simple_function)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="co"># First call triggers compilation</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> compiled_fn(x, y)</span></code></pre></div></div>
</section>
<section id="compilation-modes" class="level3">
<h3 class="anchored" data-anchor-id="compilation-modes" id="compilation-modes">Compilation Modes</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Different compilation modes</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> torch.nn.Linear(<span class="dv">100</span>, <span class="dv">10</span>).cuda()</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Default mode (balanced speed/compilation time)</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>compiled_model_default <span class="op">=</span> torch.<span class="bu">compile</span>(model)</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Reduce overhead mode (faster compilation)</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>compiled_model_reduce <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"reduce-overhead"</span>)</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Maximum optimization mode (slower compilation, faster execution)</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>compiled_model_max <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Testing performance</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">100</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Warmup and benchmark</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    _ <span class="op">=</span> compiled_model_max(x)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>torch.cuda.synchronize()</span></code></pre></div></div>
</section>
<section id="inductor-configuration" class="level3">
<h3 class="anchored" data-anchor-id="inductor-configuration" id="inductor-configuration">Inductor Configuration</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._inductor.config <span class="im">as</span> config</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Configure Inductor behavior</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>config.debug <span class="op">=</span> <span class="va">True</span>  <span class="co"># Enable debug output</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>config.triton.convolution <span class="op">=</span> <span class="va">True</span>  <span class="co"># Use Triton for convolutions</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>config.cpp_wrapper <span class="op">=</span> <span class="va">True</span>  <span class="co"># Generate C++ wrapper</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>config.freezing <span class="op">=</span> <span class="va">True</span>  <span class="co"># Enable weight freezing optimization</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom optimization settings</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>config.max_autotune <span class="op">=</span> <span class="va">True</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>config.epilogue_fusion <span class="op">=</span> <span class="va">True</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>config.pattern_matcher <span class="op">=</span> <span class="va">True</span></span></code></pre></div></div>
</section>
</section>
<section id="backend-targets" class="level2">
<h2 class="anchored" data-anchor-id="backend-targets" id="backend-targets">Backend Targets</h2>
<section id="triton-backend-gpu" class="level3">
<h3 class="anchored" data-anchor-id="triton-backend-gpu" id="triton-backend-gpu">Triton Backend (GPU)</h3>
<p>Triton is a Python-like language for writing highly efficient GPU kernels. TorchInductor can generate Triton code that compiles to optimized CUDA kernels.</p>
<p><strong>Advantages of Triton</strong>:</p>
<ul>
<li>Higher-level abstraction than raw CUDA while maintaining performance</li>
<li>Automatic memory coalescing and shared memory optimization</li>
<li>Built-in support for blocked algorithms and tile-based computation</li>
<li>Seamless integration with PyTorch’s autograd system</li>
</ul>
<p><strong>Typical Triton workflow</strong>:</p>
<ol type="1">
<li>TorchInductor generates Triton kernel code based on the fused operations</li>
<li>Triton compiler optimizes the kernel for the target GPU architecture</li>
<li>Generated CUDA code is cached for future use</li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example of Triton-compiled operation</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> triton</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> triton.language <span class="im">as</span> tl</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="at">@triton.jit</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    pid <span class="op">=</span> tl.program_id(axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    block_start <span class="op">=</span> pid <span class="op">*</span> BLOCK_SIZE</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    offsets <span class="op">=</span> block_start <span class="op">+</span> tl.arange(<span class="dv">0</span>, BLOCK_SIZE)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    mask <span class="op">=</span> offsets <span class="op">&lt;</span> n_elements</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> tl.load(x_ptr <span class="op">+</span> offsets, mask<span class="op">=</span>mask)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> tl.load(y_ptr <span class="op">+</span> offsets, mask<span class="op">=</span>mask)</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> x <span class="op">+</span> y</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    tl.store(output_ptr <span class="op">+</span> offsets, output, mask<span class="op">=</span>mask)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> triton_add(x: torch.Tensor, y: torch.Tensor):</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> torch.empty_like(x)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    n_elements <span class="op">=</span> output.numel()</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Launch kernel</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    grid <span class="op">=</span> <span class="kw">lambda</span> meta: (triton.cdiv(n_elements, meta[<span class="st">'BLOCK_SIZE'</span>]),)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE<span class="op">=</span><span class="dv">1024</span>)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="co"># This is what Inductor generates internally for GPU operations</span></span></code></pre></div></div>
</section>
<section id="nvfuser-nvidias-fusion-runtime" class="level3">
<h3 class="anchored" data-anchor-id="nvfuser-nvidias-fusion-runtime" id="nvfuser-nvidias-fusion-runtime">NvFuser: NVIDIA’s Fusion Runtime</h3>
<p>For NVIDIA GPUs, PyTorch can leverage NvFuser, a specialized fusion compiler that excels at optimizing element-wise operations and reductions.</p>
<p><strong>NvFuser Strengths</strong>:</p>
<ul>
<li>Deep integration with CUDA runtime and libraries</li>
<li>Sophisticated analysis for memory access patterns</li>
<li>Optimized handling of broadcasting and reduction operations</li>
<li>Advanced techniques like loop unrolling and vectorization</li>
</ul>
</section>
<section id="c-backend-cpu" class="level3">
<h3 class="anchored" data-anchor-id="c-backend-cpu" id="c-backend-cpu">C++ Backend (CPU)</h3>
<p>For CPU execution, TorchInductor generates optimized C++ code that leverages vectorization and multi-threading.</p>
<p><strong>CPU Optimization Features</strong>:</p>
<ul>
<li>SIMD vectorization using AVX, AVX2, and AVX-512 instructions</li>
<li>OpenMP parallelization for multi-core utilization</li>
<li>Cache-aware algorithms and memory prefetching</li>
<li>Integration with optimized BLAS libraries like MKL and OpenBLAS</li>
</ul>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example of CPU compilation</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_intensive_function(x):</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Complex operations that benefit from C++ optimization</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.sin(x)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.cos(x)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.exp(x)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x.<span class="bu">sum</span>()</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="co"># CPU tensor</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>x_cpu <span class="op">=</span> torch.randn(<span class="dv">10000</span>, <span class="dv">10000</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> cpu_intensive_function(x_cpu)</span></code></pre></div></div>
</section>
<section id="backend-selection" class="level3">
<h3 class="anchored" data-anchor-id="backend-selection" id="backend-selection">Backend Selection</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Specify backend explicitly</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._inductor</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co"># For GPU (Triton)</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>compiled_gpu <span class="op">=</span> torch.<span class="bu">compile</span>(model, backend<span class="op">=</span><span class="st">"inductor"</span>)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="co"># For CPU (C++)</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>compiled_cpu <span class="op">=</span> torch.<span class="bu">compile</span>(model, backend<span class="op">=</span><span class="st">"inductor"</span>)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom backend</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> custom_backend(gm, example_inputs):</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Custom compilation backend"""</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Compiling graph with </span><span class="sc">{</span><span class="bu">len</span>(gm.graph.nodes)<span class="sc">}</span><span class="ss"> nodes"</span>)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> gm</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>compiled_custom <span class="op">=</span> torch.<span class="bu">compile</span>(model, backend<span class="op">=</span>custom_backend)</span></code></pre></div></div>
</section>
</section>
<section id="hardware-execution" class="level2">
<h2 class="anchored" data-anchor-id="hardware-execution" id="hardware-execution">Hardware Execution</h2>
<section id="gpu-execution-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="gpu-execution-pipeline" id="gpu-execution-pipeline">GPU Execution Pipeline</h3>
<p>On GPU systems, the compiled kernels execute within CUDA streams, enabling overlap between computation and memory transfers. The runtime system manages:</p>
<ul>
<li><strong>Memory Management</strong>: Efficient allocation and deallocation of GPU memory</li>
<li><strong>Stream Scheduling</strong>: Coordinating multiple CUDA streams for maximum throughput</li>
<li><strong>Synchronization</strong>: Managing dependencies between GPU operations</li>
<li><strong>Dynamic Shapes</strong>: Handling varying input sizes without recompilation</li>
</ul>
</section>
<section id="cpu-execution-optimization" class="level3">
<h3 class="anchored" data-anchor-id="cpu-execution-optimization" id="cpu-execution-optimization">CPU Execution Optimization</h3>
<p>CPU execution focuses on maximizing utilization of available cores and cache hierarchy:</p>
<ul>
<li><strong>Thread Pool Management</strong>: Efficient distribution of work across CPU cores</li>
<li><strong>NUMA Awareness</strong>: Optimizing memory access patterns for multi-socket systems</li>
<li><strong>Cache Optimization</strong>: Minimizing cache misses through intelligent data layout</li>
<li><strong>Vectorization</strong>: Leveraging SIMD instructions for parallel data processing</li>
</ul>
</section>
</section>
<section id="performance-impact-and-benefits" class="level2">
<h2 class="anchored" data-anchor-id="performance-impact-and-benefits" id="performance-impact-and-benefits">Performance Impact and Benefits</h2>
<section id="quantitative-improvements" class="level3">
<h3 class="anchored" data-anchor-id="quantitative-improvements" id="quantitative-improvements">Quantitative Improvements</h3>
<p>The PyTorch 2.x compilation pipeline typically delivers:</p>
<ul>
<li><strong>2-10x speedup</strong> for training workloads</li>
<li><strong>3-20x speedup</strong> for inference scenarios</li>
<li><strong>Significant memory efficiency</strong> improvements through fusion</li>
<li><strong>Better hardware utilization</strong> across different architectures</li>
</ul>
</section>
<section id="qualitative-advantages" class="level3">
<h3 class="anchored" data-anchor-id="qualitative-advantages" id="qualitative-advantages">Qualitative Advantages</h3>
<p><strong>Ease of Use</strong>: Developers can achieve these performance benefits with minimal code changes, often just adding <code>torch.compile()</code> decorators.</p>
<p><strong>Debugging Support</strong>: The compilation pipeline maintains debugging capabilities, allowing developers to inspect intermediate representations and profile performance bottlenecks.</p>
<p><strong>Backward Compatibility</strong>: Existing PyTorch code continues to work unchanged, with compilation providing transparent acceleration.</p>
</section>
</section>
<section id="complete-example-walkthrough" class="level2">
<h2 class="anchored" data-anchor-id="complete-example-walkthrough" id="complete-example-walkthrough">Complete Example Walkthrough</h2>
<section id="resnet-block-compilation" class="level3">
<h3 class="anchored" data-anchor-id="resnet-block-compilation" id="resnet-block-compilation">ResNet Block Compilation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ResNetBlock(nn.Module):</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, out_channels, stride<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(in_channels, out_channels, <span class="dv">3</span>, stride, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bn1 <span class="op">=</span> nn.BatchNorm2d(out_channels)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> nn.Conv2d(out_channels, out_channels, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">1</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bn2 <span class="op">=</span> nn.BatchNorm2d(out_channels)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.shortcut <span class="op">=</span> nn.Sequential()</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> stride <span class="op">!=</span> <span class="dv">1</span> <span class="kw">or</span> in_channels <span class="op">!=</span> out_channels:</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.shortcut <span class="op">=</span> nn.Sequential(</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>                nn.Conv2d(in_channels, out_channels, <span class="dv">1</span>, stride, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>                nn.BatchNorm2d(out_channels)</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> torch.relu(<span class="va">self</span>.bn1(<span class="va">self</span>.conv1(x)))</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.bn2(<span class="va">self</span>.conv2(out))</span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>        out <span class="op">+=</span> <span class="va">self</span>.shortcut(x)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> torch.relu(out)</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model</span></span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> ResNetBlock(<span class="dv">64</span>, <span class="dv">64</span>).cuda()</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile with different modes</span></span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>model_compiled <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Benchmark</span></span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_model(model, input_tensor, num_runs<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warmup</span></span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>        _ <span class="op">=</span> model(input_tensor)</span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>    torch.cuda.synchronize()</span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-43"><a href="#cb10-43" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_runs):</span>
<span id="cb10-44"><a href="#cb10-44" aria-hidden="true" tabindex="-1"></a>        _ <span class="op">=</span> model(input_tensor)</span>
<span id="cb10-45"><a href="#cb10-45" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-46"><a href="#cb10-46" aria-hidden="true" tabindex="-1"></a>    torch.cuda.synchronize()</span>
<span id="cb10-47"><a href="#cb10-47" aria-hidden="true" tabindex="-1"></a>    end_time <span class="op">=</span> time.time()</span>
<span id="cb10-48"><a href="#cb10-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-49"><a href="#cb10-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> (end_time <span class="op">-</span> start_time) <span class="op">/</span> num_runs</span>
<span id="cb10-50"><a href="#cb10-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-51"><a href="#cb10-51" aria-hidden="true" tabindex="-1"></a><span class="co"># Test input</span></span>
<span id="cb10-52"><a href="#cb10-52" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">56</span>, <span class="dv">56</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb10-53"><a href="#cb10-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-54"><a href="#cb10-54" aria-hidden="true" tabindex="-1"></a><span class="co"># Benchmark both versions</span></span>
<span id="cb10-55"><a href="#cb10-55" aria-hidden="true" tabindex="-1"></a>eager_time <span class="op">=</span> benchmark_model(model, x)</span>
<span id="cb10-56"><a href="#cb10-56" aria-hidden="true" tabindex="-1"></a>compiled_time <span class="op">=</span> benchmark_model(model_compiled, x)</span>
<span id="cb10-57"><a href="#cb10-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-58"><a href="#cb10-58" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Eager mode: </span><span class="sc">{</span>eager_time<span class="op">*</span><span class="dv">1000</span><span class="sc">:.2f}</span><span class="ss">ms"</span>)</span>
<span id="cb10-59"><a href="#cb10-59" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Compiled mode: </span><span class="sc">{</span>compiled_time<span class="op">*</span><span class="dv">1000</span><span class="sc">:.2f}</span><span class="ss">ms"</span>)</span>
<span id="cb10-60"><a href="#cb10-60" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Speedup: </span><span class="sc">{</span>eager_time<span class="op">/</span>compiled_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span></code></pre></div></div>
</section>
<section id="attention-mechanism-optimization" class="level3">
<h3 class="anchored" data-anchor-id="attention-mechanism-optimization" id="attention-mechanism-optimization">Attention Mechanism Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiHeadAttention(nn.Module):</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, num_heads):</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_heads <span class="op">=</span> num_heads</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_k <span class="op">=</span> d_model <span class="op">//</span> num_heads</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.W_q <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.W_k <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.W_v <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.W_o <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, query, key, value, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> query.size(<span class="dv">0</span>)</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear projections</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        Q <span class="op">=</span> <span class="va">self</span>.W_q(query).view(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        K <span class="op">=</span> <span class="va">self</span>.W_k(key).view(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        V <span class="op">=</span> <span class="va">self</span>.W_v(value).view(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Scaled dot-product attention</span></span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> torch.matmul(Q, K.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>)) <span class="op">/</span> math.sqrt(<span class="va">self</span>.d_k)</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>            scores <span class="op">=</span> scores.masked_fill(mask <span class="op">==</span> <span class="dv">0</span>, <span class="op">-</span><span class="fl">1e9</span>)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        attention_weights <span class="op">=</span> F.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        attention_output <span class="op">=</span> torch.matmul(attention_weights, V)</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate heads</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>        attention_output <span class="op">=</span> attention_output.transpose(<span class="dv">1</span>, <span class="dv">2</span>).contiguous().view(</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>            batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="va">self</span>.d_model</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.W_o(attention_output)</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile attention</span></span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>attention <span class="op">=</span> MultiHeadAttention(<span class="dv">512</span>, <span class="dv">8</span>).cuda()</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>compiled_attention <span class="op">=</span> torch.<span class="bu">compile</span>(attention, mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Test with transformer-like input</span></span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>seq_len, batch_size, d_model <span class="op">=</span> <span class="dv">1024</span>, <span class="dv">32</span>, <span class="dv">512</span></span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(batch_size, seq_len, d_model, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a><span class="co"># The compiled version will use optimized kernels for attention</span></span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>output <span class="op">=</span> compiled_attention(x, x, x)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-optimization-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-optimization-techniques" id="advanced-optimization-techniques">Advanced Optimization Techniques</h2>
<section id="custom-fusion-patterns" class="level3">
<h3 class="anchored" data-anchor-id="custom-fusion-patterns" id="custom-fusion-patterns">Custom Fusion Patterns</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._inductor.lowering <span class="im">as</span> lowering</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch._inductor.pattern_matcher <span class="im">import</span> PatternMatcher</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Define custom fusion patterns</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> register_custom_patterns():</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Register custom optimization patterns"""</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">@torch._inductor.pattern_matcher.register_pattern</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> fuse_add_relu(match_output, x, y):</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Fuse addition followed by ReLU"""</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        add_result <span class="op">=</span> torch.add(x, y)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.relu(add_result)</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># This pattern will be automatically detected and fused</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Memory optimization</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_function(x):</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use in-place operations where possible</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> x.add_(<span class="fl">1.0</span>)  <span class="co"># In-place addition</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> x.mul_(<span class="fl">2.0</span>)  <span class="co"># In-place multiplication</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="dynamic-shape-handling" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-shape-handling" id="dynamic-shape-handling">Dynamic Shape Handling</h3>
<p>The compilation system handles dynamic input shapes through a combination of specialization and generalization strategies. When shapes change frequently, the compiler can generate kernels that handle ranges of shapes efficiently.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Handling dynamic shapes</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span>(dynamic<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> dynamic_function(x):</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># This function can handle varying input shapes</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x.<span class="bu">sum</span>(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Test with different shapes</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>shapes <span class="op">=</span> [(<span class="dv">100</span>, <span class="dv">50</span>), (<span class="dv">200</span>, <span class="dv">30</span>), (<span class="dv">150</span>, <span class="dv">80</span>)]</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> shape <span class="kw">in</span> shapes:</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(<span class="op">*</span>shape, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> dynamic_function(x)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Shape </span><span class="sc">{</span>shape<span class="sc">}</span><span class="ss"> -&gt; </span><span class="sc">{</span>result<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="reduce-overhead-mode" class="level3">
<h3 class="anchored" data-anchor-id="reduce-overhead-mode" id="reduce-overhead-mode">Reduce Overhead Mode</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._dynamo <span class="im">as</span> dynamo</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Configure for minimal overhead</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>dynamo.config.suppress_errors <span class="op">=</span> <span class="va">True</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>dynamo.config.cache_size_limit <span class="op">=</span> <span class="dv">1000</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span>(mode<span class="op">=</span><span class="st">"reduce-overhead"</span>)</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> low_overhead_function(x):</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimized for minimal compilation overhead</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x.relu().<span class="bu">sum</span>()</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a><span class="co"># This mode is ideal for frequently called functions</span></span></code></pre></div></div>
</section>
</section>
<section id="debugging-and-profiling" class="level2">
<h2 class="anchored" data-anchor-id="debugging-and-profiling" id="debugging-and-profiling">Debugging and Profiling</h2>
<section id="compilation-debugging" class="level3">
<h3 class="anchored" data-anchor-id="compilation-debugging" id="compilation-debugging">Compilation Debugging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._dynamo <span class="im">as</span> dynamo</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._inductor.config <span class="im">as</span> config</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable debug output</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>config.debug <span class="op">=</span> <span class="va">True</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>config.trace.enabled <span class="op">=</span> <span class="va">True</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Set environment variables (in shell)</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a><span class="co"># export TORCH_COMPILE_DEBUG=1</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a><span class="co"># export TORCHINDUCTOR_TRACE=1</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span></span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> debug_function(x):</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.sin(x).<span class="bu">sum</span>()</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a><span class="co"># This will show compilation steps</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> debug_function(x)</span></code></pre></div></div>
</section>
<section id="performance-profiling" class="level3">
<h3 class="anchored" data-anchor-id="performance-profiling" id="performance-profiling">Performance Profiling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.profiler</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> profile_compilation():</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> torch.nn.Linear(<span class="dv">1000</span>, <span class="dv">1000</span>).cuda()</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    compiled_model <span class="op">=</span> torch.<span class="bu">compile</span>(model)</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.profiler.profile(</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        activities<span class="op">=</span>[</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>            torch.profiler.ProfilerActivity.CPU,</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>            torch.profiler.ProfilerActivity.CUDA,</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>        ],</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        record_shapes<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        with_stack<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    ) <span class="im">as</span> prof:</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> compiled_model(x)</span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Profile</span></span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> compiled_model(x)</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(prof.key_averages().table(sort_by<span class="op">=</span><span class="st">"cuda_time_total"</span>, row_limit<span class="op">=</span><span class="dv">10</span>))</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>profile_compilation()</span></code></pre></div></div>
</section>
<section id="inspecting-generated-code" class="level3">
<h3 class="anchored" data-anchor-id="inspecting-generated-code" id="inspecting-generated-code">Inspecting Generated Code</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch._inductor.codecache <span class="im">as</span> codecache</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable code generation inspection</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span>(mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> inspectable_function(x, y):</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.matmul(x, y) <span class="op">+</span> torch.sin(x)</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a><span class="co"># After compilation, you can inspect generated code</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> inspectable_function(x, y)</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Generated Triton/C++ code will be available in the cache</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Generated code location:"</span>, codecache.PyCodeCache.cache_dir)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="model-preparation" class="level3">
<h3 class="anchored" data-anchor-id="model-preparation" id="model-preparation">1. Model Preparation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare your model for compilation</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> prepare_model_for_compilation(model):</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Best practices for model preparation"""</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set to eval mode for inference</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Move to appropriate device</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> model.cuda()  <span class="co"># or .cpu()</span></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Freeze batch norm layers</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> module <span class="kw">in</span> model.modules():</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>            module.<span class="bu">eval</span>()</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile with appropriate settings</span></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> prepare_model_for_compilation(model)</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>compiled_model <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"max-autotune"</span>)</span></code></pre></div></div>
</section>
<section id="effective-warmup" class="level3">
<h3 class="anchored" data-anchor-id="effective-warmup" id="effective-warmup">2. Effective Warmup</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> warmup_compiled_model(compiled_model, example_inputs, num_warmup<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Proper warmup for compiled models"""</span></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warmup runs</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_warmup):</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> compiled_model(<span class="op">*</span>example_inputs)</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Ensure GPU synchronization</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize()</span></code></pre></div></div>
</section>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">3. Memory Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.compile</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_training_step(model, optimizer, x, y, loss_fn):</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Memory-efficient training step"""</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forward pass</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(x)</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> loss_fn(output, y)</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Backward pass</span></span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad(set_to_none<span class="op">=</span><span class="va">True</span>)  <span class="co"># More memory efficient</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>    loss.backward()</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss.item()</span></code></pre></div></div>
</section>
<section id="performance-tuning-tips" class="level3">
<h3 class="anchored" data-anchor-id="performance-tuning-tips" id="performance-tuning-tips">4. Performance Tuning Tips</h3>
<p><strong>Warm-up Compilation</strong>: The first execution includes compilation overhead. For production deployments, run a few warm-up iterations to ensure kernels are compiled and cached.</p>
<p><strong>Batch Size Considerations</strong>: Larger batch sizes generally benefit more from compilation due to better amortization of kernel launch overhead and improved arithmetic intensity.</p>
<p><strong>Memory Layout Awareness</strong>: Consider tensor layouts and memory access patterns when designing models, as the compiler can optimize more effectively with regular access patterns.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The PyTorch 2.x compilation pipeline represents a significant advancement in deep learning optimization. By understanding the flow from FX graph capture through Inductor compilation to hardware-specific backends, you can:</p>
<ol type="1">
<li><strong>Achieve significant speedups</strong> (2-10x) with minimal code changes</li>
<li><strong>Optimize memory usage</strong> through fusion and kernel optimization</li>
<li><strong>Handle dynamic workloads</strong> efficiently</li>
<li><strong>Debug performance issues</strong> at each compilation stage</li>
</ol>
<p>The journey from high-level Python code through FX graph representation, TorchInductor optimization, and backend-specific code generation demonstrates the sophisticated engineering required to make complex optimizations accessible to everyday users. As the ecosystem continues to evolve, we can expect even greater performance improvements and broader hardware support while maintaining PyTorch’s commitment to usability and research flexibility.</p>
<p>This compilation pipeline not only accelerates existing workloads but also enables new possibilities in model architecture design and deployment strategies, making it an essential tool for the modern deep learning practitioner.</p>
<p>The key to success is understanding when and how to apply compilation, proper model preparation, and effective debugging when issues arise. Start with simple <code>torch.compile()</code> calls and gradually explore advanced optimization techniques as needed.</p>
<section id="key-takeaways" class="level3">
<h3 class="anchored" data-anchor-id="key-takeaways" id="key-takeaways">Key Takeaways</h3>
<ul>
<li>Use <code>torch.compile()</code> for automatic optimization</li>
<li>Choose appropriate compilation modes based on your use case</li>
<li>Leverage FX for custom graph transformations</li>
<li>Monitor memory usage and compilation overhead</li>
<li>Profile and debug systematically</li>
</ul>
<p>This compilation stack makes PyTorch 2.x not just user-friendly but also performance-competitive with specialized frameworks, all while maintaining the flexibility and ease of use that PyTorch is known for.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Matryoshka Transformer: Complete Implementation Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-code/</guid>
      <pubDate>Sat, 28 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="matryoshka-transformer-complete-implementation-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-code/matryo.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Matryoshka Transformers are a neural architecture that enables flexible computational budgets during inference by allowing early exits at different layers. Named after Russian nesting dolls, these models contain multiple “nested” representations of decreasing complexity, allowing you to trade off accuracy for speed based on your computational constraints.</p>
</section>
<section id="key-concepts" class="level2">
<h2 class="anchored" data-anchor-id="key-concepts" id="key-concepts">Key Concepts</h2>
<section id="core-ideas" class="level3">
<h3 class="anchored" data-anchor-id="core-ideas" id="core-ideas">Core Ideas</h3>
<ul>
<li><strong>Nested Representations</strong>: Each layer can potentially serve as a final output</li>
<li><strong>Early Exits</strong>: Inference can stop at any intermediate layer</li>
<li><strong>Adaptive Computation</strong>: Different inputs may require different amounts of computation</li>
<li><strong>Training Efficiency</strong>: Single model training for multiple computational budgets</li>
</ul>
</section>
<section id="architecture-overview" class="level3">
<h3 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h3>
<pre><code>Input → Layer 1 → [Exit 1] → Layer 2 → [Exit 2] → ... → Layer N → [Final Exit]</code></pre>
</section>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation" id="implementation">Implementation</h2>
<section id="basic-matryoshka-transformer-block" class="level3">
<h3 class="anchored" data-anchor-id="basic-matryoshka-transformer-block" id="basic-matryoshka-transformer-block">1. Basic Matryoshka Transformer Block</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Optional, Tuple</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MatryoshkaTransformerBlock(nn.Module):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="co">    A single transformer block with optional early exit capability</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        d_model: <span class="bu">int</span>,</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        n_heads: <span class="bu">int</span>,</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        d_ff: <span class="bu">int</span>,</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        dropout: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span>,</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        has_exit: <span class="bu">bool</span> <span class="op">=</span> <span class="va">False</span>,</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        n_classes: Optional[<span class="bu">int</span>] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    ):</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Standard transformer components</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.attention <span class="op">=</span> nn.MultiheadAttention(</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>            d_model, n_heads, dropout<span class="op">=</span>dropout, batch_first<span class="op">=</span><span class="va">True</span></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.feed_forward <span class="op">=</span> nn.Sequential(</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>            nn.Linear(d_model, d_ff),</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout),</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>            nn.Linear(d_ff, d_model)</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm1 <span class="op">=</span> nn.LayerNorm(d_model)</span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm2 <span class="op">=</span> nn.LayerNorm(d_model)</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Early exit components</span></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.has_exit <span class="op">=</span> has_exit</span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> has_exit <span class="kw">and</span> n_classes <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.exit_classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>                nn.LayerNorm(d_model),</span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>                nn.Linear(d_model, d_model <span class="op">//</span> <span class="dv">2</span>),</span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>                nn.ReLU(),</span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>                nn.Dropout(dropout),</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>                nn.Linear(d_model <span class="op">//</span> <span class="dv">2</span>, n_classes)</span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(</span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>, </span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>        x: torch.Tensor, </span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>        mask: Optional[torch.Tensor] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> Tuple[torch.Tensor, Optional[torch.Tensor]]:</span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass with optional early exit</span></span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a><span class="co">            x: Transformed input</span></span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a><span class="co">            exit_logits: Early exit predictions (if has_exit=True)</span></span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Self-attention</span></span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        attn_out, _ <span class="op">=</span> <span class="va">self</span>.attention(x, x, x, attn_mask<span class="op">=</span>mask)</span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm1(x <span class="op">+</span> <span class="va">self</span>.dropout(attn_out))</span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feed-forward</span></span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>        ff_out <span class="op">=</span> <span class="va">self</span>.feed_forward(x)</span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm2(x <span class="op">+</span> <span class="va">self</span>.dropout(ff_out))</span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Early exit prediction</span></span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a>        exit_logits <span class="op">=</span> <span class="va">None</span></span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.has_exit:</span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Use mean pooling for sequence classification</span></span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a>            pooled <span class="op">=</span> x.mean(dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># [batch_size, d_model]</span></span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a>            exit_logits <span class="op">=</span> <span class="va">self</span>.exit_classifier(pooled)</span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x, exit_logits</span></code></pre></div></div>
</section>
<section id="complete-matryoshka-transformer-model" class="level3">
<h3 class="anchored" data-anchor-id="complete-matryoshka-transformer-model" id="complete-matryoshka-transformer-model">2. Complete Matryoshka Transformer Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MatryoshkaTransformer(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Complete Matryoshka Transformer with multiple exit points</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        vocab_size: <span class="bu">int</span>,</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        d_model: <span class="bu">int</span> <span class="op">=</span> <span class="dv">512</span>,</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        n_heads: <span class="bu">int</span> <span class="op">=</span> <span class="dv">8</span>,</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        n_layers: <span class="bu">int</span> <span class="op">=</span> <span class="dv">6</span>,</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        d_ff: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2048</span>,</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        max_seq_len: <span class="bu">int</span> <span class="op">=</span> <span class="dv">512</span>,</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        n_classes: <span class="bu">int</span> <span class="op">=</span> <span class="dv">2</span>,</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        dropout: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.1</span>,</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        exit_layers: List[<span class="bu">int</span>] <span class="op">=</span> <span class="va">None</span>  <span class="co"># Layers with early exits</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    ):</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.n_layers <span class="op">=</span> n_layers</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Default exit layers (every 2 layers + final)</span></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> exit_layers <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>            exit_layers <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="dv">1</span>, n_layers, <span class="dv">2</span>)) <span class="op">+</span> [n_layers <span class="op">-</span> <span class="dv">1</span>]</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.exit_layers <span class="op">=</span> <span class="bu">set</span>(exit_layers)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Embeddings</span></span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.token_embedding <span class="op">=</span> nn.Embedding(vocab_size, d_model)</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.position_embedding <span class="op">=</span> nn.Embedding(max_seq_len, d_model)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Transformer blocks</span></span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.blocks <span class="op">=</span> nn.ModuleList([</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>            MatryoshkaTransformerBlock(</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>                d_model<span class="op">=</span>d_model,</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>                n_heads<span class="op">=</span>n_heads,</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>                d_ff<span class="op">=</span>d_ff,</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>                dropout<span class="op">=</span>dropout,</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>                has_exit<span class="op">=</span>(i <span class="kw">in</span> <span class="va">self</span>.exit_layers),</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>                n_classes<span class="op">=</span>n_classes</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_layers)</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final classifier (always present)</span></span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.final_classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>            nn.LayerNorm(d_model),</span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>            nn.Linear(d_model, n_classes)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Confidence thresholds for early exits</span></span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.confidence_thresholds <span class="op">=</span> nn.Parameter(</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>            torch.full((<span class="bu">len</span>(<span class="va">self</span>.exit_layers),), <span class="fl">0.8</span>)</span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>        input_ids: torch.Tensor,</span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>        attention_mask: Optional[torch.Tensor] <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>        return_all_exits: <span class="bu">bool</span> <span class="op">=</span> <span class="va">False</span>,</span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>        confidence_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.8</span>,</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>        max_exit_layer: Optional[<span class="bu">int</span>] <span class="op">=</span> <span class="va">None</span></span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a><span class="co">        Forward pass with adaptive early exiting</span></span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a><span class="co">            input_ids: Input token IDs [batch_size, seq_len]</span></span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a><span class="co">            attention_mask: Attention mask [batch_size, seq_len]</span></span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a><span class="co">            return_all_exits: Whether to return predictions from all exit points</span></span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a><span class="co">            confidence_threshold: Minimum confidence for early exit</span></span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a><span class="co">            max_exit_layer: Maximum layer to exit at (for budget constraints)</span></span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a><span class="co">        Returns:</span></span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a><span class="co">            Dictionary containing predictions and exit information</span></span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len <span class="op">=</span> input_ids.shape</span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Embeddings</span></span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>        positions <span class="op">=</span> torch.arange(seq_len, device<span class="op">=</span>input_ids.device).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.token_embedding(input_ids) <span class="op">+</span> <span class="va">self</span>.position_embedding(positions)</span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Prepare attention mask</span></span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> attention_mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Convert to transformer format</span></span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>            attn_mask <span class="op">=</span> attention_mask.unsqueeze(<span class="dv">1</span>).unsqueeze(<span class="dv">2</span>)</span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>            attn_mask <span class="op">=</span> (<span class="fl">1.0</span> <span class="op">-</span> attn_mask) <span class="op">*</span> <span class="op">-</span><span class="fl">10000.0</span></span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a>            attn_mask <span class="op">=</span> attn_mask.squeeze(<span class="dv">1</span>).squeeze(<span class="dv">1</span>)</span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>            attn_mask <span class="op">=</span> <span class="va">None</span></span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Track exits</span></span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a>        exit_predictions <span class="op">=</span> []</span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a>        exit_confidences <span class="op">=</span> []</span>
<span id="cb3-96"><a href="#cb3-96" aria-hidden="true" tabindex="-1"></a>        exit_layer <span class="op">=</span> <span class="va">None</span></span>
<span id="cb3-97"><a href="#cb3-97" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-98"><a href="#cb3-98" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward through transformer blocks</span></span>
<span id="cb3-99"><a href="#cb3-99" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i, block <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.blocks):</span>
<span id="cb3-100"><a href="#cb3-100" aria-hidden="true" tabindex="-1"></a>            x, exit_logits <span class="op">=</span> block(x, attn_mask)</span>
<span id="cb3-101"><a href="#cb3-101" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-102"><a href="#cb3-102" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Check for early exit</span></span>
<span id="cb3-103"><a href="#cb3-103" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> exit_logits <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb3-104"><a href="#cb3-104" aria-hidden="true" tabindex="-1"></a>                exit_probs <span class="op">=</span> F.softmax(exit_logits, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb3-105"><a href="#cb3-105" aria-hidden="true" tabindex="-1"></a>                max_confidence <span class="op">=</span> torch.<span class="bu">max</span>(exit_probs, dim<span class="op">=-</span><span class="dv">1</span>)[<span class="dv">0</span>]</span>
<span id="cb3-106"><a href="#cb3-106" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb3-107"><a href="#cb3-107" aria-hidden="true" tabindex="-1"></a>                exit_predictions.append(exit_logits)</span>
<span id="cb3-108"><a href="#cb3-108" aria-hidden="true" tabindex="-1"></a>                exit_confidences.append(max_confidence)</span>
<span id="cb3-109"><a href="#cb3-109" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb3-110"><a href="#cb3-110" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Early exit decision</span></span>
<span id="cb3-111"><a href="#cb3-111" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="kw">not</span> return_all_exits:</span>
<span id="cb3-112"><a href="#cb3-112" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> max_exit_layer <span class="kw">is</span> <span class="va">None</span> <span class="kw">or</span> i <span class="op">&lt;=</span> max_exit_layer:</span>
<span id="cb3-113"><a href="#cb3-113" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">if</span> torch.mean(max_confidence) <span class="op">&gt;=</span> confidence_threshold:</span>
<span id="cb3-114"><a href="#cb3-114" aria-hidden="true" tabindex="-1"></a>                            exit_layer <span class="op">=</span> i</span>
<span id="cb3-115"><a href="#cb3-115" aria-hidden="true" tabindex="-1"></a>                            <span class="cf">break</span></span>
<span id="cb3-116"><a href="#cb3-116" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-117"><a href="#cb3-117" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final prediction</span></span>
<span id="cb3-118"><a href="#cb3-118" aria-hidden="true" tabindex="-1"></a>        final_output <span class="op">=</span> <span class="va">self</span>.final_classifier(x.mean(dim<span class="op">=</span><span class="dv">1</span>))</span>
<span id="cb3-119"><a href="#cb3-119" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-120"><a href="#cb3-120" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb3-121"><a href="#cb3-121" aria-hidden="true" tabindex="-1"></a>            <span class="st">'logits'</span>: final_output,</span>
<span id="cb3-122"><a href="#cb3-122" aria-hidden="true" tabindex="-1"></a>            <span class="st">'exit_predictions'</span>: exit_predictions,</span>
<span id="cb3-123"><a href="#cb3-123" aria-hidden="true" tabindex="-1"></a>            <span class="st">'exit_confidences'</span>: exit_confidences,</span>
<span id="cb3-124"><a href="#cb3-124" aria-hidden="true" tabindex="-1"></a>            <span class="st">'exit_layer'</span>: exit_layer,</span>
<span id="cb3-125"><a href="#cb3-125" aria-hidden="true" tabindex="-1"></a>            <span class="st">'total_layers_used'</span>: (exit_layer <span class="op">+</span> <span class="dv">1</span>) <span class="cf">if</span> exit_layer <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span> <span class="cf">else</span> <span class="va">self</span>.n_layers</span>
<span id="cb3-126"><a href="#cb3-126" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
<section id="training-strategy" class="level3">
<h3 class="anchored" data-anchor-id="training-strategy" id="training-strategy">3. Training Strategy</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MatryoshkaTrainer:</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Training strategy for Matryoshka Transformers</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        model: MatryoshkaTransformer,</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        exit_loss_weights: List[<span class="bu">float</span>] <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        distillation_weight: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.5</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    ):</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.exit_loss_weights <span class="op">=</span> exit_loss_weights <span class="kw">or</span> [<span class="fl">0.3</span>, <span class="fl">0.3</span>, <span class="fl">1.0</span>]  <span class="co"># Increasing weights</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.distillation_weight <span class="op">=</span> distillation_weight</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> compute_loss(</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        outputs: <span class="bu">dict</span>,</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        labels: torch.Tensor,</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        temperature: <span class="bu">float</span> <span class="op">=</span> <span class="fl">3.0</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a><span class="co">        Compute combined loss from all exit points</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        losses <span class="op">=</span> {}</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final layer loss</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        final_loss <span class="op">=</span> F.cross_entropy(outputs[<span class="st">'logits'</span>], labels)</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        losses[<span class="st">'final'</span>] <span class="op">=</span> final_loss</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> final_loss</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Early exit losses</span></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> outputs[<span class="st">'exit_predictions'</span>]:</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i, (exit_logits, weight) <span class="kw">in</span> <span class="bu">enumerate</span>(</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>                <span class="bu">zip</span>(outputs[<span class="st">'exit_predictions'</span>], <span class="va">self</span>.exit_loss_weights)</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>            ):</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Classification loss</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>                exit_loss <span class="op">=</span> F.cross_entropy(exit_logits, labels)</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>                losses[<span class="ss">f'exit_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">'</span>] <span class="op">=</span> exit_loss</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>                total_loss <span class="op">+=</span> weight <span class="op">*</span> exit_loss</span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Knowledge distillation from final layer</span></span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> <span class="va">self</span>.distillation_weight <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>                    distill_loss <span class="op">=</span> F.kl_div(</span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>                        F.log_softmax(exit_logits <span class="op">/</span> temperature, dim<span class="op">=-</span><span class="dv">1</span>),</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>                        F.softmax(outputs[<span class="st">'logits'</span>] <span class="op">/</span> temperature, dim<span class="op">=-</span><span class="dv">1</span>),</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>                        reduction<span class="op">=</span><span class="st">'batchmean'</span></span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>                    ) <span class="op">*</span> (temperature <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>                    losses[<span class="ss">f'distill_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">'</span>] <span class="op">=</span> distill_loss</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>                    total_loss <span class="op">+=</span> <span class="va">self</span>.distillation_weight <span class="op">*</span> weight <span class="op">*</span> distill_loss</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>        losses[<span class="st">'total'</span>] <span class="op">=</span> total_loss</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> losses</span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_step(</span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        batch: <span class="bu">dict</span>,</span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        optimizer: torch.optim.Optimizer</span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a><span class="co">        Single training step</span></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.train()</span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(</span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>            input_ids<span class="op">=</span>batch[<span class="st">'input_ids'</span>],</span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>            attention_mask<span class="op">=</span>batch[<span class="st">'attention_mask'</span>],</span>
<span id="cb4-71"><a href="#cb4-71" aria-hidden="true" tabindex="-1"></a>            return_all_exits<span class="op">=</span><span class="va">True</span></span>
<span id="cb4-72"><a href="#cb4-72" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-73"><a href="#cb4-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-74"><a href="#cb4-74" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute loss</span></span>
<span id="cb4-75"><a href="#cb4-75" aria-hidden="true" tabindex="-1"></a>        losses <span class="op">=</span> <span class="va">self</span>.compute_loss(outputs, batch[<span class="st">'labels'</span>])</span>
<span id="cb4-76"><a href="#cb4-76" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-77"><a href="#cb4-77" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass</span></span>
<span id="cb4-78"><a href="#cb4-78" aria-hidden="true" tabindex="-1"></a>        losses[<span class="st">'total'</span>].backward()</span>
<span id="cb4-79"><a href="#cb4-79" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb4-80"><a href="#cb4-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-81"><a href="#cb4-81" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {k: v.item() <span class="cf">for</span> k, v <span class="kw">in</span> losses.items()}</span></code></pre></div></div>
</section>
<section id="inference-with-adaptive-computation" class="level3">
<h3 class="anchored" data-anchor-id="inference-with-adaptive-computation" id="inference-with-adaptive-computation">4. Inference with Adaptive Computation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdaptiveInference:</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Adaptive inference with configurable exit strategies</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model: MatryoshkaTransformer):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict_with_budget(</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        input_ids: torch.Tensor,</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        attention_mask: Optional[torch.Tensor] <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        flop_budget: <span class="bu">float</span> <span class="op">=</span> <span class="fl">1.0</span>,  <span class="co"># Fraction of full model FLOPs</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        confidence_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.8</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a><span class="co">        Predict with computational budget constraint</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        max_layer <span class="op">=</span> <span class="bu">int</span>(<span class="va">self</span>.model.n_layers <span class="op">*</span> flop_budget) <span class="op">-</span> <span class="dv">1</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.model(</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            input_ids<span class="op">=</span>input_ids,</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            attention_mask<span class="op">=</span>attention_mask,</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>            confidence_threshold<span class="op">=</span>confidence_threshold,</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>            max_exit_layer<span class="op">=</span>max_layer</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate actual computation used</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        layers_used <span class="op">=</span> outputs[<span class="st">'total_layers_used'</span>]</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>        actual_budget <span class="op">=</span> layers_used <span class="op">/</span> <span class="va">self</span>.model.n_layers</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>            <span class="op">**</span>outputs,</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>            <span class="st">'computational_savings'</span>: <span class="fl">1.0</span> <span class="op">-</span> actual_budget,</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'flops_used'</span>: actual_budget</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict_with_latency_constraint(</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>,</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        input_ids: torch.Tensor,</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        attention_mask: Optional[torch.Tensor] <span class="op">=</span> <span class="va">None</span>,</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        max_latency_ms: <span class="bu">float</span> <span class="op">=</span> <span class="fl">100.0</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    ) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a><span class="co">        Predict with latency constraint (simplified)</span></span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># This is a simplified version - in practice, you'd profile</span></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># actual inference times for different exit points</span></span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>        estimated_time_per_layer <span class="op">=</span> <span class="fl">10.0</span>  <span class="co"># ms</span></span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>        max_layers <span class="op">=</span> <span class="bu">int</span>(max_latency_ms <span class="op">/</span> estimated_time_per_layer)</span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.predict_with_budget(</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>            input_ids<span class="op">=</span>input_ids,</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>            attention_mask<span class="op">=</span>attention_mask,</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>            flop_budget<span class="op">=</span>max_layers <span class="op">/</span> <span class="va">self</span>.model.n_layers</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>        )</span></code></pre></div></div>
</section>
<section id="usage-example" class="level3">
<h3 class="anchored" data-anchor-id="usage-example" id="usage-example">5. Usage Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize model</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MatryoshkaTransformer(</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    vocab_size<span class="op">=</span><span class="dv">30000</span>,</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    d_model<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    n_heads<span class="op">=</span><span class="dv">8</span>,</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    n_layers<span class="op">=</span><span class="dv">12</span>,</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    n_classes<span class="op">=</span><span class="dv">2</span>,</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    exit_layers<span class="op">=</span>[<span class="dv">2</span>, <span class="dv">5</span>, <span class="dv">8</span>, <span class="dv">11</span>]  <span class="co"># Exit points</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Training setup</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> MatryoshkaTrainer(model)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop (simplified)</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    losses <span class="op">=</span> trainer.train_step(batch, optimizer)</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Total loss: </span><span class="sc">{</span>losses[<span class="st">'total'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Inference</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>inference_engine <span class="op">=</span> AdaptiveInference(model)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Example: Predict with 50% computational budget</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> inference_engine.predict_with_budget(</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    input_ids<span class="op">=</span>sample_input,</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    flop_budget<span class="op">=</span><span class="fl">0.5</span>,</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    confidence_threshold<span class="op">=</span><span class="fl">0.85</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Prediction: </span><span class="sc">{</span>result[<span class="st">'logits'</span>]<span class="sc">.</span>argmax(<span class="op">-</span><span class="dv">1</span>)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Computational savings: </span><span class="sc">{</span>result[<span class="st">'computational_savings'</span>]<span class="sc">:.2%}</span><span class="ss">"</span>)</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Exited at layer: </span><span class="sc">{</span>result[<span class="st">'exit_layer'</span>]<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="dynamic-confidence-thresholds" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-confidence-thresholds" id="dynamic-confidence-thresholds">1. Dynamic Confidence Thresholds</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DynamicThresholdStrategy:</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Dynamically adjust confidence thresholds based on input characteristics</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_threshold: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.8</span>):</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_threshold <span class="op">=</span> base_threshold</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_threshold(<span class="va">self</span>, input_ids: torch.Tensor, layer: <span class="bu">int</span>) <span class="op">-&gt;</span> <span class="bu">float</span>:</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a><span class="co">        Compute dynamic threshold based on input and layer</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Example: Lower threshold for longer sequences</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        seq_len <span class="op">=</span> input_ids.shape[<span class="dv">1</span>]</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        length_factor <span class="op">=</span> <span class="fl">1.0</span> <span class="op">-</span> (seq_len <span class="op">-</span> <span class="dv">50</span>) <span class="op">/</span> <span class="dv">500</span>  <span class="co"># Adjust based on length</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Example: Higher threshold for earlier layers</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        layer_factor <span class="op">=</span> <span class="fl">1.0</span> <span class="op">+</span> (<span class="fl">0.1</span> <span class="op">*</span> (<span class="dv">6</span> <span class="op">-</span> layer))  <span class="co"># Stricter for early exits</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.base_threshold <span class="op">*</span> length_factor <span class="op">*</span> layer_factor</span></code></pre></div></div>
</section>
<section id="ensemble-early-exits" class="level3">
<h3 class="anchored" data-anchor-id="ensemble-early-exits" id="ensemble-early-exits">2. Ensemble Early Exits</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EnsembleMatryoshka(nn.Module):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Ensemble multiple exit predictions for better accuracy</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, base_model: MatryoshkaTransformer):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_model <span class="op">=</span> base_model</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ensemble_weights <span class="op">=</span> nn.Parameter(torch.ones(<span class="bu">len</span>(base_model.exit_layers) <span class="op">+</span> <span class="dv">1</span>))</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, input_ids: torch.Tensor, <span class="op">**</span>kwargs) <span class="op">-&gt;</span> <span class="bu">dict</span>:</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.base_model(input_ids, return_all_exits<span class="op">=</span><span class="va">True</span>, <span class="op">**</span>kwargs)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Ensemble all available predictions</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        all_logits <span class="op">=</span> outputs[<span class="st">'exit_predictions'</span>] <span class="op">+</span> [outputs[<span class="st">'logits'</span>]]</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        weights <span class="op">=</span> F.softmax(<span class="va">self</span>.ensemble_weights, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        ensemble_logits <span class="op">=</span> <span class="bu">sum</span>(w <span class="op">*</span> logits <span class="cf">for</span> w, logits <span class="kw">in</span> <span class="bu">zip</span>(weights, all_logits))</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>            <span class="op">**</span>outputs,</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            <span class="st">'ensemble_logits'</span>: ensemble_logits</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization-tips" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization-tips" id="performance-optimization-tips">Performance Optimization Tips</h2>
<ol type="1">
<li><strong>Layer Selection</strong>: Choose exit layers strategically - too many exits can hurt training</li>
<li><strong>Loss Weighting</strong>: Start with lower weights for early exits, increase gradually</li>
<li><strong>Confidence Calibration</strong>: Use temperature scaling to calibrate exit confidences</li>
<li><strong>Batch Processing</strong>: Process samples with similar complexity together</li>
<li><strong>Caching</strong>: Cache intermediate representations for multiple exit strategies</li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Matryoshka Transformers offer a powerful way to build efficient models that can adapt their computational cost at inference time. The key to success is careful tuning of exit strategies, loss weights, and confidence thresholds for your specific use case.</p>
<p>This implementation provides a solid foundation that you can extend with additional features like cascaded exits, uncertainty estimation, or task-specific adaptations.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[The Mathematics Behind Matryoshka Transformers]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-math/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-math/</guid>
      <pubDate>Sat, 28 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="the-mathematics-behind-matryoshka-transformers" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-math/matmath.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Matryoshka Transformers represent a significant advancement in adaptive neural network architectures, inspired by the Russian nesting dolls (Matryoshka dolls) where smaller models are nested within larger ones. This architecture enables dynamic inference with variable computational costs while maintaining high performance across different resource constraints.</p>
</section>
<section id="core-mathematical-framework" class="level2">
<h2 class="anchored" data-anchor-id="core-mathematical-framework" id="core-mathematical-framework">Core Mathematical Framework</h2>
<section id="nested-representation-learning" class="level3">
<h3 class="anchored" data-anchor-id="nested-representation-learning" id="nested-representation-learning">1. Nested Representation Learning</h3>
<p>The fundamental principle of Matryoshka Transformers lies in learning nested representations where smaller models are subsets of larger ones. Given a transformer with hidden dimension <span class="math inline">\(d\)</span>, we define a sequence of nested dimensions:</p>
<p><span class="math display">\[
d_1 &lt; d_2 &lt; d_3 &lt; \ldots &lt; d_k = d
\]</span></p>
<p>For each layer <span class="math inline">\(l\)</span> and nesting level <span class="math inline">\(i\)</span>, the hidden state <span class="math inline">\(h^{(l,i)}\)</span> is defined as:</p>
<p><span class="math display">\[
h^{(l,i)} = h^{(l)}[:d_i]
\]</span></p>
<p>where <span class="math inline">\(h^{(l)}[:d_i]\)</span> represents the first <span class="math inline">\(d_i\)</span> dimensions of the full hidden state <span class="math inline">\(h^{(l)}\)</span> .</p>
</section>
<section id="multi-scale-attention-mechanism" class="level3">
<h3 class="anchored" data-anchor-id="multi-scale-attention-mechanism" id="multi-scale-attention-mechanism">2. Multi-Scale Attention Mechanism</h3>
<p>The attention mechanism is modified to operate across multiple scales simultaneously. For a given layer, the multi-scale attention is computed as:</p>
<p><span class="math display">\[
\text{MultiScaleAttention}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_k)
\]</span></p>
<p>where each head <span class="math inline">\(\text{head}_i\)</span> operates on the nested representation of dimension <span class="math inline">\(d_i\)</span>:</p>
<p><span class="math display">\[
\text{head}_i = \text{Attention}(Q[:d_i], K[:d_i], V[:d_i])
\]</span></p>
<p>The attention weights are computed using the scaled dot-product mechanism:</p>
<p><span class="math display">\[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
\]</span></p>
</section>
<section id="nested-loss-function" class="level3">
<h3 class="anchored" data-anchor-id="nested-loss-function" id="nested-loss-function">3. Nested Loss Function</h3>
<p>The training objective incorporates losses at multiple scales to ensure that smaller nested models perform well independently. The total loss is:</p>
<p><span class="math display">\[
\mathcal{L}_{\text{total}} = \sum_{i=1}^k \alpha_i \cdot \mathcal{L}(f_i(x), y)
\]</span></p>
<p>where:</p>
<ul>
<li><span class="math inline">\(f_i(x)\)</span> is the prediction using the first <span class="math inline">\(d_i\)</span> dimensions<br>
</li>
<li><span class="math inline">\(\mathcal{L}(f_i(x), y)\)</span> is the task-specific loss (e.g., cross-entropy)<br>
</li>
<li><span class="math inline">\(\alpha_i\)</span> are weighting coefficients that balance the importance of different scales</li>
</ul>
</section>
<section id="progressive-training-strategy" class="level3">
<h3 class="anchored" data-anchor-id="progressive-training-strategy" id="progressive-training-strategy">4. Progressive Training Strategy</h3>
<p>The training process follows a progressive strategy where smaller models are trained first, and larger models build upon them. The parameter update rule is:</p>
<p><span class="math display">\[
\theta_i^{(t+1)} = \theta_i^{(t)} - \eta \cdot \nabla_{\theta_i} \left[ \sum_{j=i}^k \alpha_j \cdot \mathcal{L}(f_j(x), y) \right]
\]</span></p>
<p>This ensures that parameters contributing to smaller models receive gradients from all larger models that contain them.</p>
</section>
</section>
<section id="mathematical-properties" class="level2">
<h2 class="anchored" data-anchor-id="mathematical-properties" id="mathematical-properties">Mathematical Properties</h2>
<section id="representation-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="representation-efficiency" id="representation-efficiency">1. Representation Efficiency</h3>
<p>The nested structure provides computational efficiency with a complexity reduction factor. For a model with <span class="math inline">\(n\)</span> parameters and nesting levels with dimensions <span class="math inline">\([d_1, d_2, \ldots, d_k]\)</span>, the computational complexity for the smallest model is:</p>
<p><span class="math display">\[
O\left(n \cdot \frac{d_1}{d}\right) \quad \text{compared to} \quad O(n) \quad \text{for the full model}
\]</span></p>
</section>
<section id="information-preservation" class="level3">
<h3 class="anchored" data-anchor-id="information-preservation" id="information-preservation">2. Information Preservation</h3>
<p>The mathematical guarantee of information preservation is achieved through the constraint that larger models must contain all information from smaller models. This is formalized as:</p>
<p><span class="math display">\[
I(Y; h^{(l,i)}) \leq I(Y; h^{(l,j)}) \quad \text{for } i &lt; j
\]</span></p>
<p>where <span class="math inline">\(I(\cdot\,;\,\cdot)\)</span> denotes mutual information between the representation and target <span class="math inline">\(Y\)</span>.</p>
</section>
<section id="gradient-flow-analysis" class="level3">
<h3 class="anchored" data-anchor-id="gradient-flow-analysis" id="gradient-flow-analysis">3. Gradient Flow Analysis</h3>
<p>The gradient flow through nested structures follows a hierarchical pattern. For parameter <code>θᵢ</code> contributing to representation dimension <code>dᵢ</code>, the gradient magnitude satisfies:</p>
<p><span class="math display">\[
\|\nabla_{\theta_i} \mathcal{L}_{\text{total}}\|_2 \geq \alpha_i \cdot \|\nabla_{\theta_i} \mathcal{L}(f_i(x), y)\|_2
\]</span></p>
<p>This ensures that smaller models receive sufficient gradient signal during training.</p>
</section>
</section>
<section id="layer-wise-mathematical-operations" class="level2">
<h2 class="anchored" data-anchor-id="layer-wise-mathematical-operations" id="layer-wise-mathematical-operations">Layer-wise Mathematical Operations</h2>
<section id="nested-feed-forward-networks" class="level3">
<h3 class="anchored" data-anchor-id="nested-feed-forward-networks" id="nested-feed-forward-networks">1. Nested Feed-Forward Networks</h3>
<p>The feed-forward network in each transformer layer is modified to support nested computation:</p>
<p><span class="math display">\[
\text{FFN}^{(i)}(x) = \max(0,\ x W_1^{(i)} + b_1^{(i)}) W_2^{(i)} + b_2^{(i)}
\]</span></p>
<p>where <span class="math inline">\(W_1^{(i)} \in \mathbb{R}^{d_i \times d_{\text{mid}}}\)</span> and <span class="math inline">\(W_2^{(i)} \in \mathbb{R}^{d_{\text{mid}} \times d_i}\)</span> are the weight matrices for the <span class="math inline">\(i\)</span>-th nesting level.</p>
</section>
<section id="layer-normalization-adaptation" class="level3">
<h3 class="anchored" data-anchor-id="layer-normalization-adaptation" id="layer-normalization-adaptation">2. Layer Normalization Adaptation</h3>
<p>Layer normalization is applied independently at each nesting level:</p>
<p><span class="math display">\[
\text{LayerNorm}^{(i)}(x) = \gamma_i \cdot \frac{x - \mu_i}{\sigma_i} + \beta_i
\]</span></p>
<p>where <span class="math inline">\(\mu_i\)</span> and <span class="math inline">\(\sigma_i\)</span> are computed over the first <span class="math inline">\(d_i\)</span> dimensions.</p>
</section>
<section id="positional-encoding" class="level3">
<h3 class="anchored" data-anchor-id="positional-encoding" id="positional-encoding">3. Positional Encoding</h3>
<p>Positional encodings are extended to support nested dimensions:</p>
<p><span class="math display">\[
\text{PE}^{(i)}(\text{pos}, 2j) = \sin\left(\frac{\text{pos}}{10000^{\frac{2j}{d_i}}}\right)
\]</span> <span class="math display">\[
\text{PE}^{(i)}(\text{pos}, 2j+1) = \cos\left(\frac{\text{pos}}{10000^{\frac{2j}{d_i}}}\right)
\]</span></p>
<p>for <span class="math inline">\(j \in [0, \frac{d_i}{2})\)</span></p>
</section>
</section>
<section id="optimization-considerations" class="level2">
<h2 class="anchored" data-anchor-id="optimization-considerations" id="optimization-considerations">Optimization Considerations</h2>
<section id="learning-rate-scheduling" class="level3">
<h3 class="anchored" data-anchor-id="learning-rate-scheduling" id="learning-rate-scheduling">1. Learning Rate Scheduling</h3>
<p>Different nesting levels may require different learning rates. The adaptive learning rate is:</p>
<p><span class="math display">\[
\eta_i = \eta_0 \cdot \sqrt{\frac{d}{d_i}} \cdot \lambda_i
\]</span></p>
<p>where <span class="math inline">\(\lambda_i\)</span> is a level-specific scaling factor.</p>
</section>
<section id="regularization" class="level3">
<h3 class="anchored" data-anchor-id="regularization" id="regularization">2. Regularization</h3>
<p>Regularization is applied to encourage similarity between nested representations:</p>
<p><span class="math display">\[
\mathcal{L}_{\text{reg}} = \sum_{i=1}^{k-1} \beta \cdot \| h^{(l,i+1)}[:d_i] - h^{(l,i)} \|_2^2
\]</span></p>
<p>This term encourages consistency across different scales.</p>
</section>
</section>
<section id="theoretical-analysis" class="level2">
<h2 class="anchored" data-anchor-id="theoretical-analysis" id="theoretical-analysis">Theoretical Analysis</h2>
<section id="approximation-theory" class="level3">
<h3 class="anchored" data-anchor-id="approximation-theory" id="approximation-theory">1. Approximation Theory</h3>
<p>The approximation error for a nested model of dimension <code>dᵢ</code> is bounded by:</p>
<p><span class="math display">\[
|f(x) - f_i(x)| \leq C \cdot \sqrt{\frac{d - d_i}{d}} \cdot \|x\|_2
\]</span></p>
<p>where <span class="math inline">\(C\)</span> is a problem-dependent constant.</p>
</section>
<section id="generalization-bounds" class="level3">
<h3 class="anchored" data-anchor-id="generalization-bounds" id="generalization-bounds">2. Generalization Bounds</h3>
<p>The generalization bound for nested models follows:</p>
<p><span class="math display">\[
P\left(|R(f_i) - \hat{R}(f_i)| &gt; \varepsilon\right) \leq 2 \exp\left(-\frac{2n \varepsilon^2}{d_i/d}\right)
\]</span></p>
<p>where <span class="math inline">\(R(f_i)\)</span> is the true risk and <span class="math inline">\(\hat{R}(f_i)\)</span> is the empirical risk.</p>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="memory-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficiency" id="memory-efficiency">1. Memory Efficiency</h3>
<p>The memory footprint scales with the largest model while enabling inference at multiple scales:</p>
<p><span class="math display">\[
\text{Memory} = O(d \cdot L) \quad \text{where } L \text{ is the number of layers}
\]</span></p>
</section>
<section id="computational-flexibility" class="level3">
<h3 class="anchored" data-anchor-id="computational-flexibility" id="computational-flexibility">2. Computational Flexibility</h3>
<p>The inference cost can be dynamically adjusted based on computational budget:</p>
<p><span class="math display">\[
\text{FLOPs}^{(i)} = O(d_i^2 \cdot L \cdot N)
\]</span></p>
<p>where <span class="math inline">\(N\)</span> is the sequence length.</p>
</section>
</section>
<section id="applications-and-extensions" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-extensions" id="applications-and-extensions">Applications and Extensions</h2>
<section id="adaptive-inference" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-inference" id="adaptive-inference">1. Adaptive Inference</h3>
<p>The mathematical framework enables adaptive inference where the model can exit early based on confidence measures:</p>
<p><span class="math display">\[
\text{Exit\_Condition} = P(\hat{y}_i \mid x) &gt; \tau_i
\]</span></p>
<p>where <span class="math inline">\(\tau_i\)</span> is a confidence threshold for level <span class="math inline">\(i\)</span>.</p>
</section>
<section id="distillation-integration" class="level3">
<h3 class="anchored" data-anchor-id="distillation-integration" id="distillation-integration">2. Distillation Integration</h3>
<p>Knowledge distillation can be integrated into the nested framework:</p>
<p><span class="math display">\[
\mathcal{L}_{\text{distill}} = \sum_{i=1}^{k-1} \gamma \cdot \text{KL}\left(\text{softmax}\left(\frac{z_i}{T}\right),\ \text{softmax}\left(\frac{z_k}{T}\right)\right)
\]</span></p>
<p>where <span class="math inline">\(z_i\)</span> are the logits from the <span class="math inline">\(i\)</span>-th level and <span class="math inline">\(T\)</span> is the temperature parameter.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Matryoshka Transformers provide a mathematically rigorous framework for creating adaptive neural networks with nested computational capabilities. The mathematical foundations ensure efficient training, inference flexibility, and theoretical guarantees on performance across different scales. This architecture represents a significant step toward more efficient and adaptable transformer models for real-world applications with varying computational constraints.</p>
</section>
<section id="further-reading" class="level2">
<h2 class="anchored" data-anchor-id="further-reading" id="further-reading">Further Reading</h2>
<ul>
<li>Progressive Neural Architecture Search</li>
<li>Adaptive Neural Networks</li>
<li>Multi-Scale Deep Learning</li>
<li>Efficient Transformer Architectures</li>
</ul>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Matryoshka Transformer for Vision Language Models]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-transformer/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-transformer/</guid>
      <pubDate>Sat, 28 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="matryoshka-transformer-for-vision-language-models" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/matryoshka/matryoshka-transformer/matryoshka.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The Matryoshka Transformer represents a significant advancement in the architecture of vision language models (VLMs), drawing inspiration from the nested structure of Russian Matryoshka dolls. This innovative approach addresses one of the fundamental challenges in multimodal AI: efficiently processing and integrating visual and textual information at multiple scales and resolutions.</p>
<p>Named after the traditional Russian nesting dolls where each doll contains a smaller version of itself, the Matryoshka Transformer employs a nested, hierarchical structure that allows for flexible and adaptive processing of multimodal inputs. This architecture enables models to handle varying computational budgets while maintaining competitive performance across different tasks.</p>
</section>
<section id="core-architecture" class="level2">
<h2 class="anchored" data-anchor-id="core-architecture" id="core-architecture">Core Architecture</h2>
<section id="nested-representation-learning" class="level3">
<h3 class="anchored" data-anchor-id="nested-representation-learning" id="nested-representation-learning">Nested Representation Learning</h3>
<p>The Matryoshka Transformer’s primary innovation lies in its ability to learn nested representations at multiple granularities simultaneously. Unlike traditional transformers that process information at a fixed resolution, this architecture creates a hierarchy of representations where each level contains increasingly detailed information.</p>
<p>The model operates on the principle that useful representations can be extracted at various levels of detail. A coarse representation might capture global semantic information about an image and its associated text, while finer representations preserve local details and nuanced relationships between visual and textual elements.</p>
</section>
<section id="multi-scale-processing" class="level3">
<h3 class="anchored" data-anchor-id="multi-scale-processing" id="multi-scale-processing">Multi-Scale Processing</h3>
<p>The architecture implements multi-scale processing through a series of nested attention mechanisms. Each “doll” in the Matryoshka structure corresponds to a different scale of processing:</p>
<ul>
<li><strong>Outer layers</strong> handle global context and high-level semantic relationships</li>
<li><strong>Middle layers</strong> process regional features and cross-modal alignments<br>
</li>
<li><strong>Inner layers</strong> focus on fine-grained details and local feature interactions</li>
</ul>
<p>This hierarchical approach allows the model to adaptively allocate computational resources based on the complexity of the input and the requirements of the downstream task.</p>
</section>
<section id="adaptive-computation" class="level3">
<h3 class="anchored" data-anchor-id="adaptive-computation" id="adaptive-computation">Adaptive Computation</h3>
<p>One of the key advantages of the Matryoshka Transformer is its support for adaptive computation. The nested structure enables early exit strategies where simpler inputs can be processed using only the outer layers, while complex multimodal scenarios can leverage the full depth of the nested architecture.</p>
<p>This adaptive capability is particularly valuable in real-world applications where computational resources may be limited or where different levels of accuracy are acceptable for different types of queries.</p>
</section>
</section>
<section id="vision-language-integration" class="level2">
<h2 class="anchored" data-anchor-id="vision-language-integration" id="vision-language-integration">Vision-Language Integration</h2>
<section id="cross-modal-attention-mechanisms" class="level3">
<h3 class="anchored" data-anchor-id="cross-modal-attention-mechanisms" id="cross-modal-attention-mechanisms">Cross-Modal Attention Mechanisms</h3>
<p>The Matryoshka Transformer employs sophisticated cross-modal attention mechanisms that operate at each level of the nested hierarchy. These mechanisms enable the model to establish correspondences between visual and textual elements at multiple scales:</p>
<ul>
<li><strong>Global attention</strong> links high-level concepts between images and text</li>
<li><strong>Regional attention</strong> connects specific image regions with relevant text segments</li>
<li><strong>Local attention</strong> establishes fine-grained correspondences between visual features and individual words or phrases</li>
</ul>
</section>
<section id="hierarchical-feature-fusion" class="level3">
<h3 class="anchored" data-anchor-id="hierarchical-feature-fusion" id="hierarchical-feature-fusion">Hierarchical Feature Fusion</h3>
<p>Feature fusion in the Matryoshka Transformer occurs hierarchically, with information flowing both within and between the nested levels. This design enables the model to build rich, multi-scale representations that capture both global context and local details.</p>
<p>The hierarchical fusion process ensures that global context informs local processing while local details can influence global understanding, creating a more coherent and comprehensive multimodal representation.</p>
</section>
</section>
<section id="training-methodology" class="level2">
<h2 class="anchored" data-anchor-id="training-methodology" id="training-methodology">Training Methodology</h2>
<section id="multi-objective-learning" class="level3">
<h3 class="anchored" data-anchor-id="multi-objective-learning" id="multi-objective-learning">Multi-Objective Learning</h3>
<p>Training a Matryoshka Transformer involves optimizing multiple objectives simultaneously across different levels of the nested hierarchy. This multi-objective approach ensures that each level of the architecture learns meaningful representations appropriate to its scale.</p>
<p>The training process typically involves:</p>
<ul>
<li><strong>Reconstruction objectives</strong> at each level to ensure information preservation</li>
<li><strong>Cross-modal alignment objectives</strong> to maintain correspondence between vision and language</li>
<li><strong>Task-specific objectives</strong> for downstream applications</li>
<li><strong>Efficiency objectives</strong> to encourage effective use of computational resources</li>
</ul>
</section>
<section id="progressive-training-strategies" class="level3">
<h3 class="anchored" data-anchor-id="progressive-training-strategies" id="progressive-training-strategies">Progressive Training Strategies</h3>
<p>Many implementations employ progressive training strategies where the model is initially trained on simpler, coarser representations before gradually incorporating finer details. This approach helps stabilize training and ensures that the hierarchical structure develops properly.</p>
<p>The progressive training typically follows a curriculum where:</p>
<ol type="1">
<li>Initial training focuses on global semantic alignment</li>
<li>Intermediate stages introduce regional correspondences</li>
<li>Final stages refine local feature interactions</li>
</ol>
</section>
</section>
<section id="applications-and-use-cases" class="level2">
<h2 class="anchored" data-anchor-id="applications-and-use-cases" id="applications-and-use-cases">Applications and Use Cases</h2>
<section id="image-captioning" class="level3">
<h3 class="anchored" data-anchor-id="image-captioning" id="image-captioning">Image Captioning</h3>
<p>In image captioning tasks, the Matryoshka Transformer can generate descriptions at varying levels of detail. The outer layers might produce general descriptions, while inner layers can add specific details about objects, relationships, and attributes visible in the image.</p>
</section>
<section id="visual-question-answering" class="level3">
<h3 class="anchored" data-anchor-id="visual-question-answering" id="visual-question-answering">Visual Question Answering</h3>
<p>For visual question answering, the nested structure allows the model to adaptively allocate attention based on question complexity. Simple questions about global image properties can be answered using outer layers, while detailed questions requiring fine-grained visual analysis can leverage the full nested hierarchy.</p>
</section>
<section id="multimodal-retrieval" class="level3">
<h3 class="anchored" data-anchor-id="multimodal-retrieval" id="multimodal-retrieval">Multimodal Retrieval</h3>
<p>The hierarchical representations learned by the Matryoshka Transformer are particularly well-suited for multimodal retrieval tasks. The model can perform coarse-grained retrieval using global representations and then refine results using more detailed features as needed.</p>
</section>
<section id="real-time-applications" class="level3">
<h3 class="anchored" data-anchor-id="real-time-applications" id="real-time-applications">Real-Time Applications</h3>
<p>The adaptive computation capabilities make the Matryoshka Transformer ideal for real-time applications where processing speed is critical. The model can automatically adjust its computational depth based on available resources and accuracy requirements.</p>
</section>
</section>
<section id="advantages-and-benefits" class="level2">
<h2 class="anchored" data-anchor-id="advantages-and-benefits" id="advantages-and-benefits">Advantages and Benefits</h2>
<section id="computational-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h3>
<p>The nested structure enables significant computational savings by allowing early termination for simpler inputs. This adaptive processing can reduce inference time by 30-50% on average while maintaining comparable accuracy to full-depth processing.</p>
</section>
<section id="scalability" class="level3">
<h3 class="anchored" data-anchor-id="scalability" id="scalability">Scalability</h3>
<p>The hierarchical design naturally scales to different computational budgets and hardware constraints. The same model can be deployed across various platforms, from mobile devices to high-performance servers, simply by adjusting the depth of processing.</p>
</section>
<section id="robustness" class="level3">
<h3 class="anchored" data-anchor-id="robustness" id="robustness">Robustness</h3>
<p>The multi-scale representations provide increased robustness to variations in input quality, resolution, and complexity. The model can gracefully degrade performance rather than failing catastrophically when faced with challenging inputs.</p>
</section>
<section id="interpretability" class="level3">
<h3 class="anchored" data-anchor-id="interpretability" id="interpretability">Interpretability</h3>
<p>The nested structure offers improved interpretability by providing insights into the model’s decision-making process at different scales. Researchers and practitioners can examine how global context influences local processing and vice versa.</p>
</section>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<section id="training-complexity" class="level3">
<h3 class="anchored" data-anchor-id="training-complexity" id="training-complexity">Training Complexity</h3>
<p>Training Matryoshka Transformers is more complex than traditional architectures due to the need to optimize multiple objectives across different scales simultaneously. This complexity can lead to training instability and requires careful hyperparameter tuning.</p>
</section>
<section id="memory-requirements" class="level3">
<h3 class="anchored" data-anchor-id="memory-requirements" id="memory-requirements">Memory Requirements</h3>
<p>While the model offers computational efficiency during inference, training requires maintaining gradients and activations across all nested levels, potentially increasing memory requirements during the training phase.</p>
</section>
<section id="architecture-design" class="level3">
<h3 class="anchored" data-anchor-id="architecture-design" id="architecture-design">Architecture Design</h3>
<p>Determining the optimal number of nested levels and their respective capacities requires extensive experimentation and domain expertise. The architecture choices significantly impact both performance and efficiency.</p>
</section>
</section>
<section id="recent-developments-and-research" class="level2">
<h2 class="anchored" data-anchor-id="recent-developments-and-research" id="recent-developments-and-research">Recent Developments and Research</h2>
<section id="architectural-variants" class="level3">
<h3 class="anchored" data-anchor-id="architectural-variants" id="architectural-variants">Architectural Variants</h3>
<p>Recent research has explored various architectural variants of the Matryoshka Transformer, including:</p>
<ul>
<li><strong>Sparse Matryoshka models</strong> that use sparse attention patterns to further reduce computational costs</li>
<li><strong>Dynamic Matryoshka architectures</strong> that can adjust their structure based on input characteristics</li>
<li><strong>Hybrid approaches</strong> that combine Matryoshka principles with other efficient architectures</li>
</ul>
</section>
<section id="performance-improvements" class="level3">
<h3 class="anchored" data-anchor-id="performance-improvements" id="performance-improvements">Performance Improvements</h3>
<p>Ongoing research focuses on improving the performance of Matryoshka Transformers through:</p>
<ul>
<li>Better training strategies and curriculum design</li>
<li>Novel attention mechanisms optimized for nested processing</li>
<li>Advanced feature fusion techniques</li>
<li>Integration with other efficiency-focused innovations</li>
</ul>
</section>
<section id="domain-specific-adaptations" class="level3">
<h3 class="anchored" data-anchor-id="domain-specific-adaptations" id="domain-specific-adaptations">Domain-Specific Adaptations</h3>
<p>Researchers are developing domain-specific adaptations of the Matryoshka Transformer for applications such as:</p>
<ul>
<li>Medical imaging and diagnostic tasks</li>
<li>Autonomous driving and robotics</li>
<li>Scientific image analysis</li>
<li>Creative content generation</li>
</ul>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="framework-support" class="level3">
<h3 class="anchored" data-anchor-id="framework-support" id="framework-support">Framework Support</h3>
<p>Most major deep learning frameworks now provide support for implementing Matryoshka Transformers, with specialized libraries offering pre-built components for common architectural patterns.</p>
</section>
<section id="hardware-optimization" class="level3">
<h3 class="anchored" data-anchor-id="hardware-optimization" id="hardware-optimization">Hardware Optimization</h3>
<p>Modern hardware accelerators are increasingly optimized for the types of hierarchical computations required by Matryoshka Transformers, with specialized support for adaptive depth processing.</p>
</section>
<section id="deployment-strategies" class="level3">
<h3 class="anchored" data-anchor-id="deployment-strategies" id="deployment-strategies">Deployment Strategies</h3>
<p>Successful deployment of Matryoshka Transformers requires careful consideration of:</p>
<ul>
<li>Dynamic batching strategies for variable-depth processing</li>
<li>Memory management across nested levels</li>
<li>Load balancing for adaptive computation</li>
<li>Monitoring and profiling tools for performance optimization</li>
</ul>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="integration-with-large-language-models" class="level3">
<h3 class="anchored" data-anchor-id="integration-with-large-language-models" id="integration-with-large-language-models">Integration with Large Language Models</h3>
<p>Future research directions include integrating Matryoshka principles with large language models to create more efficient and capable multimodal AI systems. This integration could enable better handling of complex reasoning tasks that require both visual and textual understanding.</p>
</section>
<section id="automated-architecture-search" class="level3">
<h3 class="anchored" data-anchor-id="automated-architecture-search" id="automated-architecture-search">Automated Architecture Search</h3>
<p>Automated neural architecture search techniques are being developed to optimize Matryoshka Transformer designs for specific tasks and computational constraints, reducing the manual effort required for architecture design.</p>
</section>
<section id="continual-learning" class="level3">
<h3 class="anchored" data-anchor-id="continual-learning" id="continual-learning">Continual Learning</h3>
<p>The nested structure of Matryoshka Transformers shows promise for continual learning scenarios where models need to adapt to new tasks while preserving previously learned capabilities.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The Matryoshka Transformer represents a significant step forward in the development of efficient and scalable vision language models. By embracing the principle of nested, hierarchical processing, this architecture addresses many of the computational and scalability challenges facing modern multimodal AI systems.</p>
<p>The ability to adaptively allocate computational resources while maintaining high performance across diverse tasks makes the Matryoshka Transformer particularly valuable for real-world applications. As research continues to refine and extend this architectural approach, we can expect to see even more sophisticated and efficient multimodal AI systems that can handle the growing complexity and scale of vision-language tasks.</p>
<p>The nested doll metaphor that inspired this architecture serves as a powerful reminder that effective AI systems often benefit from hierarchical organization that mirrors the multi-scale nature of human perception and understanding. As we continue to push the boundaries of what’s possible with vision language models, the Matryoshka Transformer provides a compelling framework for building more efficient, scalable, and capable multimodal AI systems.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Attention Mechanisms: Transformers vs Convolutional Neural Networks]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-article/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-article/</guid>
      <pubDate>Fri, 27 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="attention-mechanisms-transformers-vs-convolutional-neural-networks" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-article/attention.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Attention mechanisms have revolutionized deep learning by enabling models to focus on relevant parts of the input data. While originally popularized in Transformers, attention has also been successfully integrated into Convolutional Neural Networks (CNNs). This article explores the fundamental differences, applications, and trade-offs between attention mechanisms in these two architectural paradigms.</p>
</section>
<section id="attention-in-transformers" class="level2">
<h2 class="anchored" data-anchor-id="attention-in-transformers" id="attention-in-transformers">Attention in Transformers</h2>
<section id="core-concept" class="level3">
<h3 class="anchored" data-anchor-id="core-concept" id="core-concept">Core Concept</h3>
<p>The attention mechanism in Transformers is based on the concept of <strong>self-attention</strong> or <strong>scaled dot-product attention</strong>. The fundamental idea is to allow each position in a sequence to attend to all positions in both the input and output sequences.</p>
</section>
<section id="mathematical-foundation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-foundation" id="mathematical-foundation">Mathematical Foundation</h3>
<p>The attention mechanism in Transformers computes attention weights using three key components:</p>
<ul>
<li><strong>Query (Q)</strong>: What information we’re looking for</li>
<li><strong>Key (K)</strong>: What information is available</li>
<li><strong>Value (V)</strong>: The actual information content</li>
</ul>
<p>The attention score is calculated as:</p>
<p><span class="math display">\[
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V
\]</span></p>
<p>Where <code>d_k</code> is the dimension of the key vectors, used for scaling to prevent the softmax function from having extremely small gradients.</p>
</section>
<section id="multi-head-attention" class="level3">
<h3 class="anchored" data-anchor-id="multi-head-attention" id="multi-head-attention">Multi-Head Attention</h3>
<p>Transformers employ <strong>multi-head attention</strong>, which runs multiple attention mechanisms in parallel:</p>
<p><span class="math display">\[
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O
\]</span></p>
<p>Where each <span class="math inline">\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)</span></p>
<p>This allows the model to attend to information from different representation subspaces simultaneously.</p>
</section>
<section id="key-characteristics" class="level3">
<h3 class="anchored" data-anchor-id="key-characteristics" id="key-characteristics">Key Characteristics</h3>
<ol type="1">
<li><strong>Global Context</strong>: Every token can attend to every other token in the sequence</li>
<li><strong>Position Agnostic</strong>: Inherently permutation-invariant (requires positional encoding)</li>
<li><strong>Parallel Processing</strong>: All attention computations can be performed simultaneously</li>
<li><strong>Quadratic Complexity</strong>: O(n²) memory and computational complexity with sequence length</li>
<li><strong>Dynamic Weights</strong>: Attention weights are computed dynamically based on input content</li>
</ol>
</section>
<section id="applications" class="level3">
<h3 class="anchored" data-anchor-id="applications" id="applications">Applications</h3>
<ul>
<li>Natural Language Processing (BERT, GPT, T5)</li>
<li>Computer Vision (Vision Transformer - ViT)</li>
<li>Multimodal tasks (CLIP, DALL-E)</li>
<li>Time series analysis</li>
<li>Graph neural networks</li>
</ul>
</section>
</section>
<section id="attention-in-convolutional-neural-networks" class="level2">
<h2 class="anchored" data-anchor-id="attention-in-convolutional-neural-networks" id="attention-in-convolutional-neural-networks">Attention in Convolutional Neural Networks</h2>
<section id="core-concept-1" class="level3">
<h3 class="anchored" data-anchor-id="core-concept-1" id="core-concept-1">Core Concept</h3>
<p>Attention in CNNs is typically implemented as <strong>channel attention</strong> or <strong>spatial attention</strong> mechanisms that help the network focus on important features or spatial locations. Unlike Transformers, CNN attention is usually applied to feature maps rather than sequence elements.</p>
</section>
<section id="types-of-cnn-attention" class="level3">
<h3 class="anchored" data-anchor-id="types-of-cnn-attention" id="types-of-cnn-attention">Types of CNN Attention</h3>
<section id="channel-attention-se-net-eca-net" class="level4">
<h4 class="anchored" data-anchor-id="channel-attention-se-net-eca-net"><strong>1. Channel Attention (SE-Net, ECA-Net)</strong></h4>
<p>Channel attention mechanisms adaptively recalibrate channel-wise feature responses by modeling interdependencies between channels.</p>
<p><strong>Squeeze-and-Excitation (SE) Block</strong>:</p>
<ol type="1">
<li>Global Average Pooling: <span class="math inline">\(z_c = \frac{1}{H \times W} \sum \sum u_c(i,j)\)</span></li>
<li>Excitation: <span class="math inline">\(s = \sigma(W_2 \, \delta(W_1 z))\)</span></li>
<li>Scale: <span class="math inline">\(\tilde{x}_c = s_c \times u_c\)</span></li>
</ol>
</section>
<section id="spatial-attention-cbam-sam" class="level4">
<h4 class="anchored" data-anchor-id="spatial-attention-cbam-sam"><strong>2. Spatial Attention (CBAM, SAM)</strong></h4>
<p>Spatial attention focuses on “where” informative parts are located in the feature map.</p>
<p><strong>Spatial Attention Module</strong>:</p>
<ol type="1">
<li>Channel-wise statistics: <span class="math inline">\(F_{\text{avg}},\ F_{\text{max}}\)</span></li>
<li>Convolution: <span class="math inline">\(M_s = \sigma(\text{conv}([F_{\text{avg}}; F_{\text{max}}]))\)</span></li>
<li>Element-wise multiplication: <span class="math inline">\(F' = M_s \otimes F\)</span></li>
</ol>
</section>
<section id="self-attention-in-cnns-non-local-networks" class="level4">
<h4 class="anchored" data-anchor-id="self-attention-in-cnns-non-local-networks"><strong>3. Self-Attention in CNNs (Non-Local Networks)</strong></h4>
<p>Some CNNs incorporate self-attention mechanisms similar to Transformers but adapted for spatial data:</p>
<p><span class="math display">\[
y_i = \frac{1}{C(x)} \sum_j f(x_i, x_j) \, g(x_j)
\]</span></p>
<p>Where <code>f</code> computes affinity between positions <code>i</code> and <code>j</code>, and <code>g</code> computes representation of input at position <code>j</code>.</p>
</section>
</section>
<section id="key-characteristics-1" class="level3">
<h3 class="anchored" data-anchor-id="key-characteristics-1" id="key-characteristics-1">Key Characteristics</h3>
<ol type="1">
<li><strong>Local and Global Context</strong>: Can focus on both local patterns and global dependencies</li>
<li><strong>Spatial Awareness</strong>: Naturally preserves spatial relationships in 2D/3D data</li>
<li><strong>Efficient Computation</strong>: Generally more computationally efficient than Transformer attention</li>
<li><strong>Feature Enhancement</strong>: Primarily used to enhance existing convolutional features</li>
<li><strong>Lightweight</strong>: Usually adds minimal parameters to the base model</li>
</ol>
</section>
<section id="applications-1" class="level3">
<h3 class="anchored" data-anchor-id="applications-1" id="applications-1">Applications</h3>
<ul>
<li>Image classification (ResNet + SE, EfficientNet)</li>
<li>Object detection (Feature Pyramid Networks with attention)</li>
<li>Semantic segmentation (attention-based skip connections)</li>
<li>Medical image analysis</li>
<li>Video understanding</li>
</ul>
</section>
</section>
<section id="comparative-analysis" class="level2">
<h2 class="anchored" data-anchor-id="comparative-analysis" id="comparative-analysis">Comparative Analysis</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">Computational Complexity</h3>
<table class="caption-top table">
<colgroup>
<col style="width: 18%">
<col style="width: 47%">
<col style="width: 34%">
</colgroup>
<thead>
<tr class="header">
<th>Aspect</th>
<th>Transformer Attention</th>
<th>CNN Attention</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Time Complexity</strong></td>
<td>O(n²d) for sequence length n</td>
<td>O(HWd) for spatial dimensions H×W</td>
</tr>
<tr class="even">
<td><strong>Space Complexity</strong></td>
<td>O(n²) attention matrix</td>
<td>O(HW) or O(d) depending on type</td>
</tr>
<tr class="odd">
<td><strong>Scalability</strong></td>
<td>Challenging for long sequences</td>
<td>Scales well with image resolution</td>
</tr>
</tbody>
</table>
</section>
<section id="architectural-differences" class="level3">
<h3 class="anchored" data-anchor-id="architectural-differences" id="architectural-differences">Architectural Differences</h3>
<section id="information-flow" class="level4">
<h4 class="anchored" data-anchor-id="information-flow">Information Flow</h4>
<ul>
<li><strong>Transformers</strong>: Global information exchange from the start</li>
<li><strong>CNNs</strong>: Hierarchical feature learning with attention refinement</li>
</ul>
</section>
<section id="inductive-biases" class="level4">
<h4 class="anchored" data-anchor-id="inductive-biases">Inductive Biases</h4>
<ul>
<li><strong>Transformers</strong>: Minimal inductive bias, relies on data and scale</li>
<li><strong>CNNs</strong>: Strong spatial inductive bias through convolution operations</li>
</ul>
</section>
<section id="interpretability" class="level4">
<h4 class="anchored" data-anchor-id="interpretability">Interpretability</h4>
<ul>
<li><strong>Transformers</strong>: Attention weights provide interpretable focus patterns</li>
<li><strong>CNNs</strong>: Channel/spatial attention maps show feature importance</li>
</ul>
</section>
</section>
<section id="performance-characteristics" class="level3">
<h3 class="anchored" data-anchor-id="performance-characteristics" id="performance-characteristics">Performance Characteristics</h3>
<section id="data-efficiency" class="level4">
<h4 class="anchored" data-anchor-id="data-efficiency">Data Efficiency</h4>
<ul>
<li><strong>Transformers</strong>: Require large datasets to learn effectively</li>
<li><strong>CNNs</strong>: More data-efficient due to built-in inductive biases</li>
</ul>
</section>
<section id="generalization" class="level4">
<h4 class="anchored" data-anchor-id="generalization">Generalization</h4>
<ul>
<li><strong>Transformers</strong>: Excel at capturing long-range dependencies</li>
<li><strong>CNNs</strong>: Better at learning local patterns and spatial hierarchies</li>
</ul>
</section>
<section id="training-stability" class="level4">
<h4 class="anchored" data-anchor-id="training-stability">Training Stability</h4>
<ul>
<li><strong>Transformers</strong>: Can be unstable, require careful initialization and learning rates</li>
<li><strong>CNNs</strong>: Generally more stable training dynamics</li>
</ul>
</section>
</section>
</section>
<section id="hybrid-approaches" class="level2">
<h2 class="anchored" data-anchor-id="hybrid-approaches" id="hybrid-approaches">Hybrid Approaches</h2>
<p>Recent research has explored combining both attention mechanisms:</p>
<section id="convnets-with-transformer-blocks" class="level3">
<h3 class="anchored" data-anchor-id="convnets-with-transformer-blocks" id="convnets-with-transformer-blocks">ConvNets with Transformer Blocks</h3>
<ul>
<li><strong>ConvNeXt</strong>: Modernized CNNs inspired by Transformer design principles</li>
<li><strong>CoAtNet</strong>: Combines convolution and self-attention in a unified architecture</li>
</ul>
</section>
<section id="vision-transformers-with-convolutional-elements" class="level3">
<h3 class="anchored" data-anchor-id="vision-transformers-with-convolutional-elements" id="vision-transformers-with-convolutional-elements">Vision Transformers with Convolutional Elements</h3>
<ul>
<li><strong>CvT</strong>: Convolutional Vision Transformer with convolutional token embedding</li>
<li><strong>CeiT</strong>: Incorporating convolutional inductive bias into ViTs</li>
</ul>
</section>
<section id="advantages-of-hybrid-models" class="level3">
<h3 class="anchored" data-anchor-id="advantages-of-hybrid-models" id="advantages-of-hybrid-models">Advantages of Hybrid Models</h3>
<ol type="1">
<li><strong>Best of Both Worlds</strong>: Local pattern recognition + global context modeling</li>
<li><strong>Improved Efficiency</strong>: Reduced computational complexity while maintaining performance</li>
<li><strong>Better Inductive Bias</strong>: Combines spatial awareness with flexible attention</li>
</ol>
</section>
</section>
<section id="use-case-recommendations" class="level2">
<h2 class="anchored" data-anchor-id="use-case-recommendations" id="use-case-recommendations">Use Case Recommendations</h2>
<section id="choose-transformer-attention-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-transformer-attention-when" id="choose-transformer-attention-when">Choose Transformer Attention When:</h3>
<ul>
<li>Working with sequential data (NLP, time series)</li>
<li>Need to model long-range dependencies</li>
<li>Have access to large datasets</li>
<li>Computational resources are abundant</li>
<li>Interpretability of attention patterns is important</li>
</ul>
</section>
<section id="choose-cnn-attention-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-cnn-attention-when" id="choose-cnn-attention-when">Choose CNN Attention When:</h3>
<ul>
<li>Working with spatial data (images, videos)</li>
<li>Limited computational resources</li>
<li>Smaller datasets available</li>
<li>Need faster inference times</li>
<li>Spatial relationships are crucial for the task</li>
</ul>
</section>
<section id="consider-hybrid-approaches-when" class="level3">
<h3 class="anchored" data-anchor-id="consider-hybrid-approaches-when" id="consider-hybrid-approaches-when">Consider Hybrid Approaches When:</h3>
<ul>
<li>Working with complex visual tasks requiring both local and global understanding</li>
<li>Need to balance performance and efficiency</li>
<li>Have moderate computational resources</li>
<li>Want to leverage benefits of both paradigms</li>
</ul>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<p>The field continues to evolve with several promising directions:</p>
<ol type="1">
<li><strong>Efficient Attention</strong>: Linear attention mechanisms for Transformers</li>
<li><strong>Dynamic Attention</strong>: Adaptive attention mechanisms that adjust based on input complexity</li>
<li><strong>Cross-Modal Attention</strong>: Attention mechanisms that work across different data modalities</li>
<li><strong>Learnable Attention Patterns</strong>: Meta-learning approaches for attention mechanism design</li>
<li><strong>Hardware-Optimized Attention</strong>: Attention mechanisms designed for specific hardware accelerators</li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Both Transformer and CNN attention mechanisms serve distinct but complementary purposes in modern deep learning. Transformer attention excels at modeling global dependencies and complex relationships in sequential data, while CNN attention provides efficient feature enhancement for spatial data. The choice between them depends on specific use case requirements, available resources, and the nature of the data being processed.</p>
<p>The ongoing convergence of these approaches through hybrid architectures suggests that the future of attention mechanisms lies not in choosing one over the other, but in thoughtfully combining their strengths to create more powerful and efficient models. As the field continues to advance, we can expect to see more sophisticated attention mechanisms that bridge the gap between these two paradigms while addressing their respective limitations.</p>



</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Attention Mechanisms: Transformers vs CNNs - Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-code/</guid>
      <pubDate>Fri, 27 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="attention-mechanisms-transformers-vs-cnns---complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/attention-mechanisms/attention-code/attention.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Attention mechanisms have revolutionized deep learning by allowing models to focus on relevant parts of input data. While Transformers use self-attention as their core mechanism, CNNs incorporate attention as an enhancement to their convolutional operations.</p>
</section>
<section id="transformer-attention" class="level2">
<h2 class="anchored" data-anchor-id="transformer-attention" id="transformer-attention">Transformer Attention</h2>
<section id="multi-head-self-attention-implementation" class="level3">
<h3 class="anchored" data-anchor-id="multi-head-self-attention-implementation" id="multi-head-self-attention-implementation">Multi-Head Self-Attention Implementation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiHeadAttention(nn.Module):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, num_heads, dropout<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> d_model <span class="op">%</span> num_heads <span class="op">==</span> <span class="dv">0</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_model <span class="op">=</span> d_model</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_heads <span class="op">=</span> num_heads</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.d_k <span class="op">=</span> d_model <span class="op">//</span> num_heads</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear projections for Q, K, V</span></span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.w_q <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.w_k <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.w_v <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.w_o <span class="op">=</span> nn.Linear(d_model, d_model)</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> scaled_dot_product_attention(<span class="va">self</span>, Q, K, V, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a><span class="co">        Compute scaled dot-product attention</span></span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a><span class="co">            Q: Query matrix [batch_size, num_heads, seq_len, d_k]</span></span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a><span class="co">            K: Key matrix [batch_size, num_heads, seq_len, d_k]</span></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a><span class="co">            V: Value matrix [batch_size, num_heads, seq_len, d_k]</span></span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a><span class="co">            mask: Optional mask [batch_size, 1, seq_len, seq_len]</span></span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate attention scores</span></span>
<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> torch.matmul(Q, K.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>)) <span class="op">/</span> math.sqrt(<span class="va">self</span>.d_k)</span>
<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply mask if provided</span></span>
<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a>            scores <span class="op">=</span> scores.masked_fill(mask <span class="op">==</span> <span class="dv">0</span>, <span class="op">-</span><span class="fl">1e9</span>)</span>
<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Softmax normalization</span></span>
<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a>        attention_weights <span class="op">=</span> F.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a>        attention_weights <span class="op">=</span> <span class="va">self</span>.dropout(attention_weights)</span>
<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention to values</span></span>
<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.matmul(attention_weights, V)</span>
<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output, attention_weights</span>
<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, query, key, value, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len <span class="op">=</span> query.size(<span class="dv">0</span>), query.size(<span class="dv">1</span>)</span>
<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear projections and reshape for multi-head</span></span>
<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a>        Q <span class="op">=</span> <span class="va">self</span>.w_q(query).view(batch_size, seq_len, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a>        K <span class="op">=</span> <span class="va">self</span>.w_k(key).view(batch_size, seq_len, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a>        V <span class="op">=</span> <span class="va">self</span>.w_v(value).view(batch_size, seq_len, <span class="va">self</span>.num_heads, <span class="va">self</span>.d_k).transpose(<span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention</span></span>
<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a>        attention_output, attention_weights <span class="op">=</span> <span class="va">self</span>.scaled_dot_product_attention(Q, K, V, mask)</span>
<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate heads</span></span>
<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a>        attention_output <span class="op">=</span> attention_output.transpose(<span class="dv">1</span>, <span class="dv">2</span>).contiguous().view(</span>
<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a>            batch_size, seq_len, <span class="va">self</span>.d_model</span>
<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final linear projection</span></span>
<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> <span class="va">self</span>.w_o(attention_output)</span>
<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output, attention_weights</span>
<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a><span class="co"># Complete Transformer Block</span></span>
<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TransformerBlock(nn.Module):</span>
<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, num_heads, d_ff, dropout<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.attention <span class="op">=</span> MultiHeadAttention(d_model, num_heads, dropout)</span>
<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm1 <span class="op">=</span> nn.LayerNorm(d_model)</span>
<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm2 <span class="op">=</span> nn.LayerNorm(d_model)</span>
<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.feed_forward <span class="op">=</span> nn.Sequential(</span>
<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a>            nn.Linear(d_model, d_ff),</span>
<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout),</span>
<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a>            nn.Linear(d_ff, d_model)</span>
<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Self-attention with residual connection</span></span>
<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a>        attn_output, attn_weights <span class="op">=</span> <span class="va">self</span>.attention(x, x, x, mask)</span>
<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm1(x <span class="op">+</span> <span class="va">self</span>.dropout(attn_output))</span>
<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feed-forward with residual connection</span></span>
<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a>        ff_output <span class="op">=</span> <span class="va">self</span>.feed_forward(x)</span>
<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm2(x <span class="op">+</span> <span class="va">self</span>.dropout(ff_output))</span>
<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x, attn_weights</span>
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> transformer_example():</span>
<span id="cb1-99"><a href="#cb1-99" aria-hidden="true" tabindex="-1"></a>    batch_size, seq_len, d_model <span class="op">=</span> <span class="dv">2</span>, <span class="dv">10</span>, <span class="dv">512</span></span>
<span id="cb1-100"><a href="#cb1-100" aria-hidden="true" tabindex="-1"></a>    num_heads, d_ff <span class="op">=</span> <span class="dv">8</span>, <span class="dv">2048</span></span>
<span id="cb1-101"><a href="#cb1-101" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-102"><a href="#cb1-102" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create input</span></span>
<span id="cb1-103"><a href="#cb1-103" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(batch_size, seq_len, d_model)</span>
<span id="cb1-104"><a href="#cb1-104" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-105"><a href="#cb1-105" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create transformer block</span></span>
<span id="cb1-106"><a href="#cb1-106" aria-hidden="true" tabindex="-1"></a>    transformer <span class="op">=</span> TransformerBlock(d_model, num_heads, d_ff)</span>
<span id="cb1-107"><a href="#cb1-107" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-108"><a href="#cb1-108" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forward pass</span></span>
<span id="cb1-109"><a href="#cb1-109" aria-hidden="true" tabindex="-1"></a>    output, attention_weights <span class="op">=</span> transformer(x)</span>
<span id="cb1-110"><a href="#cb1-110" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-111"><a href="#cb1-111" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Input shape: </span><span class="sc">{</span>x<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-112"><a href="#cb1-112" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Output shape: </span><span class="sc">{</span>output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-113"><a href="#cb1-113" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Attention weights shape: </span><span class="sc">{</span>attention_weights<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb1-114"><a href="#cb1-114" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-115"><a href="#cb1-115" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output, attention_weights</span></code></pre></div></div>
</section>
<section id="positional-encoding-for-transformers" class="level3">
<h3 class="anchored" data-anchor-id="positional-encoding-for-transformers" id="positional-encoding-for-transformers">Positional Encoding for Transformers</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PositionalEncoding(nn.Module):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, d_model, max_len<span class="op">=</span><span class="dv">5000</span>):</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>        pe <span class="op">=</span> torch.zeros(max_len, d_model)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        position <span class="op">=</span> torch.arange(<span class="dv">0</span>, max_len).unsqueeze(<span class="dv">1</span>).<span class="bu">float</span>()</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        div_term <span class="op">=</span> torch.exp(torch.arange(<span class="dv">0</span>, d_model, <span class="dv">2</span>).<span class="bu">float</span>() <span class="op">*</span> </span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>                           <span class="op">-</span>(math.log(<span class="fl">10000.0</span>) <span class="op">/</span> d_model))</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        pe[:, <span class="dv">0</span>::<span class="dv">2</span>] <span class="op">=</span> torch.sin(position <span class="op">*</span> div_term)</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        pe[:, <span class="dv">1</span>::<span class="dv">2</span>] <span class="op">=</span> torch.cos(position <span class="op">*</span> div_term)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">'pe'</span>, pe.unsqueeze(<span class="dv">0</span>))</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">+</span> <span class="va">self</span>.pe[:, :x.size(<span class="dv">1</span>)]</span></code></pre></div></div>
</section>
</section>
<section id="cnn-attention" class="level2">
<h2 class="anchored" data-anchor-id="cnn-attention" id="cnn-attention">CNN Attention</h2>
<section id="spatial-attention-mechanism" class="level3">
<h3 class="anchored" data-anchor-id="spatial-attention-mechanism" id="spatial-attention-mechanism">Spatial Attention Mechanism</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SpatialAttention(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, kernel_size<span class="op">=</span><span class="dv">7</span>):</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv <span class="op">=</span> nn.Conv2d(<span class="dv">2</span>, <span class="dv">1</span>, kernel_size, padding<span class="op">=</span>kernel_size<span class="op">//</span><span class="dv">2</span>, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.sigmoid <span class="op">=</span> nn.Sigmoid()</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute spatial statistics</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        avg_pool <span class="op">=</span> torch.mean(x, dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)  <span class="co"># [B, 1, H, W]</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        max_pool, _ <span class="op">=</span> torch.<span class="bu">max</span>(x, dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)  <span class="co"># [B, 1, H, W]</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Concatenate along channel dimension</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        spatial_info <span class="op">=</span> torch.cat([avg_pool, max_pool], dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># [B, 2, H, W]</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate attention map</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        attention_map <span class="op">=</span> <span class="va">self</span>.conv(spatial_info)  <span class="co"># [B, 1, H, W]</span></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        attention_map <span class="op">=</span> <span class="va">self</span>.sigmoid(attention_map)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention</span></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> attention_map</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ChannelAttention(nn.Module):</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, reduction_ratio<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.avg_pool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_pool <span class="op">=</span> nn.AdaptiveMaxPool2d(<span class="dv">1</span>)</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            nn.Linear(in_channels, in_channels <span class="op">//</span> reduction_ratio, bias<span class="op">=</span><span class="va">False</span>),</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>            nn.Linear(in_channels <span class="op">//</span> reduction_ratio, in_channels, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.sigmoid <span class="op">=</span> nn.Sigmoid()</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        b, c, h, w <span class="op">=</span> x.size()</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global average pooling and max pooling</span></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        avg_pool <span class="op">=</span> <span class="va">self</span>.avg_pool(x).view(b, c)</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        max_pool <span class="op">=</span> <span class="va">self</span>.max_pool(x).view(b, c)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Channel attention</span></span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        avg_out <span class="op">=</span> <span class="va">self</span>.fc(avg_pool)</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        max_out <span class="op">=</span> <span class="va">self</span>.fc(max_pool)</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine and apply sigmoid</span></span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>        channel_attention <span class="op">=</span> <span class="va">self</span>.sigmoid(avg_out <span class="op">+</span> max_out).view(b, c, <span class="dv">1</span>, <span class="dv">1</span>)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x <span class="op">*</span> channel_attention</span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a><span class="co"># CBAM (Convolutional Block Attention Module)</span></span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CBAM(nn.Module):</span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels, reduction_ratio<span class="op">=</span><span class="dv">16</span>, kernel_size<span class="op">=</span><span class="dv">7</span>):</span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.channel_attention <span class="op">=</span> ChannelAttention(in_channels, reduction_ratio)</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.spatial_attention <span class="op">=</span> SpatialAttention(kernel_size)</span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply channel attention first</span></span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.channel_attention(x)</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Then apply spatial attention</span></span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.spatial_attention(x)</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a><span class="co"># Self-Attention for CNNs</span></span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SelfAttention2D(nn.Module):</span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_channels):</span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.in_channels <span class="op">=</span> in_channels</span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.query_conv <span class="op">=</span> nn.Conv2d(in_channels, in_channels <span class="op">//</span> <span class="dv">8</span>, <span class="dv">1</span>)</span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.key_conv <span class="op">=</span> nn.Conv2d(in_channels, in_channels <span class="op">//</span> <span class="dv">8</span>, <span class="dv">1</span>)</span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.value_conv <span class="op">=</span> nn.Conv2d(in_channels, in_channels, <span class="dv">1</span>)</span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.gamma <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>))</span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.softmax <span class="op">=</span> nn.Softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>        batch_size, channels, height, width <span class="op">=</span> x.size()</span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate Q, K, V</span></span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>        proj_query <span class="op">=</span> <span class="va">self</span>.query_conv(x).view(batch_size, <span class="op">-</span><span class="dv">1</span>, width <span class="op">*</span> height).permute(<span class="dv">0</span>, <span class="dv">2</span>, <span class="dv">1</span>)</span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>        proj_key <span class="op">=</span> <span class="va">self</span>.key_conv(x).view(batch_size, <span class="op">-</span><span class="dv">1</span>, width <span class="op">*</span> height)</span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>        proj_value <span class="op">=</span> <span class="va">self</span>.value_conv(x).view(batch_size, <span class="op">-</span><span class="dv">1</span>, width <span class="op">*</span> height)</span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute attention</span></span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>        energy <span class="op">=</span> torch.bmm(proj_query, proj_key)</span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a>        attention <span class="op">=</span> <span class="va">self</span>.softmax(energy)</span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention to values</span></span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> torch.bmm(proj_value, attention.permute(<span class="dv">0</span>, <span class="dv">2</span>, <span class="dv">1</span>))</span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> out.view(batch_size, channels, height, width)</span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Residual connection with learnable weight</span></span>
<span id="cb3-96"><a href="#cb3-96" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.gamma <span class="op">*</span> out <span class="op">+</span> x</span>
<span id="cb3-97"><a href="#cb3-97" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-98"><a href="#cb3-98" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span>
<span id="cb3-99"><a href="#cb3-99" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-100"><a href="#cb3-100" aria-hidden="true" tabindex="-1"></a><span class="co"># CNN with Attention</span></span>
<span id="cb3-101"><a href="#cb3-101" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AttentionCNN(nn.Module):</span>
<span id="cb3-102"><a href="#cb3-102" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb3-103"><a href="#cb3-103" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-104"><a href="#cb3-104" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-105"><a href="#cb3-105" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-106"><a href="#cb3-106" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cbam1 <span class="op">=</span> CBAM(<span class="dv">64</span>)</span>
<span id="cb3-107"><a href="#cb3-107" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-108"><a href="#cb3-108" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> nn.Conv2d(<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-109"><a href="#cb3-109" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cbam2 <span class="op">=</span> CBAM(<span class="dv">128</span>)</span>
<span id="cb3-110"><a href="#cb3-110" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-111"><a href="#cb3-111" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv3 <span class="op">=</span> nn.Conv2d(<span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-112"><a href="#cb3-112" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.self_attention <span class="op">=</span> SelfAttention2D(<span class="dv">256</span>)</span>
<span id="cb3-113"><a href="#cb3-113" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-114"><a href="#cb3-114" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pool <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">1</span>)</span>
<span id="cb3-115"><a href="#cb3-115" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="dv">256</span>, num_classes)</span>
<span id="cb3-116"><a href="#cb3-116" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-117"><a href="#cb3-117" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-118"><a href="#cb3-118" aria-hidden="true" tabindex="-1"></a>        <span class="co"># First block</span></span>
<span id="cb3-119"><a href="#cb3-119" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(<span class="va">self</span>.conv1(x))</span>
<span id="cb3-120"><a href="#cb3-120" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.cbam1(x)</span>
<span id="cb3-121"><a href="#cb3-121" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb3-122"><a href="#cb3-122" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-123"><a href="#cb3-123" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Second block</span></span>
<span id="cb3-124"><a href="#cb3-124" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(<span class="va">self</span>.conv2(x))</span>
<span id="cb3-125"><a href="#cb3-125" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.cbam2(x)</span>
<span id="cb3-126"><a href="#cb3-126" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb3-127"><a href="#cb3-127" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-128"><a href="#cb3-128" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Third block with self-attention</span></span>
<span id="cb3-129"><a href="#cb3-129" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(<span class="va">self</span>.conv3(x))</span>
<span id="cb3-130"><a href="#cb3-130" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.self_attention(x)</span>
<span id="cb3-131"><a href="#cb3-131" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb3-132"><a href="#cb3-132" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-133"><a href="#cb3-133" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classification</span></span>
<span id="cb3-134"><a href="#cb3-134" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.pool(x)</span>
<span id="cb3-135"><a href="#cb3-135" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb3-136"><a href="#cb3-136" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb3-137"><a href="#cb3-137" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-138"><a href="#cb3-138" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-139"><a href="#cb3-139" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-140"><a href="#cb3-140" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb3-141"><a href="#cb3-141" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cnn_attention_example():</span>
<span id="cb3-142"><a href="#cb3-142" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb3-143"><a href="#cb3-143" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(batch_size, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)</span>
<span id="cb3-144"><a href="#cb3-144" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-145"><a href="#cb3-145" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> AttentionCNN(num_classes<span class="op">=</span><span class="dv">1000</span>)</span>
<span id="cb3-146"><a href="#cb3-146" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> model(x)</span>
<span id="cb3-147"><a href="#cb3-147" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-148"><a href="#cb3-148" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Input shape: </span><span class="sc">{</span>x<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-149"><a href="#cb3-149" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Output shape: </span><span class="sc">{</span>output<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-150"><a href="#cb3-150" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-151"><a href="#cb3-151" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output</span></code></pre></div></div>
</section>
</section>
<section id="key-differences" class="level2">
<h2 class="anchored" data-anchor-id="key-differences" id="key-differences">Key Differences</h2>
<section id="computational-complexity" class="level3">
<h3 class="anchored" data-anchor-id="computational-complexity" id="computational-complexity">1. Computational Complexity</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> attention_complexity_comparison():</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Compare computational complexity of different attention mechanisms</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Transformer Self-Attention: O(n²d) where n=sequence length, d=dimension</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> transformer_complexity(seq_len, d_model):</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> seq_len <span class="op">*</span> seq_len <span class="op">*</span> d_model</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># CNN Spatial Attention: O(HW) where H=height, W=width</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> spatial_attention_complexity(height, width):</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> height <span class="op">*</span> width</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># CNN Channel Attention: O(C) where C=channels</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> channel_attention_complexity(channels):</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> channels</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Example calculations</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    seq_len, d_model <span class="op">=</span> <span class="dv">512</span>, <span class="dv">512</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    height, width, channels <span class="op">=</span> <span class="dv">224</span>, <span class="dv">224</span>, <span class="dv">256</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    transformer_ops <span class="op">=</span> transformer_complexity(seq_len, d_model)</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    spatial_ops <span class="op">=</span> spatial_attention_complexity(height, width)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    channel_ops <span class="op">=</span> channel_attention_complexity(channels)</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Transformer attention operations: </span><span class="sc">{</span>transformer_ops<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"CNN spatial attention operations: </span><span class="sc">{</span>spatial_ops<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"CNN channel attention operations: </span><span class="sc">{</span>channel_ops<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="st">'transformer'</span>: transformer_ops,</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        <span class="st">'spatial'</span>: spatial_ops,</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>        <span class="st">'channel'</span>: channel_ops</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>    }</span></code></pre></div></div>
</section>
<section id="information-flow-patterns" class="level3">
<h3 class="anchored" data-anchor-id="information-flow-patterns" id="information-flow-patterns">2. Information Flow Patterns</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AttentionAnalysis:</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_transformer_attention(attention_weights):</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co">        Analyze attention patterns in Transformers</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co">            attention_weights: [batch_size, num_heads, seq_len, seq_len]</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        batch_size, num_heads, seq_len, _ <span class="op">=</span> attention_weights.shape</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Average attention across heads</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        avg_attention <span class="op">=</span> attention_weights.mean(dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># [batch_size, seq_len, seq_len]</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute attention statistics</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        max_attention <span class="op">=</span> avg_attention.<span class="bu">max</span>(dim<span class="op">=-</span><span class="dv">1</span>)[<span class="dv">0</span>]  <span class="co"># Max attention per position</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        attention_entropy <span class="op">=</span> <span class="op">-</span>torch.<span class="bu">sum</span>(avg_attention <span class="op">*</span> torch.log(avg_attention <span class="op">+</span> <span class="fl">1e-8</span>), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_attention'</span>: max_attention,</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'attention_entropy'</span>: attention_entropy,</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            <span class="st">'global_connectivity'</span>: <span class="va">True</span>,  <span class="co"># All positions can attend to all others</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="st">'pattern'</span>: <span class="st">'sequence-to-sequence'</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> analyze_cnn_attention(feature_map, attention_map):</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="co">        Analyze attention patterns in CNNs</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a><span class="co">            feature_map: [batch_size, channels, height, width]</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a><span class="co">            attention_map: [batch_size, 1, height, width] or [batch_size, channels, 1, 1]</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> attention_map.dim() <span class="op">==</span> <span class="dv">4</span> <span class="kw">and</span> attention_map.size(<span class="dv">2</span>) <span class="op">==</span> <span class="dv">1</span>:</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Channel attention</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>            attention_type <span class="op">=</span> <span class="st">'channel'</span></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>            local_connectivity <span class="op">=</span> <span class="va">False</span></span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Spatial attention</span></span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>            attention_type <span class="op">=</span> <span class="st">'spatial'</span></span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>            local_connectivity <span class="op">=</span> <span class="va">True</span></span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">'attention_type'</span>: attention_type,</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">'local_connectivity'</span>: local_connectivity,</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'pattern'</span>: <span class="st">'spatial-hierarchy'</span> <span class="cf">if</span> attention_type <span class="op">==</span> <span class="st">'spatial'</span> <span class="cf">else</span> <span class="st">'channel-selection'</span></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PerformanceBenchmark:</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> benchmark_transformer_attention(<span class="va">self</span>, batch_size<span class="op">=</span><span class="dv">32</span>, seq_len<span class="op">=</span><span class="dv">512</span>, d_model<span class="op">=</span><span class="dv">512</span>, num_heads<span class="op">=</span><span class="dv">8</span>):</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Benchmark Transformer attention"""</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> MultiHeadAttention(d_model, num_heads).to(<span class="va">self</span>.device)</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.randn(batch_size, seq_len, d_model).to(<span class="va">self</span>.device)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> model(x, x, x)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            output, _ <span class="op">=</span> model(x, x, x)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> (end_time <span class="op">-</span> start_time) <span class="op">/</span> <span class="dv">100</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> benchmark_cnn_attention(<span class="va">self</span>, batch_size<span class="op">=</span><span class="dv">32</span>, channels<span class="op">=</span><span class="dv">256</span>, height<span class="op">=</span><span class="dv">56</span>, width<span class="op">=</span><span class="dv">56</span>):</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Benchmark CNN attention"""</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> CBAM(channels).to(<span class="va">self</span>.device)</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.randn(batch_size, channels, height, width).to(<span class="va">self</span>.device)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> model(x)</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(x)</span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>        torch.cuda.synchronize() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> (end_time <span class="op">-</span> start_time) <span class="op">/</span> <span class="dv">100</span></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> run_comparison(<span class="va">self</span>):</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run performance comparison"""</span></span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>        transformer_time <span class="op">=</span> <span class="va">self</span>.benchmark_transformer_attention()</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>        cnn_time <span class="op">=</span> <span class="va">self</span>.benchmark_cnn_attention()</span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Transformer attention time: </span><span class="sc">{</span>transformer_time<span class="sc">:.4f}</span><span class="ss">s"</span>)</span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"CNN attention time: </span><span class="sc">{</span>cnn_time<span class="sc">:.4f}</span><span class="ss">s"</span>)</span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Speedup: </span><span class="sc">{</span>transformer_time<span class="op">/</span>cnn_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>            <span class="st">'transformer_time'</span>: transformer_time,</span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a>            <span class="st">'cnn_time'</span>: cnn_time,</span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>            <span class="st">'speedup'</span>: transformer_time<span class="op">/</span>cnn_time</span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a><span class="co"># Memory usage comparison</span></span>
<span id="cb6-66"><a href="#cb6-66" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_comparison():</span>
<span id="cb6-67"><a href="#cb6-67" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Compare memory usage of different attention mechanisms"""</span></span>
<span id="cb6-68"><a href="#cb6-68" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-69"><a href="#cb6-69" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_memory_usage():</span>
<span id="cb6-70"><a href="#cb6-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb6-71"><a href="#cb6-71" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> torch.cuda.memory_allocated() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span>  <span class="co"># MB</span></span>
<span id="cb6-72"><a href="#cb6-72" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">0</span></span>
<span id="cb6-73"><a href="#cb6-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-74"><a href="#cb6-74" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Clear memory</span></span>
<span id="cb6-75"><a href="#cb6-75" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-76"><a href="#cb6-76" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-77"><a href="#cb6-77" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Transformer attention</span></span>
<span id="cb6-78"><a href="#cb6-78" aria-hidden="true" tabindex="-1"></a>    transformer_model <span class="op">=</span> MultiHeadAttention(<span class="dv">512</span>, <span class="dv">8</span>)</span>
<span id="cb6-79"><a href="#cb6-79" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">512</span>, <span class="dv">512</span>)</span>
<span id="cb6-80"><a href="#cb6-80" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-81"><a href="#cb6-81" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb6-82"><a href="#cb6-82" aria-hidden="true" tabindex="-1"></a>        transformer_model <span class="op">=</span> transformer_model.cuda()</span>
<span id="cb6-83"><a href="#cb6-83" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.cuda()</span>
<span id="cb6-84"><a href="#cb6-84" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-85"><a href="#cb6-85" aria-hidden="true" tabindex="-1"></a>    transformer_memory <span class="op">=</span> get_memory_usage()</span>
<span id="cb6-86"><a href="#cb6-86" aria-hidden="true" tabindex="-1"></a>    _, _ <span class="op">=</span> transformer_model(x, x, x)</span>
<span id="cb6-87"><a href="#cb6-87" aria-hidden="true" tabindex="-1"></a>    transformer_memory <span class="op">=</span> get_memory_usage() <span class="op">-</span> transformer_memory</span>
<span id="cb6-88"><a href="#cb6-88" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-89"><a href="#cb6-89" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Clear memory</span></span>
<span id="cb6-90"><a href="#cb6-90" aria-hidden="true" tabindex="-1"></a>    <span class="kw">del</span> transformer_model, x</span>
<span id="cb6-91"><a href="#cb6-91" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache() <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="va">None</span></span>
<span id="cb6-92"><a href="#cb6-92" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-93"><a href="#cb6-93" aria-hidden="true" tabindex="-1"></a>    <span class="co"># CNN attention</span></span>
<span id="cb6-94"><a href="#cb6-94" aria-hidden="true" tabindex="-1"></a>    cnn_model <span class="op">=</span> CBAM(<span class="dv">256</span>)</span>
<span id="cb6-95"><a href="#cb6-95" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(<span class="dv">32</span>, <span class="dv">256</span>, <span class="dv">56</span>, <span class="dv">56</span>)</span>
<span id="cb6-96"><a href="#cb6-96" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-97"><a href="#cb6-97" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb6-98"><a href="#cb6-98" aria-hidden="true" tabindex="-1"></a>        cnn_model <span class="op">=</span> cnn_model.cuda()</span>
<span id="cb6-99"><a href="#cb6-99" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.cuda()</span>
<span id="cb6-100"><a href="#cb6-100" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-101"><a href="#cb6-101" aria-hidden="true" tabindex="-1"></a>    cnn_memory <span class="op">=</span> get_memory_usage()</span>
<span id="cb6-102"><a href="#cb6-102" aria-hidden="true" tabindex="-1"></a>    _ <span class="op">=</span> cnn_model(x)</span>
<span id="cb6-103"><a href="#cb6-103" aria-hidden="true" tabindex="-1"></a>    cnn_memory <span class="op">=</span> get_memory_usage() <span class="op">-</span> cnn_memory</span>
<span id="cb6-104"><a href="#cb6-104" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-105"><a href="#cb6-105" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Transformer attention memory: </span><span class="sc">{</span>transformer_memory<span class="sc">:.2f}</span><span class="ss"> MB"</span>)</span>
<span id="cb6-106"><a href="#cb6-106" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"CNN attention memory: </span><span class="sc">{</span>cnn_memory<span class="sc">:.2f}</span><span class="ss"> MB"</span>)</span>
<span id="cb6-107"><a href="#cb6-107" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-108"><a href="#cb6-108" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb6-109"><a href="#cb6-109" aria-hidden="true" tabindex="-1"></a>        <span class="st">'transformer_memory'</span>: transformer_memory,</span>
<span id="cb6-110"><a href="#cb6-110" aria-hidden="true" tabindex="-1"></a>        <span class="st">'cnn_memory'</span>: cnn_memory</span>
<span id="cb6-111"><a href="#cb6-111" aria-hidden="true" tabindex="-1"></a>    }</span></code></pre></div></div>
</section>
<section id="when-to-use-each" class="level2">
<h2 class="anchored" data-anchor-id="when-to-use-each" id="when-to-use-each">When to Use Each</h2>
<section id="decision-framework" class="level3">
<h3 class="anchored" data-anchor-id="decision-framework" id="decision-framework">Decision Framework</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AttentionSelector:</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> recommend_attention_type(data_type, sequence_length<span class="op">=</span><span class="va">None</span>, spatial_dims<span class="op">=</span><span class="va">None</span>, </span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>                                computational_budget<span class="op">=</span><span class="st">'medium'</span>, task_type<span class="op">=</span><span class="st">'classification'</span>):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="co">        Recommend attention mechanism based on requirements</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="co">        </span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="co">        Args:</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a><span class="co">            data_type: 'sequential', 'spatial', 'mixed'</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a><span class="co">            sequence_length: Length of sequences (for sequential data)</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a><span class="co">            spatial_dims: (height, width) for spatial data</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a><span class="co">            computational_budget: 'low', 'medium', 'high'</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a><span class="co">            task_type: 'classification', 'generation', 'detection'</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        recommendations <span class="op">=</span> []</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Sequential data</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> data_type <span class="op">==</span> <span class="st">'sequential'</span>:</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> sequence_length <span class="kw">and</span> sequence_length <span class="op">&gt;</span> <span class="dv">1000</span> <span class="kw">and</span> computational_budget <span class="op">==</span> <span class="st">'low'</span>:</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>                recommendations.append({</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'type'</span>: <span class="st">'Local Attention'</span>,</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'reason'</span>: <span class="st">'Long sequences with limited compute'</span>,</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'implementation'</span>: <span class="st">'sliding_window_attention'</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>                recommendations.append({</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'type'</span>: <span class="st">'Transformer Self-Attention'</span>,</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'reason'</span>: <span class="st">'Global context modeling for sequences'</span>,</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'implementation'</span>: <span class="st">'MultiHeadAttention'</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Spatial data</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> data_type <span class="op">==</span> <span class="st">'spatial'</span>:</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> spatial_dims <span class="kw">and</span> spatial_dims[<span class="dv">0</span>] <span class="op">*</span> spatial_dims[<span class="dv">1</span>] <span class="op">&gt;</span> <span class="dv">224</span> <span class="op">*</span> <span class="dv">224</span>:</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>                recommendations.append({</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'type'</span>: <span class="st">'CNN Spatial + Channel Attention'</span>,</span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'reason'</span>: <span class="st">'High-resolution spatial data'</span>,</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'implementation'</span>: <span class="st">'CBAM'</span></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>                recommendations.append({</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'type'</span>: <span class="st">'CNN Self-Attention'</span>,</span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'reason'</span>: <span class="st">'Moderate resolution with global context'</span>,</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>                    <span class="st">'implementation'</span>: <span class="st">'SelfAttention2D'</span></span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Mixed data</span></span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">elif</span> data_type <span class="op">==</span> <span class="st">'mixed'</span>:</span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>            recommendations.append({</span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>                <span class="st">'type'</span>: <span class="st">'Hybrid Attention'</span>,</span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a>                <span class="st">'reason'</span>: <span class="st">'Combined sequential and spatial processing'</span>,</span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>                <span class="st">'implementation'</span>: <span class="st">'transformer_cnn_hybrid'</span></span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb7-55"><a href="#cb7-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-56"><a href="#cb7-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> recommendations</span>
<span id="cb7-57"><a href="#cb7-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-58"><a href="#cb7-58" aria-hidden="true" tabindex="-1"></a>    <span class="at">@staticmethod</span></span>
<span id="cb7-59"><a href="#cb7-59" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_hybrid_model(input_shape, num_classes):</span>
<span id="cb7-60"><a href="#cb7-60" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create a hybrid model combining both attention types"""</span></span>
<span id="cb7-61"><a href="#cb7-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-62"><a href="#cb7-62" aria-hidden="true" tabindex="-1"></a>        <span class="kw">class</span> HybridAttentionModel(nn.Module):</span>
<span id="cb7-63"><a href="#cb7-63" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_shape, num_classes):</span>
<span id="cb7-64"><a href="#cb7-64" aria-hidden="true" tabindex="-1"></a>                <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-65"><a href="#cb7-65" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-66"><a href="#cb7-66" aria-hidden="true" tabindex="-1"></a>                <span class="co"># CNN backbone with attention</span></span>
<span id="cb7-67"><a href="#cb7-67" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.cnn_backbone <span class="op">=</span> nn.Sequential(</span>
<span id="cb7-68"><a href="#cb7-68" aria-hidden="true" tabindex="-1"></a>                    nn.Conv2d(input_shape[<span class="dv">0</span>], <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-69"><a href="#cb7-69" aria-hidden="true" tabindex="-1"></a>                    nn.ReLU(),</span>
<span id="cb7-70"><a href="#cb7-70" aria-hidden="true" tabindex="-1"></a>                    CBAM(<span class="dv">64</span>),</span>
<span id="cb7-71"><a href="#cb7-71" aria-hidden="true" tabindex="-1"></a>                    nn.MaxPool2d(<span class="dv">2</span>),</span>
<span id="cb7-72"><a href="#cb7-72" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb7-73"><a href="#cb7-73" aria-hidden="true" tabindex="-1"></a>                    nn.Conv2d(<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-74"><a href="#cb7-74" aria-hidden="true" tabindex="-1"></a>                    nn.ReLU(),</span>
<span id="cb7-75"><a href="#cb7-75" aria-hidden="true" tabindex="-1"></a>                    CBAM(<span class="dv">128</span>),</span>
<span id="cb7-76"><a href="#cb7-76" aria-hidden="true" tabindex="-1"></a>                    nn.MaxPool2d(<span class="dv">2</span>),</span>
<span id="cb7-77"><a href="#cb7-77" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb7-78"><a href="#cb7-78" aria-hidden="true" tabindex="-1"></a>                    nn.Conv2d(<span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb7-79"><a href="#cb7-79" aria-hidden="true" tabindex="-1"></a>                    nn.ReLU(),</span>
<span id="cb7-80"><a href="#cb7-80" aria-hidden="true" tabindex="-1"></a>                    SelfAttention2D(<span class="dv">256</span>)</span>
<span id="cb7-81"><a href="#cb7-81" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb7-82"><a href="#cb7-82" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-83"><a href="#cb7-83" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Flatten and prepare for transformer</span></span>
<span id="cb7-84"><a href="#cb7-84" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.flatten <span class="op">=</span> nn.AdaptiveAvgPool2d(<span class="dv">8</span>)  <span class="co"># 8x8 spatial grid</span></span>
<span id="cb7-85"><a href="#cb7-85" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.embed_dim <span class="op">=</span> <span class="dv">256</span></span>
<span id="cb7-86"><a href="#cb7-86" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-87"><a href="#cb7-87" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Transformer layers</span></span>
<span id="cb7-88"><a href="#cb7-88" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.transformer <span class="op">=</span> nn.Sequential(</span>
<span id="cb7-89"><a href="#cb7-89" aria-hidden="true" tabindex="-1"></a>                    <span class="op">*</span>[TransformerBlock(<span class="va">self</span>.embed_dim, <span class="dv">8</span>, <span class="dv">1024</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>)]</span>
<span id="cb7-90"><a href="#cb7-90" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb7-91"><a href="#cb7-91" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-92"><a href="#cb7-92" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Classification head</span></span>
<span id="cb7-93"><a href="#cb7-93" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="va">self</span>.embed_dim, num_classes)</span>
<span id="cb7-94"><a href="#cb7-94" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-95"><a href="#cb7-95" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-96"><a href="#cb7-96" aria-hidden="true" tabindex="-1"></a>                <span class="co"># CNN processing</span></span>
<span id="cb7-97"><a href="#cb7-97" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> <span class="va">self</span>.cnn_backbone(x)</span>
<span id="cb7-98"><a href="#cb7-98" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-99"><a href="#cb7-99" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Reshape for transformer</span></span>
<span id="cb7-100"><a href="#cb7-100" aria-hidden="true" tabindex="-1"></a>                batch_size <span class="op">=</span> x.size(<span class="dv">0</span>)</span>
<span id="cb7-101"><a href="#cb7-101" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> <span class="va">self</span>.flatten(x)  <span class="co"># [B, 256, 8, 8]</span></span>
<span id="cb7-102"><a href="#cb7-102" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> x.flatten(<span class="dv">2</span>).transpose(<span class="dv">1</span>, <span class="dv">2</span>)  <span class="co"># [B, 64, 256]</span></span>
<span id="cb7-103"><a href="#cb7-103" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-104"><a href="#cb7-104" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Transformer processing</span></span>
<span id="cb7-105"><a href="#cb7-105" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> transformer_block <span class="kw">in</span> <span class="va">self</span>.transformer:</span>
<span id="cb7-106"><a href="#cb7-106" aria-hidden="true" tabindex="-1"></a>                    x, _ <span class="op">=</span> transformer_block(x)</span>
<span id="cb7-107"><a href="#cb7-107" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-108"><a href="#cb7-108" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Global average pooling and classification</span></span>
<span id="cb7-109"><a href="#cb7-109" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> x.mean(dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># [B, 256]</span></span>
<span id="cb7-110"><a href="#cb7-110" aria-hidden="true" tabindex="-1"></a>                x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb7-111"><a href="#cb7-111" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-112"><a href="#cb7-112" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> x</span>
<span id="cb7-113"><a href="#cb7-113" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-114"><a href="#cb7-114" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> HybridAttentionModel(input_shape, num_classes)</span>
<span id="cb7-115"><a href="#cb7-115" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-116"><a href="#cb7-116" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage examples</span></span>
<span id="cb7-117"><a href="#cb7-117" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> usage_examples():</span>
<span id="cb7-118"><a href="#cb7-118" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Demonstrate when to use each attention type"""</span></span>
<span id="cb7-119"><a href="#cb7-119" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-120"><a href="#cb7-120" aria-hidden="true" tabindex="-1"></a>    selector <span class="op">=</span> AttentionSelector()</span>
<span id="cb7-121"><a href="#cb7-121" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-122"><a href="#cb7-122" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Example 1: NLP task</span></span>
<span id="cb7-123"><a href="#cb7-123" aria-hidden="true" tabindex="-1"></a>    nlp_rec <span class="op">=</span> selector.recommend_attention_type(</span>
<span id="cb7-124"><a href="#cb7-124" aria-hidden="true" tabindex="-1"></a>        data_type<span class="op">=</span><span class="st">'sequential'</span>,</span>
<span id="cb7-125"><a href="#cb7-125" aria-hidden="true" tabindex="-1"></a>        sequence_length<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb7-126"><a href="#cb7-126" aria-hidden="true" tabindex="-1"></a>        computational_budget<span class="op">=</span><span class="st">'high'</span>,</span>
<span id="cb7-127"><a href="#cb7-127" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">'generation'</span></span>
<span id="cb7-128"><a href="#cb7-128" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-129"><a href="#cb7-129" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-130"><a href="#cb7-130" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Example 2: Computer Vision task</span></span>
<span id="cb7-131"><a href="#cb7-131" aria-hidden="true" tabindex="-1"></a>    cv_rec <span class="op">=</span> selector.recommend_attention_type(</span>
<span id="cb7-132"><a href="#cb7-132" aria-hidden="true" tabindex="-1"></a>        data_type<span class="op">=</span><span class="st">'spatial'</span>,</span>
<span id="cb7-133"><a href="#cb7-133" aria-hidden="true" tabindex="-1"></a>        spatial_dims<span class="op">=</span>(<span class="dv">224</span>, <span class="dv">224</span>),</span>
<span id="cb7-134"><a href="#cb7-134" aria-hidden="true" tabindex="-1"></a>        computational_budget<span class="op">=</span><span class="st">'medium'</span>,</span>
<span id="cb7-135"><a href="#cb7-135" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">'classification'</span></span>
<span id="cb7-136"><a href="#cb7-136" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-137"><a href="#cb7-137" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-138"><a href="#cb7-138" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Example 3: Video analysis</span></span>
<span id="cb7-139"><a href="#cb7-139" aria-hidden="true" tabindex="-1"></a>    video_rec <span class="op">=</span> selector.recommend_attention_type(</span>
<span id="cb7-140"><a href="#cb7-140" aria-hidden="true" tabindex="-1"></a>        data_type<span class="op">=</span><span class="st">'mixed'</span>,</span>
<span id="cb7-141"><a href="#cb7-141" aria-hidden="true" tabindex="-1"></a>        sequence_length<span class="op">=</span><span class="dv">30</span>,</span>
<span id="cb7-142"><a href="#cb7-142" aria-hidden="true" tabindex="-1"></a>        spatial_dims<span class="op">=</span>(<span class="dv">112</span>, <span class="dv">112</span>),</span>
<span id="cb7-143"><a href="#cb7-143" aria-hidden="true" tabindex="-1"></a>        computational_budget<span class="op">=</span><span class="st">'high'</span>,</span>
<span id="cb7-144"><a href="#cb7-144" aria-hidden="true" tabindex="-1"></a>        task_type<span class="op">=</span><span class="st">'detection'</span></span>
<span id="cb7-145"><a href="#cb7-145" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-146"><a href="#cb7-146" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-147"><a href="#cb7-147" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"NLP Recommendation:"</span>, nlp_rec)</span>
<span id="cb7-148"><a href="#cb7-148" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Computer Vision Recommendation:"</span>, cv_rec)</span>
<span id="cb7-149"><a href="#cb7-149" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Video Analysis Recommendation:"</span>, video_rec)</span>
<span id="cb7-150"><a href="#cb7-150" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-151"><a href="#cb7-151" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> nlp_rec, cv_rec, video_rec</span></code></pre></div></div>
</section>
</section>
<section id="summary" class="level2">
<h2 class="anchored" data-anchor-id="summary" id="summary">Summary</h2>
<table class="caption-top table">
<colgroup>
<col style="width: 17%">
<col style="width: 48%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Aspect</th>
<th>Transformer Attention</th>
<th>CNN Attention</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Scope</strong></td>
<td>Global, all-to-all</td>
<td>Local, spatial/channel-wise</td>
</tr>
<tr class="even">
<td><strong>Complexity</strong></td>
<td>O(n²)</td>
<td>O(HW) or O(C)</td>
</tr>
<tr class="odd">
<td><strong>Best For</strong></td>
<td>Sequential data, language</td>
<td>Spatial data, images</td>
</tr>
<tr class="even">
<td><strong>Memory</strong></td>
<td>High</td>
<td>Moderate</td>
</tr>
<tr class="odd">
<td><strong>Parallelization</strong></td>
<td>Limited by sequence length</td>
<td>Highly parallelizable</td>
</tr>
<tr class="even">
<td><strong>Interpretability</strong></td>
<td>Attention weights show relationships</td>
<td>Spatial/channel importance maps</td>
</tr>
</tbody>
</table>
<p>Choose Transformer attention for tasks requiring global context modeling, and CNN attention for spatially-structured data where local relationships dominate. Consider hybrid approaches for complex multi-modal tasks.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[CUDA Python: Accelerating Python Applications with GPU Computing]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/cuda-python/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/cuda-python/</guid>
      <pubDate>Sun, 22 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="cuda-python-accelerating-python-applications-with-gpu-computing" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/cuda-python/cuda.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>CUDA Python brings the power of NVIDIA’s CUDA platform directly to Python developers, enabling massive parallel computing capabilities without leaving the Python ecosystem. This comprehensive guide explores how to leverage GPU acceleration for computationally intensive tasks, from basic vector operations to complex machine learning algorithms.</p>
</section>
<section id="what-is-cuda-python" class="level2">
<h2 class="anchored" data-anchor-id="what-is-cuda-python" id="what-is-cuda-python">What is CUDA Python?</h2>
<p>CUDA Python is a collection of Python packages that provide direct access to CUDA from Python. It includes several key components:</p>
<ul>
<li><strong>CuPy</strong>: NumPy-compatible library for GPU arrays</li>
<li><strong>Numba</strong>: Just-in-time compiler with CUDA support</li>
<li><strong>PyCUDA</strong>: Low-level Python wrapper for CUDA</li>
<li><strong>cuDF</strong>: GPU-accelerated DataFrame library</li>
<li><strong>CuML</strong>: GPU-accelerated machine learning library</li>
</ul>
</section>
<section id="setting-up-your-environment" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-your-environment" id="setting-up-your-environment">Setting Up Your Environment</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<p>Before diving into CUDA Python, ensure you have:</p>
<ol type="1">
<li>An NVIDIA GPU with CUDA Compute Capability 3.5 or higher</li>
<li>NVIDIA drivers installed</li>
<li>CUDA Toolkit (version 11.0 or later recommended)</li>
<li>Python 3.8 or later</li>
</ol>
</section>
<section id="installation" class="level3">
<h3 class="anchored" data-anchor-id="installation" id="installation">Installation</h3>
<p>The easiest way to get started is with conda:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a new environment</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> create <span class="at">-n</span> cuda-python python=3.9</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> activate cuda-python</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install CUDA Python packages</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> install <span class="at">-c</span> conda-forge cupy</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> install <span class="at">-c</span> conda-forge numba</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="ex">conda</span> install <span class="at">-c</span> rapidsai cudf cuml</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative: pip installation</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install cupy-cuda11x  <span class="co"># Replace 11x with your CUDA version</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install numba</span></code></pre></div></div>
</section>
</section>
<section id="getting-started-with-cupy" class="level2">
<h2 class="anchored" data-anchor-id="getting-started-with-cupy" id="getting-started-with-cupy">Getting Started with CuPy</h2>
<p>CuPy provides a NumPy-like interface for GPU computing, making it the most accessible entry point for CUDA Python.</p>
<section id="basic-array-operations" class="level3">
<h3 class="anchored" data-anchor-id="basic-array-operations" id="basic-array-operations">Basic Array Operations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create arrays on GPU</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>gpu_array <span class="op">=</span> cp.array([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>])</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"GPU Array: </span><span class="sc">{</span>gpu_array<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Device: </span><span class="sc">{</span>gpu_array<span class="sc">.</span>device<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert between CPU and GPU</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>cpu_array <span class="op">=</span> np.array([<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>, <span class="dv">5</span>])</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>gpu_from_cpu <span class="op">=</span> cp.asarray(cpu_array)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>cpu_from_gpu <span class="op">=</span> cp.asnumpy(gpu_array)</span></code></pre></div></div>
</section>
<section id="performance-comparison" class="level3">
<h3 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_operations():</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    size <span class="op">=</span> <span class="dv">10</span><span class="op">**</span><span class="dv">7</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># CPU computation with NumPy</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    cpu_a <span class="op">=</span> np.random.random(size)</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    cpu_b <span class="op">=</span> np.random.random(size)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    cpu_result <span class="op">=</span> np.sqrt(cpu_a<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> cpu_b<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    cpu_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># GPU computation with CuPy</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    gpu_a <span class="op">=</span> cp.random.random(size)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    gpu_b <span class="op">=</span> cp.random.random(size)</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.time()</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    gpu_result <span class="op">=</span> cp.sqrt(gpu_a<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> gpu_b<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    cp.cuda.Stream.null.synchronize()  <span class="co"># Wait for GPU to finish</span></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>    gpu_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"CPU time: </span><span class="sc">{</span>cpu_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"GPU time: </span><span class="sc">{</span>gpu_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Speedup: </span><span class="sc">{</span>cpu_time<span class="op">/</span>gpu_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>benchmark_operations()</span></code></pre></div></div>
</section>
</section>
<section id="advanced-cupy-custom-kernels" class="level2">
<h2 class="anchored" data-anchor-id="advanced-cupy-custom-kernels" id="advanced-cupy-custom-kernels">Advanced CuPy: Custom Kernels</h2>
<p>For maximum performance, you can write custom CUDA kernels:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Define a custom kernel</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>vector_add_kernel <span class="op">=</span> cp.RawKernel(<span class="vs">r'''</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="vs">extern "C" __global__</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="vs">void vector_add</span><span class="kw">(</span><span class="vs">const float</span><span class="op">*</span><span class="vs"> a, const float</span><span class="op">*</span><span class="vs"> b, float</span><span class="op">*</span><span class="vs"> c, int n</span><span class="kw">)</span><span class="vs"> {</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="vs">    int idx = blockIdx</span><span class="dv">.</span><span class="vs">x </span><span class="op">*</span><span class="vs"> blockDim</span><span class="dv">.</span><span class="vs">x </span><span class="op">+</span><span class="vs"> threadIdx</span><span class="dv">.</span><span class="vs">x;</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="vs">    if </span><span class="kw">(</span><span class="vs">idx &lt; n</span><span class="kw">)</span><span class="vs"> {</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="vs">        c</span><span class="pp">[idx]</span><span class="vs"> = a</span><span class="pp">[idx]</span><span class="vs"> </span><span class="op">+</span><span class="vs"> b</span><span class="pp">[idx]</span><span class="vs">;</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="vs">    }</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="vs">}</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="vs">'''</span>, <span class="st">'vector_add'</span>)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> custom_vector_add(a, b):</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">assert</span> a.shape <span class="op">==</span> b.shape</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    c <span class="op">=</span> cp.empty_like(a)</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    n <span class="op">=</span> a.size</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Launch kernel</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    threads_per_block <span class="op">=</span> <span class="dv">256</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    blocks_per_grid <span class="op">=</span> (n <span class="op">+</span> threads_per_block <span class="op">-</span> <span class="dv">1</span>) <span class="op">//</span> threads_per_block</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    vector_add_kernel((blocks_per_grid,), (threads_per_block,), </span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>                     (a, b, c, n))</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> c</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>a <span class="op">=</span> cp.random.random(<span class="dv">1000000</span>).astype(cp.float32)</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>b <span class="op">=</span> cp.random.random(<span class="dv">1000000</span>).astype(cp.float32)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> custom_vector_add(a, b)</span></code></pre></div></div>
</section>
<section id="numba-cuda-python-to-cuda-jit-compilation" class="level2">
<h2 class="anchored" data-anchor-id="numba-cuda-python-to-cuda-jit-compilation" id="numba-cuda-python-to-cuda-jit-compilation">Numba CUDA: Python-to-CUDA JIT Compilation</h2>
<p>Numba allows you to write CUDA kernels in Python syntax:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> numba <span class="im">import</span> cuda</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="at">@cuda.jit</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> matrix_multiply_kernel(A, B, C):</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    row, col <span class="op">=</span> cuda.grid(<span class="dv">2</span>)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> row <span class="op">&lt;</span> C.shape[<span class="dv">0</span>] <span class="kw">and</span> col <span class="op">&lt;</span> C.shape[<span class="dv">1</span>]:</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        temp <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> k <span class="kw">in</span> <span class="bu">range</span>(A.shape[<span class="dv">1</span>]):</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>            temp <span class="op">+=</span> A[row, k] <span class="op">*</span> B[k, col]</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        C[row, col] <span class="op">=</span> temp</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> gpu_matrix_multiply(A, B):</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Allocate memory on GPU</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    A_gpu <span class="op">=</span> cuda.to_device(A)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>    B_gpu <span class="op">=</span> cuda.to_device(B)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    C_gpu <span class="op">=</span> cuda.device_array((A.shape[<span class="dv">0</span>], B.shape[<span class="dv">1</span>]), dtype<span class="op">=</span>A.dtype)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Configure grid and block dimensions</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    threads_per_block <span class="op">=</span> (<span class="dv">16</span>, <span class="dv">16</span>)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    blocks_per_grid_x <span class="op">=</span> math.ceil(A.shape[<span class="dv">0</span>] <span class="op">/</span> threads_per_block[<span class="dv">0</span>])</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    blocks_per_grid_y <span class="op">=</span> math.ceil(B.shape[<span class="dv">1</span>] <span class="op">/</span> threads_per_block[<span class="dv">1</span>])</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    blocks_per_grid <span class="op">=</span> (blocks_per_grid_x, blocks_per_grid_y)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Launch kernel</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    matrix_multiply_kernel[blocks_per_grid, threads_per_block](A_gpu, B_gpu, C_gpu)</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Copy result back to host</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> C_gpu.copy_to_host()</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>A <span class="op">=</span> np.random.random((<span class="dv">1000</span>, <span class="dv">1000</span>)).astype(np.float32)</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>B <span class="op">=</span> np.random.random((<span class="dv">1000</span>, <span class="dv">1000</span>)).astype(np.float32)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>C <span class="op">=</span> gpu_matrix_multiply(A, B)</span></code></pre></div></div>
</section>
<section id="memory-management-best-practices" class="level2">
<h2 class="anchored" data-anchor-id="memory-management-best-practices" id="memory-management-best-practices">Memory Management Best Practices</h2>
<p>Efficient memory management is crucial for GPU performance:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Memory pool for efficient allocation</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>mempool <span class="op">=</span> cp.get_default_memory_pool()</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>pinned_mempool <span class="op">=</span> cp.get_default_pinned_memory_pool()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_gpu_computation():</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use context manager for automatic cleanup</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> cp.cuda.Device(<span class="dv">0</span>):  <span class="co"># Use GPU 0</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Pre-allocate memory</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        data <span class="op">=</span> cp.zeros((<span class="dv">10000</span>, <span class="dv">10000</span>), dtype<span class="op">=</span>cp.float32)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Perform computations</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> cp.fft.fft2(data)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> cp.<span class="bu">abs</span>(result)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Memory info</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Memory used: </span><span class="sc">{</span>mempool<span class="sc">.</span>used_bytes() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span><span class="sc">:.1f}</span><span class="ss"> MB"</span>)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Memory total: </span><span class="sc">{</span>mempool<span class="sc">.</span>total_bytes() <span class="op">/</span> <span class="dv">1024</span><span class="op">**</span><span class="dv">2</span><span class="sc">:.1f}</span><span class="ss"> MB"</span>)</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> cp.asnumpy(result)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Free unused memory</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cleanup_gpu_memory():</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    mempool.free_all_blocks()</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    pinned_mempool.free_all_blocks()</span></code></pre></div></div>
</section>
<section id="real-world-applications" class="level2">
<h2 class="anchored" data-anchor-id="real-world-applications" id="real-world-applications">Real-World Applications</h2>
<section id="image-processing-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="image-processing-pipeline" id="image-processing-pipeline">Image Processing Pipeline</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> cupyx.scipy <span class="im">import</span> ndimage</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> gpu_image_processing(image):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""GPU-accelerated image processing pipeline"""</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to GPU array</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    gpu_image <span class="op">=</span> cp.asarray(image)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply Gaussian blur</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    blurred <span class="op">=</span> ndimage.gaussian_filter(gpu_image, sigma<span class="op">=</span><span class="fl">2.0</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Edge detection (Sobel filter)</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    sobel_x <span class="op">=</span> ndimage.sobel(blurred, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    sobel_y <span class="op">=</span> ndimage.sobel(blurred, axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    edges <span class="op">=</span> cp.sqrt(sobel_x<span class="op">**</span><span class="dv">2</span> <span class="op">+</span> sobel_y<span class="op">**</span><span class="dv">2</span>)</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Threshold</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    threshold <span class="op">=</span> cp.percentile(edges, <span class="dv">90</span>)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    binary <span class="op">=</span> edges <span class="op">&gt;</span> threshold</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> cp.asnumpy(binary)</span></code></pre></div></div>
</section>
<section id="monte-carlo-simulation" class="level3">
<h3 class="anchored" data-anchor-id="monte-carlo-simulation" id="monte-carlo-simulation">Monte Carlo Simulation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> numba <span class="im">import</span> cuda</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="at">@cuda.jit</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> monte_carlo_pi_kernel(rng_states, n_samples, results):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    idx <span class="op">=</span> cuda.grid(<span class="dv">1</span>)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> idx <span class="op">&lt;</span> rng_states.shape[<span class="dv">0</span>]:</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_samples):</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> cuda.random.xoroshiro128p_uniform_float32(rng_states, idx)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>            y <span class="op">=</span> cuda.random.xoroshiro128p_uniform_float32(rng_states, idx)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> x<span class="op">*</span>x <span class="op">+</span> y<span class="op">*</span>y <span class="op">&lt;=</span> <span class="fl">1.0</span>:</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>                count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        results[idx] <span class="op">=</span> count</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> estimate_pi_gpu(n_threads<span class="op">=</span><span class="dv">1024</span>, n_samples_per_thread<span class="op">=</span><span class="dv">10000</span>):</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize random number generator states</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    rng_states <span class="op">=</span> cuda.random.create_xoroshiro128p_states(n_threads, seed<span class="op">=</span><span class="dv">42</span>)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> cuda.device_array(n_threads, dtype<span class="op">=</span>np.int32)</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Launch kernel</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    threads_per_block <span class="op">=</span> <span class="dv">256</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    blocks_per_grid <span class="op">=</span> (n_threads <span class="op">+</span> threads_per_block <span class="op">-</span> <span class="dv">1</span>) <span class="op">//</span> threads_per_block</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    monte_carlo_pi_kernel[blocks_per_grid, threads_per_block](</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        rng_states, n_samples_per_thread, results)</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate pi estimate</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    total_inside <span class="op">=</span> results.<span class="bu">sum</span>()</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    total_samples <span class="op">=</span> n_threads <span class="op">*</span> n_samples_per_thread</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    pi_estimate <span class="op">=</span> <span class="fl">4.0</span> <span class="op">*</span> total_inside <span class="op">/</span> total_samples</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> pi_estimate</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>pi_gpu <span class="op">=</span> estimate_pi_gpu()</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"GPU Pi estimate: </span><span class="sc">{</span>pi_gpu<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization-tips" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization-tips" id="performance-optimization-tips">Performance Optimization Tips</h2>
<section id="memory-access-patterns" class="level3">
<h3 class="anchored" data-anchor-id="memory-access-patterns" id="memory-access-patterns">1. Memory Access Patterns</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Bad: Non-coalesced memory access</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="at">@cuda.jit</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> bad_transpose(A, A_T):</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    i, j <span class="op">=</span> cuda.grid(<span class="dv">2</span>)</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> i <span class="op">&lt;</span> A.shape[<span class="dv">0</span>] <span class="kw">and</span> j <span class="op">&lt;</span> A.shape[<span class="dv">1</span>]:</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        A_T[j, i] <span class="op">=</span> A[i, j]  <span class="co"># Non-coalesced</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Good: Coalesced memory access with shared memory</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="at">@cuda.jit</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> good_transpose(A, A_T):</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use shared memory for efficient transpose</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    tile <span class="op">=</span> cuda.shared.array((<span class="dv">16</span>, <span class="dv">16</span>), dtype<span class="op">=</span>numba.float32)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    tx <span class="op">=</span> cuda.threadIdx.x</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    ty <span class="op">=</span> cuda.threadIdx.y</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    bx <span class="op">=</span> cuda.blockIdx.x <span class="op">*</span> <span class="dv">16</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    by <span class="op">=</span> cuda.blockIdx.y <span class="op">*</span> <span class="dv">16</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> bx <span class="op">+</span> tx</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> by <span class="op">+</span> ty</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> x <span class="op">&lt;</span> A.shape[<span class="dv">1</span>] <span class="kw">and</span> y <span class="op">&lt;</span> A.shape[<span class="dv">0</span>]:</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        tile[ty, tx] <span class="op">=</span> A[y, x]</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    cuda.syncthreads()</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> bx <span class="op">+</span> ty</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> by <span class="op">+</span> tx</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> x <span class="op">&lt;</span> A_T.shape[<span class="dv">1</span>] <span class="kw">and</span> y <span class="op">&lt;</span> A_T.shape[<span class="dv">0</span>]:</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        A_T[y, x] <span class="op">=</span> tile[tx, ty]</span></code></pre></div></div>
</section>
<section id="stream-processing" class="level3">
<h3 class="anchored" data-anchor-id="stream-processing" id="stream-processing">2. Stream Processing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> async_processing():</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create multiple streams for overlapping computation</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    stream1 <span class="op">=</span> cp.cuda.Stream()</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    stream2 <span class="op">=</span> cp.cuda.Stream()</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Process data in chunks</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    chunk_size <span class="op">=</span> <span class="dv">1000000</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    data1 <span class="op">=</span> cp.random.random(chunk_size)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    data2 <span class="op">=</span> cp.random.random(chunk_size)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> stream1:</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        result1 <span class="op">=</span> cp.fft.fft(data1)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> stream2:</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>        result2 <span class="op">=</span> cp.fft.fft(data2)</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Synchronize streams</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    stream1.synchronize()</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    stream2.synchronize()</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result1, result2</span></code></pre></div></div>
</section>
</section>
<section id="debugging-and-profiling" class="level2">
<h2 class="anchored" data-anchor-id="debugging-and-profiling" id="debugging-and-profiling">Debugging and Profiling</h2>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">Error Handling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_gpu_computation():</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># GPU computation that might fail</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        large_array <span class="op">=</span> cp.zeros((<span class="dv">50000</span>, <span class="dv">50000</span>), dtype<span class="op">=</span>cp.float64)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> cp.linalg.svd(large_array)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> cp.cuda.memory.OutOfMemoryError:</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"GPU out of memory. Try reducing array size."</span>)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"GPU computation failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span></code></pre></div></div>
</section>
<section id="profiling-with-cupy" class="level3">
<h3 class="anchored" data-anchor-id="profiling-with-cupy" id="profiling-with-cupy">Profiling with CuPy</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cupy <span class="im">as</span> cp</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable profiling</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>cp.cuda.profiler.start()</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Your GPU code here</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> cp.random.random((<span class="dv">5000</span>, <span class="dv">5000</span>))</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> cp.linalg.eig(data)</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Stop profiling</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>cp.cuda.profiler.stop()</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Use nvprof or Nsight Systems for detailed analysis</span></span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>CUDA Python opens up powerful GPU acceleration capabilities for Python developers. Whether you’re processing large datasets, running complex simulations, or implementing machine learning algorithms, the combination of Python’s ease of use and CUDA’s parallel computing power provides significant performance advantages.</p>
<p>Key takeaways:</p>
<ul>
<li>Start with CuPy for NumPy-like GPU operations</li>
<li>Use Numba for custom CUDA kernels in Python</li>
<li>Pay attention to memory management and access patterns</li>
<li>Profile your code to identify bottlenecks</li>
<li>Consider the data transfer overhead between CPU and GPU</li>
</ul>
<p>As GPU computing continues to evolve, CUDA Python remains an essential tool for high-performance computing in the Python ecosystem. The examples and techniques covered in this article provide a solid foundation for building GPU-accelerated applications that can handle the computational demands of modern data science and scientific computing.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python Decorators: A Complete Guide with Useful Examples]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-decorators/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-decorators/</guid>
      <pubDate>Sun, 22 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="python-decorators-a-complete-guide-with-useful-examples" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-decorators/deco.png" class="img-fluid"></p>
<p>Python decorators are one of the most powerful and elegant features of the language. They allow you to modify or enhance the behavior of functions, methods, or classes without permanently altering their structure. This article explores decorators from the ground up and presents several useful decorators you can implement in your projects.</p>
<section id="understanding-decorators" class="level2">
<h2 class="anchored" data-anchor-id="understanding-decorators" id="understanding-decorators">Understanding Decorators</h2>
<p>At its core, a decorator is a function that takes another function as an argument and returns a modified version of that function. Decorators leverage Python’s first-class functions, where functions can be assigned to variables, passed as arguments, and returned from other functions.</p>
<section id="basic-decorator-structure" class="level3">
<h3 class="anchored" data-anchor-id="basic-decorator-structure" id="basic-decorator-structure">Basic Decorator Structure</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> my_decorator(func):</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Code to execute before the original function</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Code to execute after the original function</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Using the decorator</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="at">@my_decorator</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> my_function():</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Hello, World!"</span>)</span></code></pre></div></div>
<p>The <code>@my_decorator</code> syntax is equivalent to writing <code>my_function = my_decorator(my_function)</code>.</p>
</section>
</section>
<section id="essential-decorator-patterns" class="level2">
<h2 class="anchored" data-anchor-id="essential-decorator-patterns" id="essential-decorator-patterns">Essential Decorator Patterns</h2>
<section id="timing-decorator" class="level3">
<h3 class="anchored" data-anchor-id="timing-decorator" id="timing-decorator">1. Timing Decorator</h3>
<p>This decorator measures how long a function takes to execute, perfect for performance monitoring.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timer(func):</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        end_time <span class="op">=</span> time.time()</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> took </span><span class="sc">{</span>end_time <span class="op">-</span> start_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="at">@timer</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> slow_function():</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    time.sleep(<span class="dv">1</span>)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"Done!"</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>slow_function()  <span class="co"># Output: slow_function took 1.0041 seconds</span></span></code></pre></div></div>
</section>
<section id="retry-decorator" class="level3">
<h3 class="anchored" data-anchor-id="retry-decorator" id="retry-decorator">2. Retry Decorator</h3>
<p>Automatically retries a function if it fails, useful for network requests or unreliable operations.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> retry(max_attempts<span class="op">=</span><span class="dv">3</span>, delay<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> attempt <span class="kw">in</span> <span class="bu">range</span>(max_attempts):</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>                <span class="cf">try</span>:</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>                <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> attempt <span class="op">==</span> max_attempts <span class="op">-</span> <span class="dv">1</span>:</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">raise</span> e</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(<span class="ss">f"Attempt </span><span class="sc">{</span>attempt <span class="op">+</span> <span class="dv">1</span><span class="sc">}</span><span class="ss"> failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">. Retrying in </span><span class="sc">{</span>delay<span class="sc">}</span><span class="ss"> seconds..."</span>)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>                    time.sleep(delay)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a><span class="at">@retry</span>(max_attempts<span class="op">=</span><span class="dv">3</span>, delay<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> unreliable_function():</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.7</span>:  <span class="co"># 70% chance of failure</span></span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">Exception</span>(<span class="st">"Random failure"</span>)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"Success!"</span></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> unreliable_function()</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(result)</span></code></pre></div></div>
</section>
<section id="cachememoization-decorator" class="level3">
<h3 class="anchored" data-anchor-id="cachememoization-decorator" id="cachememoization-decorator">3. Cache/Memoization Decorator</h3>
<p>Caches function results to avoid expensive recalculations for the same inputs.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memoize(func):</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    cache <span class="op">=</span> {}</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create a key from arguments</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        key <span class="op">=</span> <span class="bu">str</span>(args) <span class="op">+</span> <span class="bu">str</span>(<span class="bu">sorted</span>(kwargs.items()))</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> key <span class="kw">not</span> <span class="kw">in</span> cache:</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>            cache[key] <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Cached result for </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}{</span>args<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Retrieved from cache for </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}{</span>args<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> cache[key]</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a><span class="at">@memoize</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci(n):</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;</span> <span class="dv">2</span>:</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> n</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fibonacci(n<span class="op">-</span><span class="dv">1</span>) <span class="op">+</span> fibonacci(n<span class="op">-</span><span class="dv">2</span>)</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fibonacci(<span class="dv">10</span>))  <span class="co"># Calculates and caches intermediate results</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(fibonacci(<span class="dv">10</span>))  <span class="co"># Retrieved from cache</span></span></code></pre></div></div>
</section>
<section id="logging-decorator" class="level3">
<h3 class="anchored" data-anchor-id="logging-decorator" id="logging-decorator">4. Logging Decorator</h3>
<p>Automatically logs function calls with their arguments and return values.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Configure logging</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>logging.basicConfig(level<span class="op">=</span>logging.INFO, <span class="bu">format</span><span class="op">=</span><span class="st">'</span><span class="sc">%(asctime)s</span><span class="st"> - </span><span class="sc">%(message)s</span><span class="st">'</span>)</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_calls(func):</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">@functools.wraps</span>(func)</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        args_str <span class="op">=</span> <span class="st">', '</span>.join([<span class="bu">repr</span>(arg) <span class="cf">for</span> arg <span class="kw">in</span> args])</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        kwargs_str <span class="op">=</span> <span class="st">', '</span>.join([<span class="ss">f"</span><span class="sc">{</span>k<span class="sc">}</span><span class="ss">=</span><span class="sc">{</span>v<span class="sc">!r}</span><span class="ss">"</span> <span class="cf">for</span> k, v <span class="kw">in</span> kwargs.items()])</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        all_args <span class="op">=</span> <span class="st">', '</span>.join(<span class="bu">filter</span>(<span class="va">None</span>, [args_str, kwargs_str]))</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        logging.info(<span class="ss">f"Calling </span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">(</span><span class="sc">{</span>all_args<span class="sc">}</span><span class="ss">)"</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>            logging.info(<span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> returned </span><span class="sc">{</span>result<span class="sc">!r}</span><span class="ss">"</span>)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            logging.error(<span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> raised </span><span class="sc">{</span><span class="bu">type</span>(e)<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a><span class="at">@log_calls</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> divide(a, b):</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> a <span class="op">/</span> b</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>divide(<span class="dv">10</span>, <span class="dv">2</span>)    <span class="co"># Logs the call and result</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>divide(<span class="dv">10</span>, <span class="dv">0</span>)    <span class="co"># Logs the call and exception</span></span></code></pre></div></div>
</section>
<section id="rate-limiting-decorator" class="level3">
<h3 class="anchored" data-anchor-id="rate-limiting-decorator" id="rate-limiting-decorator">5. Rate Limiting Decorator</h3>
<p>Prevents a function from being called too frequently, useful for API rate limiting.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> defaultdict</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> rate_limit(max_calls<span class="op">=</span><span class="dv">5</span>, window<span class="op">=</span><span class="dv">60</span>):</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    call_times <span class="op">=</span> defaultdict(<span class="bu">list</span>)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            now <span class="op">=</span> time.time()</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>            func_name <span class="op">=</span> func.<span class="va">__name__</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Remove old calls outside the window</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            call_times[func_name] <span class="op">=</span> [</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>                call_time <span class="cf">for</span> call_time <span class="kw">in</span> call_times[func_name]</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> now <span class="op">-</span> call_time <span class="op">&lt;</span> window</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>            ]</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">len</span>(call_times[func_name]) <span class="op">&gt;=</span> max_calls:</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> <span class="pp">Exception</span>(<span class="ss">f"Rate limit exceeded for </span><span class="sc">{</span>func_name<span class="sc">}</span><span class="ss">. Max </span><span class="sc">{</span>max_calls<span class="sc">}</span><span class="ss"> calls per </span><span class="sc">{</span>window<span class="sc">}</span><span class="ss"> seconds."</span>)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>            call_times[func_name].append(now)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a><span class="at">@rate_limit</span>(max_calls<span class="op">=</span><span class="dv">3</span>, window<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> api_call():</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"API response"</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(api_call())</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>        time.sleep(<span class="dv">2</span>)</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Error: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="validation-decorator" class="level3">
<h3 class="anchored" data-anchor-id="validation-decorator" id="validation-decorator">6. Validation Decorator</h3>
<p>Validates function arguments before execution.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate_types(<span class="op">**</span>expected_types):</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Get function parameter names</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>            <span class="im">import</span> inspect</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>            sig <span class="op">=</span> inspect.signature(func)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            bound_args <span class="op">=</span> sig.bind(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>            bound_args.apply_defaults()</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validate types</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> param_name, expected_type <span class="kw">in</span> expected_types.items():</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> param_name <span class="kw">in</span> bound_args.arguments:</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>                    value <span class="op">=</span> bound_args.arguments[param_name]</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> <span class="kw">not</span> <span class="bu">isinstance</span>(value, expected_type):</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>                        <span class="cf">raise</span> <span class="pp">TypeError</span>(</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>                            <span class="ss">f"Parameter '</span><span class="sc">{</span>param_name<span class="sc">}</span><span class="ss">' must be of type </span><span class="sc">{</span>expected_type<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">, "</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>                            <span class="ss">f"got </span><span class="sc">{</span><span class="bu">type</span>(value)<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss">"</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>                        )</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="at">@validate_types</span>(name<span class="op">=</span><span class="bu">str</span>, age<span class="op">=</span><span class="bu">int</span>, height<span class="op">=</span><span class="bu">float</span>)</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_person(name, age, height<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Person: </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">, </span><span class="sc">{</span>age<span class="sc">}</span><span class="ss"> years old, </span><span class="sc">{</span>height<span class="sc">}</span><span class="ss">m tall"</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(create_person(<span class="st">"Alice"</span>, <span class="dv">30</span>, <span class="fl">1.75</span>))  <span class="co"># Works fine</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>    create_person(<span class="st">"Bob"</span>, <span class="st">"thirty"</span>, <span class="fl">1.80</span>)  <span class="co"># Raises TypeError</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">TypeError</span> <span class="im">as</span> e:</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Validation error: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="deprecated-decorator" class="level3">
<h3 class="anchored" data-anchor-id="deprecated-decorator" id="deprecated-decorator">7. Deprecated Decorator</h3>
<p>Warns users when they call deprecated functions.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> functools</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> warnings</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> deprecated(reason<span class="op">=</span><span class="st">"This function is deprecated"</span>):</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decorator(func):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="at">@functools.wraps</span>(func)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>            warnings.warn(</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>                <span class="ss">f"</span><span class="sc">{</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> is deprecated: </span><span class="sc">{</span>reason<span class="sc">}</span><span class="ss">"</span>,</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>                category<span class="op">=</span><span class="pp">DeprecationWarning</span>,</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>                stacklevel<span class="op">=</span><span class="dv">2</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> wrapper</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> decorator</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a><span class="at">@deprecated</span>(<span class="st">"Use new_function() instead"</span>)</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> old_function():</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"This is the old way"</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> new_function():</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="st">"This is the new way"</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> old_function()  <span class="co"># Prints deprecation warning</span></span></code></pre></div></div>
</section>
</section>
<section id="advanced-decorator-concepts" class="level2">
<h2 class="anchored" data-anchor-id="advanced-decorator-concepts" id="advanced-decorator-concepts">Advanced Decorator Concepts</h2>
<section id="class-based-decorators" class="level3">
<h3 class="anchored" data-anchor-id="class-based-decorators" id="class-based-decorators">Class-Based Decorators</h3>
<p>You can also create decorators using classes by implementing the <code>__call__</code> method:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CountCalls:</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, func):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.func <span class="op">=</span> func</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        functools.update_wrapper(<span class="va">self</span>, func)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>func<span class="sc">.</span><span class="va">__name__</span><span class="sc">}</span><span class="ss"> has been called </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>count<span class="sc">}</span><span class="ss"> times"</span>)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a><span class="at">@CountCalls</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> say_hello():</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Hello!"</span>)</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>say_hello()  <span class="co"># say_hello has been called 1 times</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>say_hello()  <span class="co"># say_hello has been called 2 times</span></span></code></pre></div></div>
</section>
<section id="stacking-decorators" class="level3">
<h3 class="anchored" data-anchor-id="stacking-decorators" id="stacking-decorators">Stacking Decorators</h3>
<p>Multiple decorators can be applied to a single function:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="at">@timer</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="at">@log_calls</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="at">@retry</span>(max_attempts<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> complex_function(x, y):</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> random.random() <span class="op">&lt;</span> <span class="fl">0.5</span>:</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">Exception</span>(<span class="st">"Random failure"</span>)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> x <span class="op">+</span> y</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a><span class="co"># The decorators are applied from bottom to top:</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="co"># complex_function = timer(log_calls(retry(complex_function)))</span></span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<ol type="1">
<li><p><strong>Always use <code>functools.wraps</code></strong>: This preserves the original function’s metadata (name, docstring, etc.).</p></li>
<li><p><strong>Handle arguments properly</strong>: Use <code>*args</code> and <code>**kwargs</code> to ensure your decorator works with any function signature.</p></li>
<li><p><strong>Consider performance</strong>: Be mindful of the overhead your decorators add, especially in performance-critical code.</p></li>
<li><p><strong>Make decorators configurable</strong>: Use decorator factories (decorators that return decorators) to make them more flexible.</p></li>
<li><p><strong>Document your decorators</strong>: Clear documentation helps other developers understand what your decorators do and how to use them.</p></li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Decorators are a powerful tool for writing clean, maintainable code. They allow you to separate concerns, reduce code duplication, and add functionality to existing functions without modifying their core logic. The decorators presented in this article provide a solid foundation for common programming tasks like logging, caching, validation, and error handling.</p>
<p>Start by incorporating simple decorators like the timer and logging decorators into your projects, then gradually explore more advanced patterns as your needs grow. Remember that the key to effective decorator use is keeping them focused on a single responsibility and making them as reusable as possible.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DeepSpeed with PyTorch: Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/deepspeed/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/deepspeed/</guid>
      <pubDate>Sat, 21 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="deepspeed-with-pytorch-complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/deepspeed/DeepSpeed.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. It provides system innovations like ZeRO (Zero Redundancy Optimizer) to enable training massive models with trillions of parameters.</p>
<p>Key benefits:</p>
<ul>
<li><strong>Memory Efficiency</strong>: ZeRO reduces memory consumption by partitioning optimizer states, gradients, and model parameters</li>
<li><strong>Speed</strong>: Achieves high training throughput through optimized kernels and communication</li>
<li><strong>Scale</strong>: Enables training of models with billions/trillions of parameters</li>
<li><strong>Ease of Use</strong>: Simple integration with existing PyTorch code</li>
</ul>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co">#| eval: false</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Install DeepSpeed</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install deepspeed</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Or install from source for latest features</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/microsoft/DeepSpeed.git</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> DeepSpeed</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install .</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Verify installation</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="ex">ds_report</span></span></code></pre></div></div>
</section>
<section id="basic-setup" class="level2">
<h2 class="anchored" data-anchor-id="basic-setup" id="basic-setup">Basic Setup</h2>
<section id="simple-model-training" class="level3">
<h3 class="anchored" data-anchor-id="simple-model-training" id="simple-model-training">Simple Model Training</h3>
<div id="ded34062" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> deepspeed</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, Dataset</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Define a simple model</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(nn.Module):</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size<span class="op">=</span><span class="dv">1000</span>, hidden_size<span class="op">=</span><span class="dv">2000</span>, output_size<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.Sequential(</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>            nn.Linear(input_size, hidden_size),</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_size, hidden_size),</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_size, output_size)</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.layers(x)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Dummy dataset</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DummyDataset(Dataset):</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, size<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.size <span class="op">=</span> size</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.size</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.randn(<span class="dv">1000</span>), torch.randn(<span class="dv">1000</span>)</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize model and data</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel()</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>dataset <span class="op">=</span> DummyDataset()</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>dataloader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize DeepSpeed</span></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>model_engine, optimizer, _, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span>model,</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>    model_parameters<span class="op">=</span>model.parameters(),</span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>    config_params<span class="op">=</span>{</span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>        <span class="st">"train_batch_size"</span>: <span class="dv">32</span>,</span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>        <span class="st">"optimizer"</span>: {</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>        },</span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>        <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop</span></span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model_engine(data)</span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> nn.MSELoss()(outputs, target)</span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass</span></span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>        model_engine.backward(loss)</span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        model_engine.step()</span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch: </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="configuration-file-approach" class="level3">
<h3 class="anchored" data-anchor-id="configuration-file-approach" id="configuration-file-approach">Configuration File Approach</h3>
<div id="d608dd81" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> deepspeed</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> argparse</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    parser <span class="op">=</span> argparse.ArgumentParser()</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--local_rank'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=-</span><span class="dv">1</span>,</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>                       <span class="bu">help</span><span class="op">=</span><span class="st">'local rank passed from distributed launcher'</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--deepspeed_config'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">str</span>, default<span class="op">=</span><span class="st">'ds_config.json'</span>,</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>                       <span class="bu">help</span><span class="op">=</span><span class="st">'deepspeed config file'</span>)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    args <span class="op">=</span> parser.parse_args()</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize distributed training</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    deepspeed.init_distributed()</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleModel()</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize with config file</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    model_engine, optimizer, trainloader, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        args<span class="op">=</span>args,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        model<span class="op">=</span>model,</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        model_parameters<span class="op">=</span>model.parameters(),</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>        training_data<span class="op">=</span>dataset</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> step, batch <span class="kw">in</span> <span class="bu">enumerate</span>(trainloader):</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> model_engine(batch)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        model_engine.backward(loss)</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        model_engine.step()</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">'__main__'</span>:</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="configuration-files" class="level2">
<h2 class="anchored" data-anchor-id="configuration-files" id="configuration-files">Configuration Files</h2>
<section id="basic-configuration" class="level3">
<h3 class="anchored" data-anchor-id="basic-configuration" id="basic-configuration">Basic Configuration</h3>
<p>Create a file called <code>ds_config.json</code>:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode json code-with-copy"><code class="sourceCode json"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="fu">{</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"train_batch_size"</span><span class="fu">:</span> <span class="dv">64</span><span class="fu">,</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"train_micro_batch_size_per_gpu"</span><span class="fu">:</span> <span class="dv">16</span><span class="fu">,</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"gradient_accumulation_steps"</span><span class="fu">:</span> <span class="dv">1</span><span class="fu">,</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"optimizer"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"type"</span><span class="fu">:</span> <span class="st">"Adam"</span><span class="fu">,</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"params"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"lr"</span><span class="fu">:</span> <span class="dv">3e-5</span><span class="fu">,</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"betas"</span><span class="fu">:</span> <span class="ot">[</span><span class="fl">0.8</span><span class="ot">,</span> <span class="fl">0.999</span><span class="ot">]</span><span class="fu">,</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"eps"</span><span class="fu">:</span> <span class="dv">1e-8</span><span class="fu">,</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"weight_decay"</span><span class="fu">:</span> <span class="dv">3e-7</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"scheduler"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"type"</span><span class="fu">:</span> <span class="st">"WarmupLR"</span><span class="fu">,</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"params"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"warmup_min_lr"</span><span class="fu">:</span> <span class="dv">0</span><span class="fu">,</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"warmup_max_lr"</span><span class="fu">:</span> <span class="dv">3e-5</span><span class="fu">,</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"warmup_num_steps"</span><span class="fu">:</span> <span class="dv">1000</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"fp16"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"enabled"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"auto_cast"</span><span class="fu">:</span> <span class="kw">false</span><span class="fu">,</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"loss_scale"</span><span class="fu">:</span> <span class="dv">0</span><span class="fu">,</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"initial_scale_power"</span><span class="fu">:</span> <span class="dv">16</span><span class="fu">,</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"loss_scale_window"</span><span class="fu">:</span> <span class="dv">1000</span><span class="fu">,</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"hysteresis"</span><span class="fu">:</span> <span class="dv">2</span><span class="fu">,</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"min_loss_scale"</span><span class="fu">:</span> <span class="dv">1</span></span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"zero_optimization"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage"</span><span class="fu">:</span> <span class="dv">2</span><span class="fu">,</span></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"allgather_partitions"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"allgather_bucket_size"</span><span class="fu">:</span> <span class="dv">2e8</span><span class="fu">,</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"overlap_comm"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"reduce_scatter"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"reduce_bucket_size"</span><span class="fu">:</span> <span class="dv">2e8</span><span class="fu">,</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"contiguous_gradients"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"gradient_clipping"</span><span class="fu">:</span> <span class="fl">1.0</span><span class="fu">,</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"wall_clock_breakdown"</span><span class="fu">:</span> <span class="kw">false</span></span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a><span class="fu">}</span></span></code></pre></div></div>
</section>
<section id="advanced-configuration-with-zero-stage-3" class="level3">
<h3 class="anchored" data-anchor-id="advanced-configuration-with-zero-stage-3" id="advanced-configuration-with-zero-stage-3">Advanced Configuration with ZeRO Stage 3</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode json code-with-copy"><code class="sourceCode json"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="fu">{</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"train_batch_size"</span><span class="fu">:</span> <span class="dv">64</span><span class="fu">,</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"train_micro_batch_size_per_gpu"</span><span class="fu">:</span> <span class="dv">4</span><span class="fu">,</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"gradient_accumulation_steps"</span><span class="fu">:</span> <span class="dv">4</span><span class="fu">,</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"optimizer"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"type"</span><span class="fu">:</span> <span class="st">"AdamW"</span><span class="fu">,</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"params"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"lr"</span><span class="fu">:</span> <span class="dv">3e-4</span><span class="fu">,</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"betas"</span><span class="fu">:</span> <span class="ot">[</span><span class="fl">0.9</span><span class="ot">,</span> <span class="fl">0.95</span><span class="ot">]</span><span class="fu">,</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"eps"</span><span class="fu">:</span> <span class="dv">1e-8</span><span class="fu">,</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"weight_decay"</span><span class="fu">:</span> <span class="fl">0.1</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"fp16"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"enabled"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"auto_cast"</span><span class="fu">:</span> <span class="kw">false</span><span class="fu">,</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"loss_scale"</span><span class="fu">:</span> <span class="dv">0</span><span class="fu">,</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"initial_scale_power"</span><span class="fu">:</span> <span class="dv">16</span><span class="fu">,</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"loss_scale_window"</span><span class="fu">:</span> <span class="dv">1000</span><span class="fu">,</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"hysteresis"</span><span class="fu">:</span> <span class="dv">2</span><span class="fu">,</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"min_loss_scale"</span><span class="fu">:</span> <span class="dv">1</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"zero_optimization"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage"</span><span class="fu">:</span> <span class="dv">3</span><span class="fu">,</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"offload_optimizer"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"device"</span><span class="fu">:</span> <span class="st">"cpu"</span><span class="fu">,</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"pin_memory"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    <span class="fu">},</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"offload_param"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"device"</span><span class="fu">:</span> <span class="st">"cpu"</span><span class="fu">,</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"pin_memory"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>    <span class="fu">},</span></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"overlap_comm"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"contiguous_gradients"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"sub_group_size"</span><span class="fu">:</span> <span class="dv">1e9</span><span class="fu">,</span></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"reduce_bucket_size"</span><span class="fu">:</span> <span class="st">"auto"</span><span class="fu">,</span></span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage3_prefetch_bucket_size"</span><span class="fu">:</span> <span class="st">"auto"</span><span class="fu">,</span></span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage3_param_persistence_threshold"</span><span class="fu">:</span> <span class="st">"auto"</span><span class="fu">,</span></span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage3_max_live_parameters"</span><span class="fu">:</span> <span class="dv">1e9</span><span class="fu">,</span></span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage3_max_reuse_distance"</span><span class="fu">:</span> <span class="dv">1e9</span><span class="fu">,</span></span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"stage3_gather_16bit_weights_on_model_save"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>  <span class="fu">},</span></span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"activation_checkpointing"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"partition_activations"</span><span class="fu">:</span> <span class="kw">false</span><span class="fu">,</span></span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"cpu_checkpointing"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"contiguous_memory_optimization"</span><span class="fu">:</span> <span class="kw">false</span><span class="fu">,</span></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"number_checkpoints"</span><span class="fu">:</span> <span class="kw">null</span><span class="fu">,</span></span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"synchronize_checkpoint_boundary"</span><span class="fu">:</span> <span class="kw">false</span><span class="fu">,</span></span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"profile"</span><span class="fu">:</span> <span class="kw">false</span></span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>  <span class="fu">}</span></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a><span class="fu">}</span></span></code></pre></div></div>
</section>
</section>
<section id="zero-optimizer-states" class="level2">
<h2 class="anchored" data-anchor-id="zero-optimizer-states" id="zero-optimizer-states">ZeRO Optimizer States</h2>
<section id="zero-stage-1-optimizer-state-partitioning" class="level3">
<h3 class="anchored" data-anchor-id="zero-stage-1-optimizer-state-partitioning" id="zero-stage-1-optimizer-state-partitioning">ZeRO Stage 1: Optimizer State Partitioning</h3>
<div id="d9c3302d" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration for ZeRO Stage 1</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>zero_stage1_config <span class="op">=</span> {</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"zero_optimization"</span>: {</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage"</span>: <span class="dv">1</span>,</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"reduce_bucket_size"</span>: <span class="fl">5e8</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>model_engine, optimizer, _, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span>model,</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    model_parameters<span class="op">=</span>model.parameters(),</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    config_params<span class="op">=</span>zero_stage1_config</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="zero-stage-2-gradient-optimizer-state-partitioning" class="level3">
<h3 class="anchored" data-anchor-id="zero-stage-2-gradient-optimizer-state-partitioning" id="zero-stage-2-gradient-optimizer-state-partitioning">ZeRO Stage 2: Gradient + Optimizer State Partitioning</h3>
<div id="941fda56" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration for ZeRO Stage 2</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>zero_stage2_config <span class="op">=</span> {</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"zero_optimization"</span>: {</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage"</span>: <span class="dv">2</span>,</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"allgather_partitions"</span>: <span class="va">True</span>,</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"allgather_bucket_size"</span>: <span class="fl">2e8</span>,</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"overlap_comm"</span>: <span class="va">True</span>,</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"reduce_scatter"</span>: <span class="va">True</span>,</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">"reduce_bucket_size"</span>: <span class="fl">2e8</span>,</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">"contiguous_gradients"</span>: <span class="va">True</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="zero-stage-3-full-parameter-partitioning" class="level3">
<h3 class="anchored" data-anchor-id="zero-stage-3-full-parameter-partitioning" id="zero-stage-3-full-parameter-partitioning">ZeRO Stage 3: Full Parameter Partitioning</h3>
<div id="57eb499e" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration for ZeRO Stage 3</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>zero_stage3_config <span class="op">=</span> {</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">32</span>,</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_micro_batch_size_per_gpu"</span>: <span class="dv">8</span>,</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">"zero_optimization"</span>: {</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage"</span>: <span class="dv">3</span>,</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"overlap_comm"</span>: <span class="va">True</span>,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"contiguous_gradients"</span>: <span class="va">True</span>,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"sub_group_size"</span>: <span class="fl">1e9</span>,</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">"reduce_bucket_size"</span>: <span class="st">"auto"</span>,</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage3_prefetch_bucket_size"</span>: <span class="st">"auto"</span>,</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage3_param_persistence_threshold"</span>: <span class="st">"auto"</span>,</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage3_max_live_parameters"</span>: <span class="fl">1e9</span>,</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage3_max_reuse_distance"</span>: <span class="fl">1e9</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Special handling for ZeRO Stage 3</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> deepspeed.zero.Init(config_dict_or_path<span class="op">=</span>zero_stage3_config):</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleModel()</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>model_engine, optimizer, _, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span>model,</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    config_params<span class="op">=</span>zero_stage3_config</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="model-parallelism" class="level2">
<h2 class="anchored" data-anchor-id="model-parallelism" id="model-parallelism">Model Parallelism</h2>
<section id="pipeline-parallelism" class="level3">
<h3 class="anchored" data-anchor-id="pipeline-parallelism" id="pipeline-parallelism">Pipeline Parallelism</h3>
<div id="eb6d12a9" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> deepspeed</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> deepspeed.pipe <span class="im">import</span> PipelineModule</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PipelineModel(nn.Module):</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, layers_per_stage<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList([</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">1000</span>, <span class="dv">1000</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">8</span>)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> torch.relu(layer(x))</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to pipeline model</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> partition_layers():</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    layers <span class="op">=</span> []</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">8</span>):</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Sequential(</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">1000</span>, <span class="dv">1000</span>),</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            nn.ReLU()</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        ))</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> layers</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Create pipeline</span></span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> PipelineModule(</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    layers<span class="op">=</span>partition_layers(),</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>    num_stages<span class="op">=</span><span class="dv">4</span>,  <span class="co"># Number of pipeline stages</span></span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    partition_method<span class="op">=</span><span class="st">'type:Linear'</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a><span class="co"># Pipeline-specific config</span></span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>pipeline_config <span class="op">=</span> {</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_micro_batch_size_per_gpu"</span>: <span class="dv">16</span>,</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gradient_accumulation_steps"</span>: <span class="dv">1</span>,</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>    <span class="st">"pipeline"</span>: {</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stages"</span>: <span class="st">"auto"</span>,</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>        <span class="st">"partition"</span>: <span class="st">"balanced"</span></span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a>engine, _, _, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>    model<span class="op">=</span>model,</span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>    config_params<span class="op">=</span>pipeline_config</span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="tensor-parallelism-with-megatron" class="level3">
<h3 class="anchored" data-anchor-id="tensor-parallelism-with-megatron" id="tensor-parallelism-with-megatron">Tensor Parallelism (with Megatron)</h3>
<div id="08527e7e" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example using DeepSpeed with Megatron-style tensor parallelism</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> deepspeed</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> deepspeed.moe <span class="im">import</span> MoE</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TensorParallelLinear(nn.Module):</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, output_size, world_size):</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.world_size <span class="op">=</span> world_size</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rank <span class="op">=</span> torch.distributed.get_rank()</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Split output dimension across ranks</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output_size_per_partition <span class="op">=</span> output_size <span class="op">//</span> world_size</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.weight <span class="op">=</span> nn.Parameter(</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>            torch.randn(input_size, <span class="va">self</span>.output_size_per_partition)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> torch.matmul(x, <span class="va">self</span>.weight)</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># All-gather outputs from all partitions</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>        gathered <span class="op">=</span> [torch.zeros_like(output) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.world_size)]</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>        torch.distributed.all_gather(gathered, output)</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.cat(gathered, dim<span class="op">=-</span><span class="dv">1</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="mixed-precision-training" class="level2">
<h2 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">Mixed Precision Training</h2>
<section id="fp16-configuration" class="level3">
<h3 class="anchored" data-anchor-id="fp16-configuration" id="fp16-configuration">FP16 Configuration</h3>
<div id="cf1d913b" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a>fp16_config <span class="op">=</span> {</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"enabled"</span>: <span class="va">True</span>,</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"auto_cast"</span>: <span class="va">False</span>,</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"loss_scale"</span>: <span class="dv">0</span>,</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"initial_scale_power"</span>: <span class="dv">16</span>,</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"loss_scale_window"</span>: <span class="dv">1000</span>,</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"hysteresis"</span>: <span class="dv">2</span>,</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="st">"min_loss_scale"</span>: <span class="dv">1</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom loss scaling</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_custom_scaling(model_engine, dataloader):</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model_engine(batch)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> compute_loss(outputs, batch)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># DeepSpeed handles scaling automatically</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        model_engine.backward(loss)</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        model_engine.step()</span></code></pre></div></div>
</div>
</section>
<section id="bf16-configuration" class="level3">
<h3 class="anchored" data-anchor-id="bf16-configuration" id="bf16-configuration">BF16 Configuration</h3>
<div id="32fcb0a1" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a>bf16_config <span class="op">=</span> {</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"bf16"</span>: {</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"enabled"</span>: <span class="va">True</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="activation-checkpointing" class="level3">
<h3 class="anchored" data-anchor-id="activation-checkpointing" id="activation-checkpointing">Activation Checkpointing</h3>
<div id="7d263481" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a>activation_checkpointing_config <span class="op">=</span> {</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"activation_checkpointing"</span>: {</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"partition_activations"</span>: <span class="va">False</span>,</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"cpu_checkpointing"</span>: <span class="va">True</span>,</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"contiguous_memory_optimization"</span>: <span class="va">False</span>,</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"number_checkpoints"</span>: <span class="va">None</span>,</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"synchronize_checkpoint_boundary"</span>: <span class="va">False</span>,</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"profile"</span>: <span class="va">False</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="cpu-offloading" class="level3">
<h3 class="anchored" data-anchor-id="cpu-offloading" id="cpu-offloading">CPU Offloading</h3>
<div id="e3210b4d" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a>cpu_offload_config <span class="op">=</span> {</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">32</span>,</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"zero_optimization"</span>: {</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage"</span>: <span class="dv">3</span>,</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"offload_optimizer"</span>: {</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">"device"</span>: <span class="st">"cpu"</span>,</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">"pin_memory"</span>: <span class="va">True</span></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        },</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"offload_param"</span>: {</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">"device"</span>: <span class="st">"cpu"</span>,</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">"pin_memory"</span>: <span class="va">True</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="mixture-of-experts-moe" class="level3">
<h3 class="anchored" data-anchor-id="mixture-of-experts-moe" id="mixture-of-experts-moe">Mixture of Experts (MoE)</h3>
<div id="a9a6bdbc" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> deepspeed.moe <span class="im">import</span> MoE</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MoEModel(nn.Module):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.embedding <span class="op">=</span> nn.Embedding(<span class="dv">1000</span>, <span class="dv">512</span>)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># MoE layer</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.moe_layer <span class="op">=</span> MoE(</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            hidden_size<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            expert<span class="op">=</span>nn.Sequential(</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>                nn.Linear(<span class="dv">512</span>, <span class="dv">2048</span>),</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>                nn.ReLU(),</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>                nn.Linear(<span class="dv">2048</span>, <span class="dv">512</span>)</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>            ),</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>            num_experts<span class="op">=</span><span class="dv">8</span>,</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>            k<span class="op">=</span><span class="dv">2</span>  <span class="co"># Top-k routing</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.output <span class="op">=</span> nn.Linear(<span class="dv">512</span>, <span class="dv">1000</span>)</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.embedding(x)</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>        x, _, _ <span class="op">=</span> <span class="va">self</span>.moe_layer(x)</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.output(x)</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>moe_config <span class="op">=</span> {</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">64</span>,</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>    },</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="model-saving-and-loading" class="level3">
<h3 class="anchored" data-anchor-id="model-saving-and-loading" id="model-saving-and-loading">Model Saving and Loading</h3>
<div id="de017b4c" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Saving model</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> save_model(model_engine, checkpoint_dir):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    model_engine.save_checkpoint(checkpoint_dir)</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Loading model</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_model(model_engine, checkpoint_dir):</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    _, client_states <span class="op">=</span> model_engine.load_checkpoint(checkpoint_dir)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> client_states</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>checkpoint_dir <span class="op">=</span> <span class="st">"./checkpoints"</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>save_model(model_engine, checkpoint_dir)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Later, load the model</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>client_states <span class="op">=</span> load_model(model_engine, checkpoint_dir)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<div id="600ce879" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># 1. Memory issues</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Reduce batch size or enable CPU offloading</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co"># 2. Slow training</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Check communication overlap settings</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>overlap_config <span class="op">=</span> {</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"zero_optimization"</span>: {</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"stage"</span>: <span class="dv">2</span>,</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"overlap_comm"</span>: <span class="va">True</span>,</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"contiguous_gradients"</span>: <span class="va">True</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="co"># 3. Gradient explosion</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable gradient clipping</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>gradient_clip_config <span class="op">=</span> {</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gradient_clipping"</span>: <span class="fl">1.0</span></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a><span class="co"># 4. Loss scaling issues with FP16</span></span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Use automatic loss scaling</span></span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>auto_loss_scale_config <span class="op">=</span> {</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>    <span class="st">"fp16"</span>: {</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>        <span class="st">"enabled"</span>: <span class="va">True</span>,</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>        <span class="st">"loss_scale"</span>: <span class="dv">0</span>,  <span class="co"># 0 means automatic</span></span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>        <span class="st">"initial_scale_power"</span>: <span class="dv">16</span></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="debugging-tools" class="level3">
<h3 class="anchored" data-anchor-id="debugging-tools" id="debugging-tools">Debugging Tools</h3>
<div id="dde2b0fd" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable profiling</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>profiling_config <span class="op">=</span> {</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"wall_clock_breakdown"</span>: <span class="va">True</span>,</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"memory_breakdown"</span>: <span class="va">True</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Memory monitoring</span></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> monitor_memory():</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"GPU Memory: </span><span class="sc">{</span>torch<span class="sc">.</span>cuda<span class="sc">.</span>memory_allocated() <span class="op">/</span> <span class="fl">1e9</span><span class="sc">:.2f}</span><span class="ss">GB"</span>)</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"GPU Memory Cached: </span><span class="sc">{</span>torch<span class="sc">.</span>cuda<span class="sc">.</span>memory_reserved() <span class="op">/</span> <span class="fl">1e9</span><span class="sc">:.2f}</span><span class="ss">GB"</span>)</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Communication profiling</span></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> profile_communication():</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    torch.distributed.barrier()  <span class="co"># Synchronize all processes</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Your training step here</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    torch.distributed.barrier()</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>    end_time <span class="op">=</span> time.time()</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Step time: </span><span class="sc">{</span>end_time <span class="op">-</span> start_time<span class="sc">:.4f}</span><span class="ss">s"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="batch-size-tuning" class="level3">
<h3 class="anchored" data-anchor-id="batch-size-tuning" id="batch-size-tuning">1. Batch Size Tuning</h3>
<div id="821aa955" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Find optimal batch size</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> find_optimal_batch_size(model, start_batch_size<span class="op">=</span><span class="dv">16</span>):</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> start_batch_size</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> <span class="va">True</span>:</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>            config <span class="op">=</span> {</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>                <span class="st">"train_micro_batch_size_per_gpu"</span>: batch_size,</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>                <span class="st">"gradient_accumulation_steps"</span>: <span class="dv">64</span> <span class="op">//</span> batch_size,</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>                <span class="st">"optimizer"</span>: {<span class="st">"type"</span>: <span class="st">"Adam"</span>, <span class="st">"params"</span>: {<span class="st">"lr"</span>: <span class="fl">0.001</span>}},</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>                <span class="st">"fp16"</span>: {<span class="st">"enabled"</span>: <span class="va">True</span>}</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>            model_engine, _, _, _ <span class="op">=</span> deepspeed.initialize(</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>                model<span class="op">=</span>model, config_params<span class="op">=</span>config</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Test with dummy data</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>            dummy_input <span class="op">=</span> torch.randn(batch_size, <span class="dv">1000</span>).cuda()</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model_engine(dummy_input)</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> output.<span class="bu">sum</span>()</span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>            model_engine.backward(loss)</span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>            model_engine.step()</span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Batch size </span><span class="sc">{</span>batch_size<span class="sc">}</span><span class="ss"> works!"</span>)</span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>            batch_size <span class="op">*=</span> <span class="dv">2</span></span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">RuntimeError</span> <span class="im">as</span> e:</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">"out of memory"</span> <span class="kw">in</span> <span class="bu">str</span>(e):</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Max batch size: </span><span class="sc">{</span>batch_size <span class="op">//</span> <span class="dv">2</span><span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> e</span></code></pre></div></div>
</div>
</section>
<section id="learning-rate-scaling" class="level3">
<h3 class="anchored" data-anchor-id="learning-rate-scaling" id="learning-rate-scaling">2. Learning Rate Scaling</h3>
<div id="34ec293f" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Scale learning rate with batch size</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> scale_learning_rate(base_lr, base_batch_size, actual_batch_size):</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> base_lr <span class="op">*</span> (actual_batch_size <span class="op">/</span> base_batch_size)</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Example</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>base_config <span class="op">=</span> {</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"train_batch_size"</span>: <span class="dv">1024</span>,</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"optimizer"</span>: {</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"type"</span>: <span class="st">"Adam"</span>,</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"params"</span>: {</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">"lr"</span>: scale_learning_rate(<span class="fl">3e-4</span>, <span class="dv">64</span>, <span class="dv">1024</span>)</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</div>
</section>
<section id="efficient-data-loading" class="level3">
<h3 class="anchored" data-anchor-id="efficient-data-loading" id="efficient-data-loading">3. Efficient Data Loading</h3>
<div id="69d502cc" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EfficientDataLoader:</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dataset, batch_size, num_workers<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dataloader <span class="op">=</span> DataLoader(</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>            dataset,</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>            batch_size<span class="op">=</span>batch_size,</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>            shuffle<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>            num_workers<span class="op">=</span>num_workers,</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>            pin_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>            persistent_workers<span class="op">=</span><span class="va">True</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__iter__</span>(<span class="va">self</span>):</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> <span class="va">self</span>.dataloader:</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Move to GPU asynchronously</span></span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>            batch <span class="op">=</span> [x.cuda(non_blocking<span class="op">=</span><span class="va">True</span>) <span class="cf">for</span> x <span class="kw">in</span> batch]</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">yield</span> batch</span></code></pre></div></div>
</div>
</section>
<section id="model-architecture-tips" class="level3">
<h3 class="anchored" data-anchor-id="model-architecture-tips" id="model-architecture-tips">4. Model Architecture Tips</h3>
<div id="4e94347b" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use activation checkpointing for large models</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CheckpointedModel(nn.Module):</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> nn.ModuleList([</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">1000</span>, <span class="dv">1000</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>)</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Checkpoint every 10 layers</span></span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(<span class="va">self</span>.layers), <span class="dv">10</span>):</span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>            <span class="kw">def</span> create_forward(start_idx):</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>                <span class="kw">def</span> forward_chunk(x):</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">for</span> j <span class="kw">in</span> <span class="bu">range</span>(start_idx, <span class="bu">min</span>(start_idx <span class="op">+</span> <span class="dv">10</span>, <span class="bu">len</span>(<span class="va">self</span>.layers))):</span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>                        x <span class="op">=</span> torch.relu(<span class="va">self</span>.layers[j](x))</span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">return</span> x</span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> forward_chunk</span>
<span id="cb22-18"><a href="#cb22-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb22-19"><a href="#cb22-19" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> torch.utils.checkpoint.checkpoint(create_forward(i), x)</span>
<span id="cb22-20"><a href="#cb22-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</div>
</section>
<section id="multi-node-training-script" class="level3">
<h3 class="anchored" data-anchor-id="multi-node-training-script" id="multi-node-training-script">5. Multi-Node Training Script</h3>
<div id="ce834fd7" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># launch_script.py</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> subprocess</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> sys</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> launch_distributed_training():</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>    cmd <span class="op">=</span> [</span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"deepspeed"</span>,</span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"--num_gpus=8"</span>,</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"--num_nodes=4"</span>,</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"--master_addr=your_master_node"</span>,</span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"--master_port=29500"</span>,</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">"train.py"</span>,</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>        <span class="st">"--deepspeed_config=ds_config.json"</span></span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>    ]</span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>    subprocess.run(cmd)</span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-18"><a href="#cb23-18" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb23-19"><a href="#cb23-19" aria-hidden="true" tabindex="-1"></a>    launch_distributed_training()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide covers the essential aspects of using DeepSpeed with PyTorch. Remember to experiment with different configurations based on your specific model architecture and hardware setup. Start with simpler configurations (ZeRO Stage 1-2) and gradually move to more advanced features (ZeRO Stage 3, CPU offloading) as needed.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Getting Started
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li>Start with ZeRO Stage 1 or 2 for your first DeepSpeed experiments</li>
<li>Use FP16 mixed precision to reduce memory usage</li>
<li>Tune batch sizes to maximize GPU utilization</li>
<li>Monitor memory usage and communication overhead</li>
<li>Scale learning rates appropriately with batch size changes</li>
</ol>
</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Common Pitfalls
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Not setting <code>gradient_accumulation_steps</code> correctly</li>
<li>Using too large batch sizes leading to OOM errors</li>
<li>Not enabling communication overlap for better performance</li>
<li>Forgetting to scale learning rates when changing batch sizes</li>
</ul>
</div>
</div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch Model Deployment on Edge Devices - Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/edge-device-deployment/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/edge-device-deployment/</guid>
      <pubDate>Sat, 21 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="pytorch-model-deployment-on-edge-devices---complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/edge-device-deployment/edge.png" class="img-fluid"></p>
<section id="prerequisites" class="level2">
<h2 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install required packages</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch-model-archiver</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install onnx onnxruntime</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install tensorflow  <span class="co"># for TensorFlow Lite conversion</span></span></code></pre></div></div>
</section>
<section id="model-optimization" class="level2">
<h2 class="anchored" data-anchor-id="model-optimization" id="model-optimization">1. Model Optimization</h2>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">Quantization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization <span class="im">as</span> quantization</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.quantization <span class="im">import</span> get_default_qconfig</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.models <span class="im">as</span> models</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Load your trained model</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> models.resnet18(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Post-training quantization (easiest method)</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> post_training_quantization(model, sample_data):</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a><span class="co">    Apply post-training quantization to reduce model size</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set model to evaluation mode</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Fuse conv, bn and relu</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>    model_fused <span class="op">=</span> torch.quantization.fuse_modules(model, [[<span class="st">'conv1'</span>, <span class="st">'bn1'</span>, <span class="st">'relu'</span>]])</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Specify quantization configuration</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>    model_fused.qconfig <span class="op">=</span> torch.quantization.get_default_qconfig(<span class="st">'fbgemm'</span>)</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare the model for quantization</span></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>    model_prepared <span class="op">=</span> torch.quantization.prepare(model_fused)</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calibrate with sample data</span></span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data <span class="kw">in</span> sample_data:</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>            model_prepared(data)</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to quantized model</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>    quantized_model <span class="op">=</span> torch.quantization.convert(model_prepared)</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> quantized_model</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>sample_data <span class="op">=</span> [torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>)]</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>quantized_model <span class="op">=</span> post_training_quantization(model, sample_data)</span></code></pre></div></div>
</section>
<section id="pruning" class="level3">
<h3 class="anchored" data-anchor-id="pruning" id="pruning">Pruning</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.prune <span class="im">as</span> prune</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> prune_model(model, pruning_amount<span class="op">=</span><span class="fl">0.3</span>):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Apply magnitude-based pruning to reduce model complexity</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    parameters_to_prune <span class="op">=</span> []</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Collect all conv and linear layers</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, module <span class="kw">in</span> model.named_modules():</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(module, (torch.nn.Conv2d, torch.nn.Linear)):</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>            parameters_to_prune.append((module, <span class="st">'weight'</span>))</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply global magnitude pruning</span></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    prune.global_unstructured(</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        parameters_to_prune,</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        pruning_method<span class="op">=</span>prune.L1Unstructured,</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        amount<span class="op">=</span>pruning_amount,</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Remove pruning reparameterization to make pruning permanent</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> module, param <span class="kw">in</span> parameters_to_prune:</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        prune.remove(module, param)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply pruning</span></span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>pruned_model <span class="op">=</span> prune_model(model.copy(), pruning_amount<span class="op">=</span><span class="fl">0.3</span>)</span></code></pre></div></div>
</section>
</section>
<section id="model-conversion" class="level2">
<h2 class="anchored" data-anchor-id="model-conversion" id="model-conversion">2. Model Conversion</h2>
<section id="convert-to-torchscript" class="level3">
<h3 class="anchored" data-anchor-id="convert-to-torchscript" id="convert-to-torchscript">Convert to TorchScript</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> convert_to_torchscript(model, sample_input, save_path):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Convert PyTorch model to TorchScript for deployment</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Method 1: Tracing (recommended for models without control flow)</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        traced_model <span class="op">=</span> torch.jit.trace(model, sample_input)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        traced_model.save(save_path)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Model traced and saved to </span><span class="sc">{</span>save_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> traced_model</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Tracing failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Method 2: Scripting (for models with control flow)</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>            scripted_model <span class="op">=</span> torch.jit.script(model)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>            scripted_model.save(save_path)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Model scripted and saved to </span><span class="sc">{</span>save_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> scripted_model</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Scripting also failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>sample_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>torchscript_model <span class="op">=</span> convert_to_torchscript(model, sample_input, <span class="st">"model.pt"</span>)</span></code></pre></div></div>
</section>
<section id="convert-to-onnx" class="level3">
<h3 class="anchored" data-anchor-id="convert-to-onnx" id="convert-to-onnx">Convert to ONNX</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnx</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> onnxruntime <span class="im">as</span> ort</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> convert_to_onnx(model, sample_input, onnx_path):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co">    Convert PyTorch model to ONNX format</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    torch.onnx.export(</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        model,                      <span class="co"># model being run</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        sample_input,               <span class="co"># model input</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        onnx_path,                 <span class="co"># where to save the model</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        export_params<span class="op">=</span><span class="va">True</span>,         <span class="co"># store the trained parameter weights</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        opset_version<span class="op">=</span><span class="dv">11</span>,          <span class="co"># ONNX version to export to</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        do_constant_folding<span class="op">=</span><span class="va">True</span>,   <span class="co"># optimize constant folding</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        input_names<span class="op">=</span>[<span class="st">'input'</span>],      <span class="co"># model's input names</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        output_names<span class="op">=</span>[<span class="st">'output'</span>],    <span class="co"># model's output names</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        dynamic_axes<span class="op">=</span>{</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'input'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>},</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            <span class="st">'output'</span>: {<span class="dv">0</span>: <span class="st">'batch_size'</span>}</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Verify the ONNX model</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    onnx_model <span class="op">=</span> onnx.load(onnx_path)</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    onnx.checker.check_model(onnx_model)</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"ONNX model saved and verified at </span><span class="sc">{</span>onnx_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to ONNX</span></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>convert_to_onnx(model, sample_input, <span class="st">"model.onnx"</span>)</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a><span class="co"># Test ONNX Runtime inference</span></span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test_onnx_inference(onnx_path, sample_input):</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Test ONNX model inference"""</span></span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>    ort_session <span class="op">=</span> ort.InferenceSession(onnx_path)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert input to numpy</span></span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>    input_np <span class="op">=</span> sample_input.numpy()</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Run inference</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> ort_session.run(<span class="va">None</span>, {<span class="st">'input'</span>: input_np})</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> outputs[<span class="dv">0</span>]</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the converted model</span></span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>onnx_output <span class="op">=</span> test_onnx_inference(<span class="st">"model.onnx"</span>, sample_input)</span></code></pre></div></div>
</section>
<section id="convert-to-tensorflow-lite" class="level3">
<h3 class="anchored" data-anchor-id="convert-to-tensorflow-lite" id="convert-to-tensorflow-lite">Convert to TensorFlow Lite</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> tensorflow <span class="im">as</span> tf</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> pytorch_to_tflite(onnx_path, tflite_path):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="co">    Convert ONNX model to TensorFlow Lite</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert ONNX to TensorFlow</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="im">from</span> onnx_tf.backend <span class="im">import</span> prepare</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> onnx</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    onnx_model <span class="op">=</span> onnx.load(onnx_path)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    tf_rep <span class="op">=</span> prepare(onnx_model)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    tf_rep.export_graph(<span class="st">"temp_tf_model"</span>)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to TensorFlow Lite</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    converter <span class="op">=</span> tf.lite.TFLiteConverter.from_saved_model(<span class="st">"temp_tf_model"</span>)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply optimizations</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    converter.optimizations <span class="op">=</span> [tf.lite.Optimize.DEFAULT]</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert model</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    tflite_model <span class="op">=</span> converter.convert()</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Save the model</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(tflite_path, <span class="st">'wb'</span>) <span class="im">as</span> f:</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        f.write(tflite_model)</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"TensorFlow Lite model saved to </span><span class="sc">{</span>tflite_path<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to TensorFlow Lite</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>pytorch_to_tflite(<span class="st">"model.onnx"</span>, <span class="st">"model.tflite"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="mobile-deployment" class="level2">
<h2 class="anchored" data-anchor-id="mobile-deployment" id="mobile-deployment">3. Mobile Deployment</h2>
<section id="android-deployment" class="level3">
<h3 class="anchored" data-anchor-id="android-deployment" id="android-deployment">Android Deployment</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode java code-with-copy"><code class="sourceCode java"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co">// Android Java code for PyTorch Mobile</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">public</span> <span class="kw">class</span> ModelInference <span class="op">{</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">private</span> Module model<span class="op">;</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">public</span> <span class="fu">ModelInference</span><span class="op">(</span><span class="bu">String</span> modelPath<span class="op">)</span> <span class="op">{</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> LiteModuleLoader<span class="op">.</span><span class="fu">load</span><span class="op">(</span>modelPath<span class="op">);</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">public</span> <span class="dt">float</span><span class="op">[]</span> <span class="fu">predict</span><span class="op">(</span>Bitmap bitmap<span class="op">)</span> <span class="op">{</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Preprocess image</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        Tensor inputTensor <span class="op">=</span> TensorImageUtils<span class="op">.</span><span class="fu">bitmapToFloat32Tensor</span><span class="op">(</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            bitmap<span class="op">,</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            TensorImageUtils<span class="op">.</span><span class="fu">TORCHVISION_NORM_MEAN_RGB</span><span class="op">,</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            TensorImageUtils<span class="op">.</span><span class="fu">TORCHVISION_NORM_STD_RGB</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="op">);</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Run inference</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        Tensor outputTensor <span class="op">=</span> model<span class="op">.</span><span class="fu">forward</span><span class="op">(</span>IValue<span class="op">.</span><span class="fu">from</span><span class="op">(</span>inputTensor<span class="op">)).</span><span class="fu">toTensor</span><span class="op">();</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Get results</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> outputTensor<span class="op">.</span><span class="fu">getDataAsFloatArray</span><span class="op">();</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="ios-deployment-swift" class="level3">
<h3 class="anchored" data-anchor-id="ios-deployment-swift" id="ios-deployment-swift">iOS Deployment (Swift)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode swift code-with-copy"><code class="sourceCode swift"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co">// iOS Swift code for PyTorch Mobile</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">import</span> <span class="im">LibTorch</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ModelInference <span class="op">{</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">private</span> <span class="kw">var</span> <span class="va">model</span><span class="op">:</span> TorchModule</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">init</span><span class="op">(</span>modelPath<span class="op">:</span> String<span class="op">)</span> <span class="op">{</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> TorchModule<span class="op">(</span>fileAtPath<span class="op">:</span> modelPath<span class="op">)!</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">func</span> <span class="fu">predict</span><span class="op">(</span><span class="va">image</span><span class="op">:</span> <span class="dt">UIImage</span><span class="op">)</span> -&gt; [<span class="fu">Float</span>] <span class="op">{</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Preprocess image</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">guard</span> <span class="kw">let</span> <span class="va">pixelBuffer</span> <span class="op">=</span> image<span class="op">.</span>pixelBuffer<span class="op">()</span> <span class="cf">else</span> <span class="op">{</span> <span class="kw">return</span> <span class="op">[]</span> <span class="op">}</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">guard</span> <span class="kw">let</span> <span class="va">inputTensor</span> <span class="op">=</span> TorchTensor<span class="op">.</span>fromPixelBuffer<span class="op">(</span>pixelBuffer<span class="op">)</span> <span class="cf">else</span> <span class="op">{</span> <span class="kw">return</span> <span class="op">[]</span> <span class="op">}</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Run inference</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">guard</span> <span class="kw">let</span> <span class="va">outputTensor</span> <span class="op">=</span> model<span class="op">.</span>predict<span class="op">(</span>inputs<span class="op">:</span> <span class="op">[</span>inputTensor<span class="op">])</span> <span class="cf">else</span> <span class="op">{</span> <span class="kw">return</span> <span class="op">[]</span> <span class="op">}</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">// Get results</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        <span class="kw">return</span> outputTensor<span class="op">[</span><span class="dv">0</span><span class="op">].</span>floatArray</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="python-mobile-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="python-mobile-preprocessing" id="python-mobile-preprocessing">Python Mobile Preprocessing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_mobile_model(model, sample_input):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Create optimized model for mobile deployment</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to TorchScript</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    traced_model <span class="op">=</span> torch.jit.trace(model, sample_input)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Optimize for mobile</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    optimized_model <span class="op">=</span> optimize_for_mobile(traced_model)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Save mobile-optimized model</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>    optimized_model._save_for_lite_interpreter(<span class="st">"mobile_model.ptl"</span>)</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> optimized_model</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.mobile_optimizer <span class="im">import</span> optimize_for_mobile</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Create mobile model</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>mobile_model <span class="op">=</span> create_mobile_model(model, sample_input)</span></code></pre></div></div>
</section>
</section>
<section id="raspberry-pi-deployment" class="level2">
<h2 class="anchored" data-anchor-id="raspberry-pi-deployment" id="raspberry-pi-deployment">4. Raspberry Pi Deployment</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Raspberry Pi deployment script</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> psutil</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RaspberryPiInference:</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> torch.device(device)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> torch.jit.load(model_path, map_location<span class="op">=</span><span class="va">self</span>.device)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define preprocessing transforms</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>            transforms.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>                               std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Performance monitoring</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inference_times <span class="op">=</span> []</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> preprocess_image(<span class="va">self</span>, image_path):</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Preprocess image for inference"""</span></span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).convert(<span class="st">'RGB'</span>)</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> <span class="va">self</span>.transform(image).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> input_tensor.to(<span class="va">self</span>.device)</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inference(<span class="va">self</span>, image_path):</span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run inference on image"""</span></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess</span></span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> <span class="va">self</span>.preprocess_image(image_path)</span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Inference</span></span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(input_tensor)</span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> torch.nn.functional.softmax(outputs[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb10-43"><a href="#cb10-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-44"><a href="#cb10-44" aria-hidden="true" tabindex="-1"></a>        inference_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb10-45"><a href="#cb10-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inference_times.append(inference_time)</span>
<span id="cb10-46"><a href="#cb10-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-47"><a href="#cb10-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> predictions.cpu().numpy(), inference_time</span>
<span id="cb10-48"><a href="#cb10-48" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-49"><a href="#cb10-49" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_system_stats(<span class="va">self</span>):</span>
<span id="cb10-50"><a href="#cb10-50" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get system performance statistics"""</span></span>
<span id="cb10-51"><a href="#cb10-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb10-52"><a href="#cb10-52" aria-hidden="true" tabindex="-1"></a>            <span class="st">'cpu_percent'</span>: psutil.cpu_percent(),</span>
<span id="cb10-53"><a href="#cb10-53" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_percent'</span>: psutil.virtual_memory().percent,</span>
<span id="cb10-54"><a href="#cb10-54" aria-hidden="true" tabindex="-1"></a>            <span class="st">'temperature'</span>: <span class="va">self</span>.get_cpu_temperature()</span>
<span id="cb10-55"><a href="#cb10-55" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb10-56"><a href="#cb10-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-57"><a href="#cb10-57" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_cpu_temperature(<span class="va">self</span>):</span>
<span id="cb10-58"><a href="#cb10-58" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get CPU temperature (Raspberry Pi specific)"""</span></span>
<span id="cb10-59"><a href="#cb10-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb10-60"><a href="#cb10-60" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(<span class="st">'/sys/class/thermal/thermal_zone0/temp'</span>, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb10-61"><a href="#cb10-61" aria-hidden="true" tabindex="-1"></a>                temp <span class="op">=</span> <span class="bu">float</span>(f.read()) <span class="op">/</span> <span class="fl">1000.0</span></span>
<span id="cb10-62"><a href="#cb10-62" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> temp</span>
<span id="cb10-63"><a href="#cb10-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span>:</span>
<span id="cb10-64"><a href="#cb10-64" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb10-65"><a href="#cb10-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-66"><a href="#cb10-66" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb10-67"><a href="#cb10-67" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb10-68"><a href="#cb10-68" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize inference engine</span></span>
<span id="cb10-69"><a href="#cb10-69" aria-hidden="true" tabindex="-1"></a>    inference_engine <span class="op">=</span> RaspberryPiInference(<span class="st">"model.pt"</span>)</span>
<span id="cb10-70"><a href="#cb10-70" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-71"><a href="#cb10-71" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Run inference</span></span>
<span id="cb10-72"><a href="#cb10-72" aria-hidden="true" tabindex="-1"></a>    predictions, inference_time <span class="op">=</span> inference_engine.inference(<span class="st">"test_image.jpg"</span>)</span>
<span id="cb10-73"><a href="#cb10-73" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-74"><a href="#cb10-74" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Inference time: </span><span class="sc">{</span>inference_time<span class="sc">:.3f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb10-75"><a href="#cb10-75" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Top prediction: </span><span class="sc">{</span>predictions<span class="sc">.</span><span class="bu">max</span>()<span class="sc">:.3f}</span><span class="ss">"</span>)</span>
<span id="cb10-76"><a href="#cb10-76" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"System stats: </span><span class="sc">{</span>inference_engine<span class="sc">.</span>get_system_stats()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="nvidia-jetson-deployment" class="level2">
<h2 class="anchored" data-anchor-id="nvidia-jetson-deployment" id="nvidia-jetson-deployment">5. NVIDIA Jetson Deployment</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># NVIDIA Jetson optimized deployment</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> tensorrt <span class="im">as</span> trt</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pycuda.driver <span class="im">as</span> cuda</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pycuda.autoinit</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> JetsonTensorRTInference:</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, onnx_model_path, trt_engine_path<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.onnx_path <span class="op">=</span> onnx_model_path</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.engine_path <span class="op">=</span> trt_engine_path <span class="kw">or</span> onnx_model_path.replace(<span class="st">'.onnx'</span>, <span class="st">'.trt'</span>)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build or load TensorRT engine</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> os.path.exists(<span class="va">self</span>.engine_path):</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.build_engine()</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.engine <span class="op">=</span> <span class="va">self</span>.load_engine()</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.context <span class="op">=</span> <span class="va">self</span>.engine.create_execution_context()</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Allocate GPU memory</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.allocate_buffers()</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> build_engine(<span class="va">self</span>):</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Build TensorRT engine from ONNX model"""</span></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        logger <span class="op">=</span> trt.Logger(trt.Logger.WARNING)</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        builder <span class="op">=</span> trt.Builder(logger)</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        network <span class="op">=</span> builder.create_network(<span class="dv">1</span> <span class="op">&lt;&lt;</span> <span class="bu">int</span>(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        parser <span class="op">=</span> trt.OnnxParser(network, logger)</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Parse ONNX model</span></span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="va">self</span>.onnx_path, <span class="st">'rb'</span>) <span class="im">as</span> model:</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="kw">not</span> parser.parse(model.read()):</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> error <span class="kw">in</span> <span class="bu">range</span>(parser.num_errors):</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(parser.get_error(error))</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Build engine</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        config <span class="op">=</span> builder.create_builder_config()</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        config.max_workspace_size <span class="op">=</span> <span class="dv">1</span> <span class="op">&lt;&lt;</span> <span class="dv">28</span>  <span class="co"># 256MB</span></span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>        config.set_flag(trt.BuilderFlag.FP16)  <span class="co"># Enable FP16 precision</span></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        engine <span class="op">=</span> builder.build_engine(network, config)</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save engine</span></span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="va">self</span>.engine_path, <span class="st">'wb'</span>) <span class="im">as</span> f:</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>            f.write(engine.serialize())</span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> engine</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_engine(<span class="va">self</span>):</span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load TensorRT engine"""</span></span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>        runtime <span class="op">=</span> trt.Runtime(trt.Logger(trt.Logger.WARNING))</span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="va">self</span>.engine_path, <span class="st">'rb'</span>) <span class="im">as</span> f:</span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> runtime.deserialize_cuda_engine(f.read())</span>
<span id="cb11-55"><a href="#cb11-55" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-56"><a href="#cb11-56" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> allocate_buffers(<span class="va">self</span>):</span>
<span id="cb11-57"><a href="#cb11-57" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Allocate GPU memory buffers"""</span></span>
<span id="cb11-58"><a href="#cb11-58" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.bindings <span class="op">=</span> []</span>
<span id="cb11-59"><a href="#cb11-59" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inputs <span class="op">=</span> []</span>
<span id="cb11-60"><a href="#cb11-60" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.outputs <span class="op">=</span> []</span>
<span id="cb11-61"><a href="#cb11-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-62"><a href="#cb11-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> binding <span class="kw">in</span> <span class="va">self</span>.engine:</span>
<span id="cb11-63"><a href="#cb11-63" aria-hidden="true" tabindex="-1"></a>            shape <span class="op">=</span> <span class="va">self</span>.engine.get_binding_shape(binding)</span>
<span id="cb11-64"><a href="#cb11-64" aria-hidden="true" tabindex="-1"></a>            size <span class="op">=</span> trt.volume(shape) <span class="op">*</span> <span class="va">self</span>.engine.max_batch_size</span>
<span id="cb11-65"><a href="#cb11-65" aria-hidden="true" tabindex="-1"></a>            dtype <span class="op">=</span> trt.nptype(<span class="va">self</span>.engine.get_binding_dtype(binding))</span>
<span id="cb11-66"><a href="#cb11-66" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-67"><a href="#cb11-67" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Allocate host and device buffers</span></span>
<span id="cb11-68"><a href="#cb11-68" aria-hidden="true" tabindex="-1"></a>            host_mem <span class="op">=</span> cuda.pagelocked_empty(size, dtype)</span>
<span id="cb11-69"><a href="#cb11-69" aria-hidden="true" tabindex="-1"></a>            device_mem <span class="op">=</span> cuda.mem_alloc(host_mem.nbytes)</span>
<span id="cb11-70"><a href="#cb11-70" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-71"><a href="#cb11-71" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.bindings.append(<span class="bu">int</span>(device_mem))</span>
<span id="cb11-72"><a href="#cb11-72" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-73"><a href="#cb11-73" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.engine.binding_is_input(binding):</span>
<span id="cb11-74"><a href="#cb11-74" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.inputs.append({<span class="st">'host'</span>: host_mem, <span class="st">'device'</span>: device_mem})</span>
<span id="cb11-75"><a href="#cb11-75" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb11-76"><a href="#cb11-76" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.outputs.append({<span class="st">'host'</span>: host_mem, <span class="st">'device'</span>: device_mem})</span>
<span id="cb11-77"><a href="#cb11-77" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-78"><a href="#cb11-78" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inference(<span class="va">self</span>, input_data):</span>
<span id="cb11-79"><a href="#cb11-79" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run TensorRT inference"""</span></span>
<span id="cb11-80"><a href="#cb11-80" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Copy input data to GPU</span></span>
<span id="cb11-81"><a href="#cb11-81" aria-hidden="true" tabindex="-1"></a>        np.copyto(<span class="va">self</span>.inputs[<span class="dv">0</span>][<span class="st">'host'</span>], input_data.ravel())</span>
<span id="cb11-82"><a href="#cb11-82" aria-hidden="true" tabindex="-1"></a>        cuda.memcpy_htod(<span class="va">self</span>.inputs[<span class="dv">0</span>][<span class="st">'device'</span>], <span class="va">self</span>.inputs[<span class="dv">0</span>][<span class="st">'host'</span>])</span>
<span id="cb11-83"><a href="#cb11-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-84"><a href="#cb11-84" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run inference</span></span>
<span id="cb11-85"><a href="#cb11-85" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.context.execute_v2(bindings<span class="op">=</span><span class="va">self</span>.bindings)</span>
<span id="cb11-86"><a href="#cb11-86" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-87"><a href="#cb11-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Copy output data from GPU</span></span>
<span id="cb11-88"><a href="#cb11-88" aria-hidden="true" tabindex="-1"></a>        cuda.memcpy_dtoh(<span class="va">self</span>.outputs[<span class="dv">0</span>][<span class="st">'host'</span>], <span class="va">self</span>.outputs[<span class="dv">0</span>][<span class="st">'device'</span>])</span>
<span id="cb11-89"><a href="#cb11-89" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-90"><a href="#cb11-90" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.outputs[<span class="dv">0</span>][<span class="st">'host'</span>]</span>
<span id="cb11-91"><a href="#cb11-91" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-92"><a href="#cb11-92" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage for Jetson</span></span>
<span id="cb11-93"><a href="#cb11-93" aria-hidden="true" tabindex="-1"></a>jetson_inference <span class="op">=</span> JetsonTensorRTInference(<span class="st">"model.onnx"</span>)</span></code></pre></div></div>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">6. Performance Optimization</h2>
<section id="benchmarking-script" class="level3">
<h3 class="anchored" data-anchor-id="benchmarking-script" id="benchmarking-script">Benchmarking Script</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> psutil</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> contextlib <span class="im">import</span> contextmanager</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="at">@contextmanager</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> timer():</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Context manager for timing code execution"""</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    start <span class="op">=</span> time.perf_counter()</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">yield</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    end <span class="op">=</span> time.perf_counter()</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Execution time: </span><span class="sc">{</span>end <span class="op">-</span> start<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ModelBenchmark:</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, input_shape, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model.to(device)</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.input_shape <span class="op">=</span> input_shape</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> benchmark_inference(<span class="va">self</span>, num_runs<span class="op">=</span><span class="dv">100</span>, warmup_runs<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Benchmark model inference performance"""</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate random input</span></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>        dummy_input <span class="op">=</span> torch.randn(<span class="va">self</span>.input_shape).to(<span class="va">self</span>.device)</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup runs</span></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(warmup_runs):</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>                _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark runs</span></span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>        inference_times <span class="op">=</span> []</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>        memory_usage <span class="op">=</span> []</span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_runs):</span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Monitor memory before inference</span></span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.device <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>                torch.cuda.empty_cache()</span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>                memory_before <span class="op">=</span> torch.cuda.memory_allocated()</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>                memory_before <span class="op">=</span> psutil.Process().memory_info().rss</span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Time inference</span></span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>            start_time <span class="op">=</span> time.perf_counter()</span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.device <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>                torch.cuda.synchronize()</span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>            end_time <span class="op">=</span> time.perf_counter()</span>
<span id="cb12-53"><a href="#cb12-53" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-54"><a href="#cb12-54" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Monitor memory after inference</span></span>
<span id="cb12-55"><a href="#cb12-55" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.device <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb12-56"><a href="#cb12-56" aria-hidden="true" tabindex="-1"></a>                memory_after <span class="op">=</span> torch.cuda.memory_allocated()</span>
<span id="cb12-57"><a href="#cb12-57" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb12-58"><a href="#cb12-58" aria-hidden="true" tabindex="-1"></a>                memory_after <span class="op">=</span> psutil.Process().memory_info().rss</span>
<span id="cb12-59"><a href="#cb12-59" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-60"><a href="#cb12-60" aria-hidden="true" tabindex="-1"></a>            inference_times.append(end_time <span class="op">-</span> start_time)</span>
<span id="cb12-61"><a href="#cb12-61" aria-hidden="true" tabindex="-1"></a>            memory_usage.append(memory_after <span class="op">-</span> memory_before)</span>
<span id="cb12-62"><a href="#cb12-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-63"><a href="#cb12-63" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate statistics</span></span>
<span id="cb12-64"><a href="#cb12-64" aria-hidden="true" tabindex="-1"></a>        stats <span class="op">=</span> {</span>
<span id="cb12-65"><a href="#cb12-65" aria-hidden="true" tabindex="-1"></a>            <span class="st">'mean_time'</span>: np.mean(inference_times),</span>
<span id="cb12-66"><a href="#cb12-66" aria-hidden="true" tabindex="-1"></a>            <span class="st">'std_time'</span>: np.std(inference_times),</span>
<span id="cb12-67"><a href="#cb12-67" aria-hidden="true" tabindex="-1"></a>            <span class="st">'min_time'</span>: np.<span class="bu">min</span>(inference_times),</span>
<span id="cb12-68"><a href="#cb12-68" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_time'</span>: np.<span class="bu">max</span>(inference_times),</span>
<span id="cb12-69"><a href="#cb12-69" aria-hidden="true" tabindex="-1"></a>            <span class="st">'fps'</span>: <span class="fl">1.0</span> <span class="op">/</span> np.mean(inference_times),</span>
<span id="cb12-70"><a href="#cb12-70" aria-hidden="true" tabindex="-1"></a>            <span class="st">'mean_memory'</span>: np.mean(memory_usage),</span>
<span id="cb12-71"><a href="#cb12-71" aria-hidden="true" tabindex="-1"></a>            <span class="st">'max_memory'</span>: np.<span class="bu">max</span>(memory_usage)</span>
<span id="cb12-72"><a href="#cb12-72" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-73"><a href="#cb12-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-74"><a href="#cb12-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> stats</span>
<span id="cb12-75"><a href="#cb12-75" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-76"><a href="#cb12-76" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> profile_model(<span class="va">self</span>):</span>
<span id="cb12-77"><a href="#cb12-77" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Profile model to identify bottlenecks"""</span></span>
<span id="cb12-78"><a href="#cb12-78" aria-hidden="true" tabindex="-1"></a>        dummy_input <span class="op">=</span> torch.randn(<span class="va">self</span>.input_shape).to(<span class="va">self</span>.device)</span>
<span id="cb12-79"><a href="#cb12-79" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-80"><a href="#cb12-80" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.profiler.profile(</span>
<span id="cb12-81"><a href="#cb12-81" aria-hidden="true" tabindex="-1"></a>            activities<span class="op">=</span>[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],</span>
<span id="cb12-82"><a href="#cb12-82" aria-hidden="true" tabindex="-1"></a>            record_shapes<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb12-83"><a href="#cb12-83" aria-hidden="true" tabindex="-1"></a>            profile_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb12-84"><a href="#cb12-84" aria-hidden="true" tabindex="-1"></a>            with_stack<span class="op">=</span><span class="va">True</span></span>
<span id="cb12-85"><a href="#cb12-85" aria-hidden="true" tabindex="-1"></a>        ) <span class="im">as</span> profiler:</span>
<span id="cb12-86"><a href="#cb12-86" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb12-87"><a href="#cb12-87" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.model(dummy_input)</span>
<span id="cb12-88"><a href="#cb12-88" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-89"><a href="#cb12-89" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Print profiling results</span></span>
<span id="cb12-90"><a href="#cb12-90" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(profiler.key_averages().table(sort_by<span class="op">=</span><span class="st">"cuda_time_total"</span>, row_limit<span class="op">=</span><span class="dv">10</span>))</span>
<span id="cb12-91"><a href="#cb12-91" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-92"><a href="#cb12-92" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> profiler</span>
<span id="cb12-93"><a href="#cb12-93" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-94"><a href="#cb12-94" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb12-95"><a href="#cb12-95" aria-hidden="true" tabindex="-1"></a>benchmark <span class="op">=</span> ModelBenchmark(model, (<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>), device<span class="op">=</span><span class="st">'cpu'</span>)</span>
<span id="cb12-96"><a href="#cb12-96" aria-hidden="true" tabindex="-1"></a>stats <span class="op">=</span> benchmark.benchmark_inference()</span>
<span id="cb12-97"><a href="#cb12-97" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Average inference time: </span><span class="sc">{</span>stats[<span class="st">'mean_time'</span>]<span class="sc">:.4f}</span><span class="ss">s"</span>)</span>
<span id="cb12-98"><a href="#cb12-98" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"FPS: </span><span class="sc">{</span>stats[<span class="st">'fps'</span>]<span class="sc">:.2f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="memory-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-optimization" id="memory-optimization">Memory Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> optimize_memory_usage(model):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Apply memory optimization techniques"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Enable memory efficient attention (for transformers)</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">hasattr</span>(model, <span class="st">'enable_memory_efficient_attention'</span>):</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        model.enable_memory_efficient_attention()</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use gradient checkpointing during training</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">hasattr</span>(model, <span class="st">'gradient_checkpointing_enable'</span>):</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        model.gradient_checkpointing_enable()</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Fuse operations where possible</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> torch.jit.optimize_for_inference(torch.jit.script(model))</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_inference(model, data_loader, batch_size<span class="op">=</span><span class="dv">1</span>):</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Perform batch inference with memory management"""</span></span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> data_loader:</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Process in smaller chunks if needed</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch.size(<span class="dv">0</span>) <span class="op">&gt;</span> batch_size:</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, batch.size(<span class="dv">0</span>), batch_size):</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>                    chunk <span class="op">=</span> batch[i:i<span class="op">+</span>batch_size]</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>                    output <span class="op">=</span> model(chunk)</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>                    results.append(output.cpu())</span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># Clear GPU cache</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>                        torch.cuda.empty_cache()</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> model(batch)</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>                results.append(output.cpu())</span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.cat(results, dim<span class="op">=</span><span class="dv">0</span>)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">7. Best Practices</h2>
<section id="model-deployment-checklist" class="level3">
<h3 class="anchored" data-anchor-id="model-deployment-checklist" id="model-deployment-checklist">Model Deployment Checklist</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DeploymentValidator:</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, original_model, optimized_model, test_input):</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.original_model <span class="op">=</span> original_model</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimized_model <span class="op">=</span> optimized_model</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.test_input <span class="op">=</span> test_input</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate_accuracy(<span class="va">self</span>, tolerance<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Validate that optimized model maintains accuracy"""</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.original_model.<span class="bu">eval</span>()</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimized_model.<span class="bu">eval</span>()</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>            original_output <span class="op">=</span> <span class="va">self</span>.original_model(<span class="va">self</span>.test_input)</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>            optimized_output <span class="op">=</span> <span class="va">self</span>.optimized_model(<span class="va">self</span>.test_input)</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Check if outputs are close</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> torch.allclose(original_output, optimized_output, atol<span class="op">=</span>tolerance):</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"✓ Accuracy validation passed"</span>)</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">True</span></span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"✗ Accuracy validation failed"</span>)</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>            diff <span class="op">=</span> torch.<span class="bu">abs</span>(original_output <span class="op">-</span> optimized_output).<span class="bu">max</span>().item()</span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Maximum difference: </span><span class="sc">{</span>diff<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">False</span></span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate_performance(<span class="va">self</span>):</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compare performance metrics"""</span></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Benchmark both models</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>        original_benchmark <span class="op">=</span> ModelBenchmark(<span class="va">self</span>.original_model, <span class="va">self</span>.test_input.shape)</span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>        optimized_benchmark <span class="op">=</span> ModelBenchmark(<span class="va">self</span>.optimized_model, <span class="va">self</span>.test_input.shape)</span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>        original_stats <span class="op">=</span> original_benchmark.benchmark_inference(num_runs<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>        optimized_stats <span class="op">=</span> optimized_benchmark.benchmark_inference(num_runs<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>        speedup <span class="op">=</span> original_stats[<span class="st">'mean_time'</span>] <span class="op">/</span> optimized_stats[<span class="st">'mean_time'</span>]</span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a>        memory_reduction <span class="op">=</span> (original_stats[<span class="st">'mean_memory'</span>] <span class="op">-</span> optimized_stats[<span class="st">'mean_memory'</span>]) <span class="op">/</span> original_stats[<span class="st">'mean_memory'</span>] <span class="op">*</span> <span class="dv">100</span></span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Performance improvement: </span><span class="sc">{</span>speedup<span class="sc">:.2f}</span><span class="ss">x speedup"</span>)</span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Memory reduction: </span><span class="sc">{</span>memory_reduction<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>            <span class="st">'speedup'</span>: speedup,</span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>            <span class="st">'memory_reduction'</span>: memory_reduction,</span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">'original_fps'</span>: original_stats[<span class="st">'fps'</span>],</span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">'optimized_fps'</span>: optimized_stats[<span class="st">'fps'</span>]</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> check_model_size(<span class="va">self</span>):</span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Compare model file sizes"""</span></span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save both models temporarily</span></span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>        torch.save(<span class="va">self</span>.original_model.state_dict(), <span class="st">'temp_original.pth'</span>)</span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a>        torch.jit.save(torch.jit.script(<span class="va">self</span>.optimized_model), <span class="st">'temp_optimized.pt'</span>)</span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> os</span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a>        original_size <span class="op">=</span> os.path.getsize(<span class="st">'temp_original.pth'</span>)</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>        optimized_size <span class="op">=</span> os.path.getsize(<span class="st">'temp_optimized.pt'</span>)</span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>        size_reduction <span class="op">=</span> (original_size <span class="op">-</span> optimized_size) <span class="op">/</span> original_size <span class="op">*</span> <span class="dv">100</span></span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Original model size: </span><span class="sc">{</span>original_size <span class="op">/</span> <span class="dv">1024</span> <span class="op">/</span> <span class="dv">1024</span><span class="sc">:.2f}</span><span class="ss"> MB"</span>)</span>
<span id="cb14-61"><a href="#cb14-61" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Optimized model size: </span><span class="sc">{</span>optimized_size <span class="op">/</span> <span class="dv">1024</span> <span class="op">/</span> <span class="dv">1024</span><span class="sc">:.2f}</span><span class="ss"> MB"</span>)</span>
<span id="cb14-62"><a href="#cb14-62" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Size reduction: </span><span class="sc">{</span>size_reduction<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb14-63"><a href="#cb14-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-64"><a href="#cb14-64" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Clean up temporary files</span></span>
<span id="cb14-65"><a href="#cb14-65" aria-hidden="true" tabindex="-1"></a>        os.remove(<span class="st">'temp_original.pth'</span>)</span>
<span id="cb14-66"><a href="#cb14-66" aria-hidden="true" tabindex="-1"></a>        os.remove(<span class="st">'temp_optimized.pt'</span>)</span>
<span id="cb14-67"><a href="#cb14-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-68"><a href="#cb14-68" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> size_reduction</span>
<span id="cb14-69"><a href="#cb14-69" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-70"><a href="#cb14-70" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb14-71"><a href="#cb14-71" aria-hidden="true" tabindex="-1"></a>validator <span class="op">=</span> DeploymentValidator(model, quantized_model, sample_input)</span>
<span id="cb14-72"><a href="#cb14-72" aria-hidden="true" tabindex="-1"></a>validator.validate_accuracy()</span>
<span id="cb14-73"><a href="#cb14-73" aria-hidden="true" tabindex="-1"></a>performance_metrics <span class="op">=</span> validator.validate_performance()</span>
<span id="cb14-74"><a href="#cb14-74" aria-hidden="true" tabindex="-1"></a>size_reduction <span class="op">=</span> validator.check_model_size()</span></code></pre></div></div>
</section>
<section id="error-handling-and-logging" class="level3">
<h3 class="anchored" data-anchor-id="error-handling-and-logging" id="error-handling-and-logging">Error Handling and Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> functools <span class="im">import</span> wraps</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_logging():</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Setup logging for deployment"""</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    logging.basicConfig(</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        level<span class="op">=</span>logging.INFO,</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">format</span><span class="op">=</span><span class="st">'</span><span class="sc">%(asctime)s</span><span class="st"> - </span><span class="sc">%(name)s</span><span class="st"> - </span><span class="sc">%(levelname)s</span><span class="st"> - </span><span class="sc">%(message)s</span><span class="st">'</span>,</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        handlers<span class="op">=</span>[</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            logging.FileHandler(<span class="st">'model_deployment.log'</span>),</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            logging.StreamHandler()</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> handle_inference_errors(func):</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Decorator for handling inference errors"""</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>    <span class="at">@wraps</span>(func)</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> wrapper(<span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> func(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> torch.cuda.OutOfMemoryError:</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>            logging.error(<span class="st">"CUDA out of memory. Try reducing batch size."</span>)</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>            torch.cuda.empty_cache()</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>            logging.error(<span class="ss">f"Inference error: </span><span class="sc">{</span><span class="bu">str</span>(e)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> wrapper</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RobustInference:</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model_path, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger <span class="op">=</span> setup_logging()</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> torch.device(device)</span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model <span class="op">=</span> torch.jit.load(model_path, map_location<span class="op">=</span><span class="va">self</span>.device)</span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.info(<span class="ss">f"Model loaded successfully on </span><span class="sc">{</span>device<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.logger.error(<span class="ss">f"Failed to load model: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>    <span class="at">@handle_inference_errors</span></span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> inference(<span class="va">self</span>, input_data):</span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Robust inference with error handling"""</span></span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(input_data)</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>        inference_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(<span class="ss">f"Inference completed in </span><span class="sc">{</span>inference_time<span class="sc">:.3f}</span><span class="ss">s"</span>)</span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide provides a comprehensive approach to deploying PyTorch models on edge devices. Key takeaways:</p>
<ol type="1">
<li><strong>Model Optimization</strong>: Always quantize and prune models before deployment</li>
<li><strong>Format Selection</strong>: Choose the right format (TorchScript, ONNX, TensorRT) based on your target device</li>
<li><strong>Performance Monitoring</strong>: Continuously monitor inference time, memory usage, and accuracy</li>
<li><strong>Device-Specific Optimization</strong>: Leverage device-specific optimizations (TensorRT for NVIDIA, Core ML for iOS)</li>
<li><strong>Robust Deployment</strong>: Implement proper error handling and logging for production systems</li>
</ol>
<p>Remember to validate your optimized models thoroughly before deployment and monitor their performance in production environments.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Distributed Training with PyTorch - Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/distributed/distributed-pytorch-training/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/distributed/distributed-pytorch-training/</guid>
      <pubDate>Sat, 21 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="distributed-training-with-pytorch---complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/distributed/distributed-pytorch-training/distributed.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Distributed training allows you to scale PyTorch models across multiple GPUs and machines, dramatically reducing training time for large models and datasets. This guide covers practical implementation patterns from basic data parallelism to advanced distributed strategies.</p>
</section>
<section id="core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts" id="core-concepts">Core Concepts</h2>
<section id="key-terminology" class="level3">
<h3 class="anchored" data-anchor-id="key-terminology" id="key-terminology">Key Terminology</h3>
<ul>
<li><strong>World Size</strong>: Total number of processes participating in training</li>
<li><strong>Rank</strong>: Unique identifier for each process (0 to world_size-1)</li>
<li><strong>Local Rank</strong>: Process identifier within a single node/machine</li>
<li><strong>Process Group</strong>: Collection of processes that can communicate with each other</li>
<li><strong>Backend</strong>: Communication backend (NCCL for GPU, Gloo for CPU)</li>
</ul>
</section>
<section id="communication-patterns" class="level3">
<h3 class="anchored" data-anchor-id="communication-patterns" id="communication-patterns">Communication Patterns</h3>
<ul>
<li><strong>All-Reduce</strong>: Combine values from all processes and distribute the result</li>
<li><strong>Broadcast</strong>: Send data from one process to all others</li>
<li><strong>Gather</strong>: Collect data from all processes to one process</li>
<li><strong>Scatter</strong>: Distribute data from one process to all others</li>
</ul>
</section>
</section>
<section id="setup-and-initialization" class="level2">
<h2 class="anchored" data-anchor-id="setup-and-initialization" id="setup-and-initialization">Setup and Initialization</h2>
<section id="basic-environment-setup" class="level3">
<h3 class="anchored" data-anchor-id="basic-environment-setup" id="basic-environment-setup">Basic Environment Setup</h3>
<div id="06b7ab26" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.distributed <span class="im">as</span> dist</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.multiprocessing <span class="im">as</span> mp</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.nn.parallel <span class="im">import</span> DistributedDataParallel <span class="im">as</span> DDP</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data.distributed <span class="im">import</span> DistributedSampler</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_distributed(rank, world_size, backend<span class="op">=</span><span class="st">'nccl'</span>):</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Initialize distributed training environment"""</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_ADDR'</span>] <span class="op">=</span> <span class="st">'localhost'</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_PORT'</span>] <span class="op">=</span> <span class="st">'12355'</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize process group</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    dist.init_process_group(</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        backend<span class="op">=</span>backend,</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span>rank,</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        world_size<span class="op">=</span>world_size</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set device for current process</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>    torch.cuda.set_device(rank)</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cleanup_distributed():</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Clean up distributed training"""</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>    dist.destroy_process_group()</span></code></pre></div></div>
</div>
</section>
<section id="multi-node-setup" class="level3">
<h3 class="anchored" data-anchor-id="multi-node-setup" id="multi-node-setup">Multi-Node Setup</h3>
<div id="6e6f4fee" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup_multinode(rank, world_size, master_addr, master_port):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Setup for multi-node distributed training"""</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_ADDR'</span>] <span class="op">=</span> master_addr</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_PORT'</span>] <span class="op">=</span> <span class="bu">str</span>(master_port)</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'RANK'</span>] <span class="op">=</span> <span class="bu">str</span>(rank)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'WORLD_SIZE'</span>] <span class="op">=</span> <span class="bu">str</span>(world_size)</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    dist.init_process_group(</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        backend<span class="op">=</span><span class="st">'nccl'</span>,</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        init_method<span class="op">=</span><span class="st">'env://'</span>,</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span>rank,</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        world_size<span class="op">=</span>world_size</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    )</span></code></pre></div></div>
</div>
</section>
</section>
<section id="data-parallel-training" class="level2">
<h2 class="anchored" data-anchor-id="data-parallel-training" id="data-parallel-training">Data Parallel Training</h2>
<section id="simple-dataparallel-single-node" class="level3">
<h3 class="anchored" data-anchor-id="simple-dataparallel-single-node" id="simple-dataparallel-single-node">Simple DataParallel (Single Node)</h3>
<div id="2bd6d1d8" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(nn.Module):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, hidden_size, num_classes):</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(input_size, hidden_size)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relu <span class="op">=</span> nn.ReLU()</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(hidden_size, num_classes)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.relu(x)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_dataparallel():</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Basic DataParallel training"""</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model and wrap with DataParallel</span></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleModel(<span class="dv">784</span>, <span class="dv">256</span>, <span class="dv">10</span>)</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.device_count() <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> nn.DataParallel(model)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    model.to(device)</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup optimizer and loss</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">0.001</span>)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="distributed-data-parallel-ddp" class="level2">
<h2 class="anchored" data-anchor-id="distributed-data-parallel-ddp" id="distributed-data-parallel-ddp">Distributed Data Parallel (DDP)</h2>
<section id="basic-ddp-implementation" class="level3">
<h3 class="anchored" data-anchor-id="basic-ddp-implementation" id="basic-ddp-implementation">Basic DDP Implementation</h3>
<div id="298ae1a6" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_ddp(rank, world_size):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Distributed Data Parallel training function"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup distributed environment</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    setup_distributed(rank, world_size)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model and move to GPU</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleModel(<span class="dv">784</span>, <span class="dv">256</span>, <span class="dv">10</span>).to(rank)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Wrap model with DDP</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    ddp_model <span class="op">=</span> DDP(model, device_ids<span class="op">=</span>[rank])</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup distributed sampler</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    train_sampler <span class="op">=</span> DistributedSampler(</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        train_dataset,</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        num_replicas<span class="op">=</span>world_size,</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span>rank,</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">True</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> torch.utils.data.DataLoader(</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        train_dataset,</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span>batch_size,</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        sampler<span class="op">=</span>train_sampler,</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        pin_memory<span class="op">=</span><span class="va">True</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup optimizer and loss</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(ddp_model.parameters(), lr<span class="op">=</span><span class="fl">0.001</span>)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        train_sampler.set_epoch(epoch)  <span class="co"># Important for shuffling</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(rank), target.to(rank)</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> ddp_model(data)</span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span> <span class="kw">and</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch: </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>    cleanup_distributed()</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Main function to spawn distributed processes"""</span></span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>    world_size <span class="op">=</span> torch.cuda.device_count()</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>    mp.spawn(train_ddp, args<span class="op">=</span>(world_size,), nprocs<span class="op">=</span>world_size, join<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
<section id="complete-training-script-with-validation" class="level3">
<h3 class="anchored" data-anchor-id="complete-training-script-with-validation" id="complete-training-script-with-validation">Complete Training Script with Validation</h3>
<div id="1bc3c20c" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.tensorboard <span class="im">import</span> SummaryWriter</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DistributedTrainer:</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, rank, world_size, train_loader, val_loader<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.rank <span class="op">=</span> rank</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.world_size <span class="op">=</span> world_size</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_loader <span class="op">=</span> train_loader</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_loader <span class="op">=</span> val_loader</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup DDP</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ddp_model <span class="op">=</span> DDP(model, device_ids<span class="op">=</span>[rank])</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup optimizer and scheduler</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> torch.optim.AdamW(</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.ddp_model.parameters(),</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>            lr<span class="op">=</span><span class="fl">0.001</span>,</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>            weight_decay<span class="op">=</span><span class="fl">0.01</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scheduler <span class="op">=</span> torch.optim.lr_scheduler.CosineAnnealingLR(</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer, T_max<span class="op">=</span><span class="dv">100</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Logging (only on rank 0)</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.writer <span class="op">=</span> SummaryWriter(<span class="st">'runs/distributed_training'</span>)</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_epoch(<span class="va">self</span>, epoch):</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train for one epoch"""</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ddp_model.train()</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.train_loader):</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(<span class="va">self</span>.rank), target.to(<span class="va">self</span>.rank)</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.ddp_model(data)</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> <span class="va">self</span>.criterion(output, target)</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping</span></span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.ddp_model.parameters(), max_norm<span class="op">=</span><span class="fl">1.0</span>)</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.step()</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>            num_batches <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.rank <span class="op">==</span> <span class="dv">0</span> <span class="kw">and</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch: </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>        avg_loss <span class="op">=</span> total_loss <span class="op">/</span> num_batches</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> avg_loss</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate(<span class="va">self</span>):</span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Validate the model"""</span></span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.val_loader <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ddp_model.<span class="bu">eval</span>()</span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-68"><a href="#cb5-68" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb5-69"><a href="#cb5-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> data, target <span class="kw">in</span> <span class="va">self</span>.val_loader:</span>
<span id="cb5-70"><a href="#cb5-70" aria-hidden="true" tabindex="-1"></a>                data, target <span class="op">=</span> data.to(<span class="va">self</span>.rank), target.to(<span class="va">self</span>.rank)</span>
<span id="cb5-71"><a href="#cb5-71" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> <span class="va">self</span>.ddp_model(data)</span>
<span id="cb5-72"><a href="#cb5-72" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> <span class="va">self</span>.criterion(output, target)</span>
<span id="cb5-73"><a href="#cb5-73" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-74"><a href="#cb5-74" aria-hidden="true" tabindex="-1"></a>                total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb5-75"><a href="#cb5-75" aria-hidden="true" tabindex="-1"></a>                pred <span class="op">=</span> output.argmax(dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-76"><a href="#cb5-76" aria-hidden="true" tabindex="-1"></a>                correct <span class="op">+=</span> pred.eq(target).<span class="bu">sum</span>().item()</span>
<span id="cb5-77"><a href="#cb5-77" aria-hidden="true" tabindex="-1"></a>                total <span class="op">+=</span> target.size(<span class="dv">0</span>)</span>
<span id="cb5-78"><a href="#cb5-78" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-79"><a href="#cb5-79" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gather metrics from all processes</span></span>
<span id="cb5-80"><a href="#cb5-80" aria-hidden="true" tabindex="-1"></a>        total_loss_tensor <span class="op">=</span> torch.tensor(total_loss).to(<span class="va">self</span>.rank)</span>
<span id="cb5-81"><a href="#cb5-81" aria-hidden="true" tabindex="-1"></a>        correct_tensor <span class="op">=</span> torch.tensor(correct).to(<span class="va">self</span>.rank)</span>
<span id="cb5-82"><a href="#cb5-82" aria-hidden="true" tabindex="-1"></a>        total_tensor <span class="op">=</span> torch.tensor(total).to(<span class="va">self</span>.rank)</span>
<span id="cb5-83"><a href="#cb5-83" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-84"><a href="#cb5-84" aria-hidden="true" tabindex="-1"></a>        dist.all_reduce(total_loss_tensor, op<span class="op">=</span>dist.ReduceOp.SUM)</span>
<span id="cb5-85"><a href="#cb5-85" aria-hidden="true" tabindex="-1"></a>        dist.all_reduce(correct_tensor, op<span class="op">=</span>dist.ReduceOp.SUM)</span>
<span id="cb5-86"><a href="#cb5-86" aria-hidden="true" tabindex="-1"></a>        dist.all_reduce(total_tensor, op<span class="op">=</span>dist.ReduceOp.SUM)</span>
<span id="cb5-87"><a href="#cb5-87" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-88"><a href="#cb5-88" aria-hidden="true" tabindex="-1"></a>        avg_loss <span class="op">=</span> total_loss_tensor.item() <span class="op">/</span> <span class="va">self</span>.world_size</span>
<span id="cb5-89"><a href="#cb5-89" aria-hidden="true" tabindex="-1"></a>        accuracy <span class="op">=</span> correct_tensor.item() <span class="op">/</span> total_tensor.item()</span>
<span id="cb5-90"><a href="#cb5-90" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-91"><a href="#cb5-91" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> avg_loss, accuracy</span>
<span id="cb5-92"><a href="#cb5-92" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-93"><a href="#cb5-93" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> save_checkpoint(<span class="va">self</span>, epoch, loss):</span>
<span id="cb5-94"><a href="#cb5-94" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Save model checkpoint (only on rank 0)"""</span></span>
<span id="cb5-95"><a href="#cb5-95" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-96"><a href="#cb5-96" aria-hidden="true" tabindex="-1"></a>            checkpoint <span class="op">=</span> {</span>
<span id="cb5-97"><a href="#cb5-97" aria-hidden="true" tabindex="-1"></a>                <span class="st">'epoch'</span>: epoch,</span>
<span id="cb5-98"><a href="#cb5-98" aria-hidden="true" tabindex="-1"></a>                <span class="st">'model_state_dict'</span>: <span class="va">self</span>.ddp_model.module.state_dict(),</span>
<span id="cb5-99"><a href="#cb5-99" aria-hidden="true" tabindex="-1"></a>                <span class="st">'optimizer_state_dict'</span>: <span class="va">self</span>.optimizer.state_dict(),</span>
<span id="cb5-100"><a href="#cb5-100" aria-hidden="true" tabindex="-1"></a>                <span class="st">'scheduler_state_dict'</span>: <span class="va">self</span>.scheduler.state_dict(),</span>
<span id="cb5-101"><a href="#cb5-101" aria-hidden="true" tabindex="-1"></a>                <span class="st">'loss'</span>: loss,</span>
<span id="cb5-102"><a href="#cb5-102" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb5-103"><a href="#cb5-103" aria-hidden="true" tabindex="-1"></a>            torch.save(checkpoint, <span class="ss">f'checkpoint_epoch_</span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">.pth'</span>)</span>
<span id="cb5-104"><a href="#cb5-104" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-105"><a href="#cb5-105" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train(<span class="va">self</span>, num_epochs):</span>
<span id="cb5-106"><a href="#cb5-106" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Complete training loop"""</span></span>
<span id="cb5-107"><a href="#cb5-107" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb5-108"><a href="#cb5-108" aria-hidden="true" tabindex="-1"></a>            start_time <span class="op">=</span> time.time()</span>
<span id="cb5-109"><a href="#cb5-109" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-110"><a href="#cb5-110" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Set epoch for distributed sampler</span></span>
<span id="cb5-111"><a href="#cb5-111" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">hasattr</span>(<span class="va">self</span>.train_loader.sampler, <span class="st">'set_epoch'</span>):</span>
<span id="cb5-112"><a href="#cb5-112" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.train_loader.sampler.set_epoch(epoch)</span>
<span id="cb5-113"><a href="#cb5-113" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-114"><a href="#cb5-114" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Train</span></span>
<span id="cb5-115"><a href="#cb5-115" aria-hidden="true" tabindex="-1"></a>            train_loss <span class="op">=</span> <span class="va">self</span>.train_epoch(epoch)</span>
<span id="cb5-116"><a href="#cb5-116" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-117"><a href="#cb5-117" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validate</span></span>
<span id="cb5-118"><a href="#cb5-118" aria-hidden="true" tabindex="-1"></a>            val_metrics <span class="op">=</span> <span class="va">self</span>.validate()</span>
<span id="cb5-119"><a href="#cb5-119" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-120"><a href="#cb5-120" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Step scheduler</span></span>
<span id="cb5-121"><a href="#cb5-121" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scheduler.step()</span>
<span id="cb5-122"><a href="#cb5-122" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-123"><a href="#cb5-123" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Logging and checkpointing (rank 0 only)</span></span>
<span id="cb5-124"><a href="#cb5-124" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-125"><a href="#cb5-125" aria-hidden="true" tabindex="-1"></a>                epoch_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb5-126"><a href="#cb5-126" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss">, '</span></span>
<span id="cb5-127"><a href="#cb5-127" aria-hidden="true" tabindex="-1"></a>                      <span class="ss">f'Time: </span><span class="sc">{</span>epoch_time<span class="sc">:.2f}</span><span class="ss">s'</span>)</span>
<span id="cb5-128"><a href="#cb5-128" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-129"><a href="#cb5-129" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> val_metrics:</span>
<span id="cb5-130"><a href="#cb5-130" aria-hidden="true" tabindex="-1"></a>                    val_loss, val_acc <span class="op">=</span> val_metrics</span>
<span id="cb5-131"><a href="#cb5-131" aria-hidden="true" tabindex="-1"></a>                    <span class="bu">print</span>(<span class="ss">f'Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">, Val Acc: </span><span class="sc">{</span>val_acc<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb5-132"><a href="#cb5-132" aria-hidden="true" tabindex="-1"></a>                    </span>
<span id="cb5-133"><a href="#cb5-133" aria-hidden="true" tabindex="-1"></a>                    <span class="co"># TensorBoard logging</span></span>
<span id="cb5-134"><a href="#cb5-134" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.writer.add_scalar(<span class="st">'Loss/Train'</span>, train_loss, epoch)</span>
<span id="cb5-135"><a href="#cb5-135" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.writer.add_scalar(<span class="st">'Loss/Val'</span>, val_loss, epoch)</span>
<span id="cb5-136"><a href="#cb5-136" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.writer.add_scalar(<span class="st">'Accuracy/Val'</span>, val_acc, epoch)</span>
<span id="cb5-137"><a href="#cb5-137" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-138"><a href="#cb5-138" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Save checkpoint</span></span>
<span id="cb5-139"><a href="#cb5-139" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> epoch <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-140"><a href="#cb5-140" aria-hidden="true" tabindex="-1"></a>                    <span class="va">self</span>.save_checkpoint(epoch, train_loss)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-patterns" class="level2">
<h2 class="anchored" data-anchor-id="advanced-patterns" id="advanced-patterns">Advanced Patterns</h2>
<section id="mixed-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">Mixed Precision Training</h3>
<div id="ba181961" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> GradScaler, autocast</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MixedPrecisionTrainer(DistributedTrainer):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scaler <span class="op">=</span> GradScaler()</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_epoch(<span class="va">self</span>, epoch):</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train with mixed precision"""</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ddp_model.train()</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(<span class="va">self</span>.train_loader):</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(<span class="va">self</span>.rank), target.to(<span class="va">self</span>.rank)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass with autocast</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> autocast():</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> <span class="va">self</span>.ddp_model(data)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> <span class="va">self</span>.criterion(output, target)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass with scaled gradients</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.scale(loss).backward()</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping with scaled gradients</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.unscale_(<span class="va">self</span>.optimizer)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.ddp_model.parameters(), max_norm<span class="op">=</span><span class="fl">1.0</span>)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Optimizer step with scaler</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.step(<span class="va">self</span>.optimizer)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler.update()</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            num_batches <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss <span class="op">/</span> num_batches</span></code></pre></div></div>
</div>
</section>
<section id="model-sharding-with-fsdp" class="level3">
<h3 class="anchored" data-anchor-id="model-sharding-with-fsdp" id="model-sharding-with-fsdp">Model Sharding with FSDP</h3>
<div id="82bb1345" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.distributed.fsdp <span class="im">import</span> FullyShardedDataParallel <span class="im">as</span> FSDP</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.distributed.fsdp.wrap <span class="im">import</span> size_based_auto_wrap_policy</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_fsdp_model(model, rank):</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create FSDP wrapped model"""</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    wrap_policy <span class="op">=</span> size_based_auto_wrap_policy(min_num_params<span class="op">=</span><span class="dv">100000</span>)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    fsdp_model <span class="op">=</span> FSDP(</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        model,</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        auto_wrap_policy<span class="op">=</span>wrap_policy,</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        mixed_precision<span class="op">=</span>torch.distributed.fsdp.MixedPrecision(</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            param_dtype<span class="op">=</span>torch.float16,</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            reduce_dtype<span class="op">=</span>torch.float16,</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            buffer_dtype<span class="op">=</span>torch.float16</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        ),</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        device_id<span class="op">=</span>rank,</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        sync_module_states<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        sharding_strategy<span class="op">=</span>torch.distributed.fsdp.ShardingStrategy.FULL_SHARD</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fsdp_model</span></code></pre></div></div>
</div>
</section>
<section id="pipeline-parallelism" class="level3">
<h3 class="anchored" data-anchor-id="pipeline-parallelism" id="pipeline-parallelism">Pipeline Parallelism</h3>
<div id="eae3f417" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.distributed.pipeline.sync <span class="im">as</span> Pipe</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PipelineModel(nn.Module):</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, layers_per_partition<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define layers</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        layers <span class="op">=</span> []</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Linear(<span class="dv">784</span>, <span class="dv">512</span>))</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.ReLU())</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Linear(<span class="dv">512</span>, <span class="dv">256</span>))</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.ReLU())</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Linear(<span class="dv">256</span>, <span class="dv">128</span>))</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.ReLU())</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        layers.append(nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>))</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create pipeline</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pipe <span class="op">=</span> Pipe.Pipe(</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>            nn.Sequential(<span class="op">*</span>layers),</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>            balance<span class="op">=</span>[layers_per_partition] <span class="op">*</span> (<span class="bu">len</span>(layers) <span class="op">//</span> layers_per_partition),</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            devices<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">1</span>],  <span class="co"># GPU devices</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>            chunks<span class="op">=</span><span class="dv">8</span>  <span class="co"># Number of micro-batches</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.pipe(x)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="monitoring-and-debugging" class="level2">
<h2 class="anchored" data-anchor-id="monitoring-and-debugging" id="monitoring-and-debugging">Monitoring and Debugging</h2>
<section id="performance-profiling" class="level3">
<h3 class="anchored" data-anchor-id="performance-profiling" id="performance-profiling">Performance Profiling</h3>
<div id="3300c895" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.profiler</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> profile_training(trainer, num_steps<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Profile distributed training performance"""</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.profiler.profile(</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        activities<span class="op">=</span>[</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>            torch.profiler.ProfilerActivity.CPU,</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>            torch.profiler.ProfilerActivity.CUDA,</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        ],</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        schedule<span class="op">=</span>torch.profiler.schedule(wait<span class="op">=</span><span class="dv">1</span>, warmup<span class="op">=</span><span class="dv">1</span>, active<span class="op">=</span><span class="dv">3</span>, repeat<span class="op">=</span><span class="dv">2</span>),</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        on_trace_ready<span class="op">=</span>torch.profiler.tensorboard_trace_handler(<span class="st">'./log/profiler'</span>),</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        record_shapes<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        profile_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        with_stack<span class="op">=</span><span class="va">True</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    ) <span class="im">as</span> prof:</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> step, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(trainer.train_loader):</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> step <span class="op">&gt;=</span> num_steps:</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(trainer.rank), target.to(trainer.rank)</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            trainer.optimizer.zero_grad()</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> trainer.ddp_model(data)</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> trainer.criterion(output, target)</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>            trainer.optimizer.step()</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>            prof.step()</span></code></pre></div></div>
</div>
</section>
<section id="communication-debugging" class="level3">
<h3 class="anchored" data-anchor-id="communication-debugging" id="communication-debugging">Communication Debugging</h3>
<div id="7e0b0347" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> debug_communication():</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Debug distributed communication"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    rank <span class="op">=</span> dist.get_rank()</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    world_size <span class="op">=</span> dist.get_world_size()</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Test all-reduce</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    tensor <span class="op">=</span> torch.randn(<span class="dv">10</span>).cuda()</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Rank </span><span class="sc">{</span>rank<span class="sc">}</span><span class="ss">: Before all-reduce: </span><span class="sc">{</span>tensor<span class="sc">.</span><span class="bu">sum</span>()<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    dist.all_reduce(tensor, op<span class="op">=</span>dist.ReduceOp.SUM)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Rank </span><span class="sc">{</span>rank<span class="sc">}</span><span class="ss">: After all-reduce: </span><span class="sc">{</span>tensor<span class="sc">.</span><span class="bu">sum</span>()<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Test broadcast</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        broadcast_tensor <span class="op">=</span> torch.randn(<span class="dv">5</span>).cuda()</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>        broadcast_tensor <span class="op">=</span> torch.zeros(<span class="dv">5</span>).cuda()</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    dist.broadcast(broadcast_tensor, src<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Rank </span><span class="sc">{</span>rank<span class="sc">}</span><span class="ss">: Broadcast result: </span><span class="sc">{</span>broadcast_tensor<span class="sc">.</span><span class="bu">sum</span>()<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="data-loading-optimization" class="level3">
<h3 class="anchored" data-anchor-id="data-loading-optimization" id="data-loading-optimization">Data Loading Optimization</h3>
<div id="ae810d87" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_efficient_dataloader(dataset, batch_size, world_size, rank):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create optimized distributed data loader"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    sampler <span class="op">=</span> DistributedSampler(</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>        dataset,</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        num_replicas<span class="op">=</span>world_size,</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        rank<span class="op">=</span>rank,</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        drop_last<span class="op">=</span><span class="va">True</span>  <span class="co"># Ensures consistent batch sizes</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    loader <span class="op">=</span> torch.utils.data.DataLoader(</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        dataset,</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span>batch_size,</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        sampler<span class="op">=</span>sampler,</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span>,  <span class="co"># Adjust based on system</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        pin_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        persistent_workers<span class="op">=</span><span class="va">True</span>,  <span class="co"># Reuse worker processes</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        prefetch_factor<span class="op">=</span><span class="dv">2</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loader</span></code></pre></div></div>
</div>
</section>
<section id="gradient-synchronization" class="level3">
<h3 class="anchored" data-anchor-id="gradient-synchronization" id="gradient-synchronization">Gradient Synchronization</h3>
<div id="19e33665" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_gradient_accumulation(model, optimizer, criterion, data_loader, </span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>                                   accumulation_steps<span class="op">=</span><span class="dv">4</span>):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Training with gradient accumulation"""</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(data_loader):</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        data, target <span class="op">=</span> data.cuda(), target.cuda()</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target) <span class="op">/</span> accumulation_steps</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass</span></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update parameters every accumulation_steps</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> (batch_idx <span class="op">+</span> <span class="dv">1</span>) <span class="op">%</span> accumulation_steps <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
<section id="dynamic-loss-scaling" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-loss-scaling" id="dynamic-loss-scaling">Dynamic Loss Scaling</h3>
<div id="23e45ce5" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DynamicLossScaler:</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, init_scale<span class="op">=</span><span class="fl">2.</span><span class="op">**</span><span class="dv">16</span>, scale_factor<span class="op">=</span><span class="fl">2.</span>, scale_window<span class="op">=</span><span class="dv">2000</span>):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale <span class="op">=</span> init_scale</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale_factor <span class="op">=</span> scale_factor</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale_window <span class="op">=</span> scale_window</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update(<span class="va">self</span>, overflow):</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> overflow:</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scale <span class="op">/=</span> <span class="va">self</span>.scale_factor</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.counter <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="va">self</span>.counter <span class="op">&gt;=</span> <span class="va">self</span>.scale_window:</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.scale <span class="op">*=</span> <span class="va">self</span>.scale_factor</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.counter <span class="op">=</span> <span class="dv">0</span></span></code></pre></div></div>
</div>
</section>
<section id="launch-script-example" class="level3">
<h3 class="anchored" data-anchor-id="launch-script-example" id="launch-script-example">Launch Script Example</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>launch_distributed.sh</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14" data-filename="launch_distributed.sh"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co">#!/bin/bash</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="co"># launch_distributed.sh</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Single node, multiple GPUs</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> torch.distributed.launch <span class="dt">\</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nproc_per_node</span><span class="op">=</span>4 <span class="dt">\</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nnodes</span><span class="op">=</span>1 <span class="dt">\</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    <span class="at">--node_rank</span><span class="op">=</span>0 <span class="dt">\</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_addr</span><span class="op">=</span><span class="st">"localhost"</span> <span class="dt">\</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_port</span><span class="op">=</span>12345 <span class="dt">\</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    train_distributed.py</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Multi-node setup</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Node 0:</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> torch.distributed.launch <span class="dt">\</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nproc_per_node</span><span class="op">=</span>4 <span class="dt">\</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nnodes</span><span class="op">=</span>2 <span class="dt">\</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    <span class="at">--node_rank</span><span class="op">=</span>0 <span class="dt">\</span></span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_addr</span><span class="op">=</span><span class="st">"192.168.1.100"</span> <span class="dt">\</span></span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_port</span><span class="op">=</span>12345 <span class="dt">\</span></span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>    train_distributed.py</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Node 1:</span></span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> torch.distributed.launch <span class="dt">\</span></span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nproc_per_node</span><span class="op">=</span>4 <span class="dt">\</span></span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>    <span class="at">--nnodes</span><span class="op">=</span>2 <span class="dt">\</span></span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>    <span class="at">--node_rank</span><span class="op">=</span>1 <span class="dt">\</span></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_addr</span><span class="op">=</span><span class="st">"192.168.1.100"</span> <span class="dt">\</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>    <span class="at">--master_port</span><span class="op">=</span>12345 <span class="dt">\</span></span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>    train_distributed.py</span></code></pre></div></div>
</div>
</section>
<section id="error-handling-and-recovery" class="level3">
<h3 class="anchored" data-anchor-id="error-handling-and-recovery" id="error-handling-and-recovery">Error Handling and Recovery</h3>
<div id="e84cf6db" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> robust_train_loop(trainer, num_epochs, checkpoint_dir):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Training loop with error handling and recovery"""</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    start_epoch <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load checkpoint if exists</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    latest_checkpoint <span class="op">=</span> find_latest_checkpoint(checkpoint_dir)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> latest_checkpoint:</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        start_epoch <span class="op">=</span> load_checkpoint(trainer, latest_checkpoint)</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(start_epoch, num_epochs):</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            trainer.train_epoch(epoch)</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Save checkpoint</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> epoch <span class="op">%</span> <span class="dv">5</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>                save_checkpoint(trainer, epoch, checkpoint_dir)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">RuntimeError</span> <span class="im">as</span> e:</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">"out of memory"</span> <span class="kw">in</span> <span class="bu">str</span>(e):</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"OOM error at epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, reducing batch size"</span>)</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Implement batch size reduction logic</span></span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>                torch.cuda.empty_cache()</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>                <span class="cf">continue</span></span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> e</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Error at epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Save emergency checkpoint</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>            save_checkpoint(trainer, epoch, checkpoint_dir, emergency<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> e</span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide provides a comprehensive foundation for implementing distributed training with PyTorch. Start with basic DDP for single-node multi-GPU setups, then progress to more advanced techniques like FSDP and pipeline parallelism as your models and datasets grow larger.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Start Simple</strong>: Begin with DataParallel for single-node setups</li>
<li><strong>Scale Gradually</strong>: Move to DDP for multi-node distributed training</li>
<li><strong>Monitor Performance</strong>: Use profiling tools to identify bottlenecks</li>
<li><strong>Handle Errors</strong>: Implement robust error handling and checkpointing</li>
<li><strong>Optimize Data Loading</strong>: Use efficient data loaders and samplers</li>
</ol>
</div>
</div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Additional Resources
</div>
</div>
<div class="callout-body-container callout-body">
<p>For more advanced topics and latest updates, refer to: - <a href="https://pytorch.org/docs/stable/distributed.html">PyTorch Distributed Documentation</a> - <a href="https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html">FSDP Tutorial</a> - <a href="https://pytorch.org/docs/stable/pipeline.html">Pipeline Parallelism Guide</a></p>
</div>
</div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python Package Development with Rust - Complete Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/rust-py-package/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/rust-py-package/</guid>
      <pubDate>Sun, 15 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="python-package-development-with-rust---complete-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/rust-py-package/rust.png" class="img-fluid"></p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>This guide covers creating Python packages with Rust backends using PyO3 and maturin. This approach combines Rust’s performance and safety with Python’s ecosystem accessibility.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Why Rust + Python?
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Performance</strong>: Rust provides near C-level performance</li>
<li><strong>Safety</strong>: Memory safety without garbage collection</li>
<li><strong>Ecosystem</strong>: Access to Python’s vast library ecosystem</li>
<li><strong>Maintainability</strong>: Rust’s type system catches many bugs at compile time</li>
</ul>
</div>
</div>
</section>
<section id="prerequisites" class="level2">
<h2 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h2>
<p>Before starting, ensure you have:</p>
<ul>
<li>Python 3.7+ installed</li>
<li>Rust toolchain installed (rustup recommended)<br>
</li>
<li>Basic knowledge of both Python and Rust</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>You can install Rust from <a href="https://rustup.rs/">rustup.rs</a> if you haven’t already.</p>
</div>
</div>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<p>First, install the required tools:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install maturin (build tool for Rust-based Python extensions)</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install maturin</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyO3 CLI (optional but helpful)</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install pyo3-pack</span></code></pre></div></div>
</section>
<section id="project-setup" class="level2">
<h2 class="anchored" data-anchor-id="project-setup" id="project-setup">Project Setup</h2>
<section id="initialize-the-project" class="level3">
<h3 class="anchored" data-anchor-id="initialize-the-project" id="initialize-the-project">Initialize the Project</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a new directory</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="fu">mkdir</span> my-rust-python-package</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> my-rust-python-package</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize with maturin</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> init <span class="at">--bindings</span> pyo3</span></code></pre></div></div>
<p>This creates the basic structure:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Project Structure</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3" data-filename="Project Structure"><pre class="sourceCode default code-with-copy"><code class="sourceCode default"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a>my-rust-python-package/</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>├── Cargo.toml</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>├── pyproject.toml</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>├── src/</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>│   └── lib.rs</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>└── python/</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    └── my_rust_python_package/</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        └── __init__.py</span></code></pre></div></div>
</div>
</section>
<section id="configure-cargo.toml" class="level3">
<h3 class="anchored" data-anchor-id="configure-cargo.toml" id="configure-cargo.toml">Configure Cargo.toml</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Cargo.toml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4" data-filename="Cargo.toml"><pre class="sourceCode toml code-with-copy"><code class="sourceCode toml"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">[package]</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="dt">name</span> <span class="op">=</span> <span class="st">"my-rust-python-package"</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="dt">version</span> <span class="op">=</span> <span class="st">"0.1.0"</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="dt">edition</span> <span class="op">=</span> <span class="st">"2021"</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="kw">[lib]</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="dt">name</span> <span class="op">=</span> <span class="st">"my_rust_python_package"</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="dt">crate-type</span> <span class="op">=</span> <span class="op">[</span><span class="st">"cdylib"</span><span class="op">]</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="kw">[dependencies]</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="dt">pyo3</span> <span class="op">=</span> <span class="op">{ </span><span class="dt">version</span><span class="op"> =</span> <span class="st">"0.20"</span><span class="op">, </span><span class="dt">features</span><span class="op"> =</span> <span class="op">[</span><span class="st">"extension-module"</span><span class="op">] }</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a><span class="kw">[build-system]</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="dt">requires</span> <span class="op">=</span> <span class="op">[</span><span class="st">"maturin&gt;=1.0,&lt;2.0"</span><span class="op">]</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a><span class="dt">build-backend</span> <span class="op">=</span> <span class="st">"maturin"</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a><span class="kw">[project]</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a><span class="dt">name</span> <span class="op">=</span> <span class="st">"my-rust-python-package"</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a><span class="dt">requires-python</span> <span class="op">=</span> <span class="st">"&gt;=3.7"</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a><span class="dt">classifiers</span> <span class="op">=</span> <span class="op">[</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Programming Language :: Rust"</span><span class="op">,</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Programming Language :: Python :: Implementation :: CPython"</span><span class="op">,</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Programming Language :: Python :: Implementation :: PyPy"</span><span class="op">,</span></span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a><span class="op">]</span></span></code></pre></div></div>
</div>
</section>
<section id="configure-pyproject.toml" class="level3">
<h3 class="anchored" data-anchor-id="configure-pyproject.toml" id="configure-pyproject.toml">Configure pyproject.toml</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>pyproject.toml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5" data-filename="pyproject.toml"><pre class="sourceCode toml code-with-copy"><code class="sourceCode toml"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">[build-system]</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="dt">requires</span> <span class="op">=</span> <span class="op">[</span><span class="st">"maturin&gt;=1.0,&lt;2.0"</span><span class="op">]</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="dt">build-backend</span> <span class="op">=</span> <span class="st">"maturin"</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="kw">[project]</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="dt">name</span> <span class="op">=</span> <span class="st">"my-rust-python-package"</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="dt">version</span> <span class="op">=</span> <span class="st">"0.1.0"</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="dt">description</span> <span class="op">=</span> <span class="st">"A Python package written in Rust"</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="dt">authors</span> <span class="op">=</span> <span class="op">[</span><span class="st">"Your Name &lt;your.email@example.com&gt;"</span><span class="op">]</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="dt">requires-python</span> <span class="op">=</span> <span class="st">"&gt;=3.7"</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="dt">classifiers</span> <span class="op">=</span> <span class="op">[</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Development Status :: 4 - Beta"</span><span class="op">,</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Intended Audience :: Developers"</span><span class="op">,</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"License :: OSI Approved :: MIT License"</span><span class="op">,</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Programming Language :: Python :: 3"</span><span class="op">,</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="st">"Programming Language :: Rust"</span><span class="op">,</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="op">]</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a><span class="kw">[tool.maturin]</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a><span class="dt">features</span> <span class="op">=</span> <span class="op">[</span><span class="st">"pyo3/extension-module"</span><span class="op">]</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="writing-rust-code" class="level2">
<h2 class="anchored" data-anchor-id="writing-rust-code" id="writing-rust-code">Writing Rust Code</h2>
<section id="basic-function-example" class="level3">
<h3 class="anchored" data-anchor-id="basic-function-example" id="basic-function-example">Basic Function Example</h3>
<p>Edit <code>src/lib.rs</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>src/lib.rs</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6" data-filename="src/lib.rs"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::prelude::</span><span class="op">*;</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co">/// Formats the sum of two numbers as string.</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> sum_as_string(a<span class="op">:</span> <span class="dt">usize</span><span class="op">,</span> b<span class="op">:</span> <span class="dt">usize</span>) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span><span class="dt">String</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="cn">Ok</span>((a <span class="op">+</span> b)<span class="op">.</span>to_string())</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="co">/// A simple example function that multiplies two numbers</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> multiply(a<span class="op">:</span> <span class="dt">f64</span><span class="op">,</span> b<span class="op">:</span> <span class="dt">f64</span>) <span class="op">-&gt;</span> <span class="dt">f64</span> <span class="op">{</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    a <span class="op">*</span> b</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="co">/// Fast Fibonacci calculation</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> fibonacci(n<span class="op">:</span> <span class="dt">u64</span>) <span class="op">-&gt;</span> <span class="dt">u64</span> <span class="op">{</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">match</span> n <span class="op">{</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="dv">0</span> <span class="op">=&gt;</span> <span class="dv">0</span><span class="op">,</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="dv">1</span> <span class="op">=&gt;</span> <span class="dv">1</span><span class="op">,</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        _ <span class="op">=&gt;</span> <span class="op">{</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            <span class="kw">let</span> <span class="kw">mut</span> a <span class="op">=</span> <span class="dv">0</span><span class="op">;</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>            <span class="kw">let</span> <span class="kw">mut</span> b <span class="op">=</span> <span class="dv">1</span><span class="op">;</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="dv">2</span><span class="op">..=</span>n <span class="op">{</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>                <span class="kw">let</span> temp <span class="op">=</span> a <span class="op">+</span> b<span class="op">;</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>                a <span class="op">=</span> b<span class="op">;</span></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>                b <span class="op">=</span> temp<span class="op">;</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>            <span class="op">}</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>            b</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a><span class="co">/// A Python module implemented in Rust.</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pymodule<span class="at">]</span></span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> my_rust_python_package(_py<span class="op">:</span> Python<span class="op">,</span> m<span class="op">:</span> <span class="op">&amp;</span>PyModule) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span>()<span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>    m<span class="op">.</span>add_function(<span class="pp">wrap_pyfunction!</span>(sum_as_string<span class="op">,</span> m)<span class="op">?</span>)<span class="op">?;</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>    m<span class="op">.</span>add_function(<span class="pp">wrap_pyfunction!</span>(multiply<span class="op">,</span> m)<span class="op">?</span>)<span class="op">?;</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    m<span class="op">.</span>add_function(<span class="pp">wrap_pyfunction!</span>(fibonacci<span class="op">,</span> m)<span class="op">?</span>)<span class="op">?;</span></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>    <span class="cn">Ok</span>(())</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>PyO3 Attributes
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><code>#[pyfunction]</code>: Exposes a Rust function to Python</li>
<li><code>#[pymodule]</code>: Creates a Python module from Rust code</li>
<li><code>PyResult&lt;T&gt;</code>: Standard return type for functions that can fail</li>
</ul>
</div>
</div>
</section>
<section id="working-with-python-objects" class="level3">
<h3 class="anchored" data-anchor-id="working-with-python-objects" id="working-with-python-objects">Working with Python Objects</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Working with Python Objects</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7" data-filename="Working with Python Objects"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::prelude::</span><span class="op">*;</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::types::</span><span class="op">{</span>PyDict<span class="op">,</span> PyList<span class="op">};</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co">/// Process a Python list of numbers</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> process_list(py<span class="op">:</span> Python<span class="op">,</span> list<span class="op">:</span> <span class="op">&amp;</span>PyList) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span><span class="dt">Vec</span><span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;&gt;</span> <span class="op">{</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> result <span class="op">=</span> <span class="dt">Vec</span><span class="pp">::</span>new()<span class="op">;</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> item <span class="kw">in</span> list <span class="op">{</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="kw">let</span> num<span class="op">:</span> <span class="dt">f64</span> <span class="op">=</span> item<span class="op">.</span>extract()<span class="op">?;</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        result<span class="op">.</span>push(num <span class="op">*</span> <span class="dv">2.0</span>)<span class="op">;</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="cn">Ok</span>(result)</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a><span class="co">/// Work with Python dictionaries</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> process_dict(dict<span class="op">:</span> <span class="op">&amp;</span>PyDict) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> sum <span class="op">=</span> <span class="dv">0.0</span><span class="op">;</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> (key<span class="op">,</span> value) <span class="kw">in</span> dict <span class="op">{</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>        <span class="kw">let</span> key_str<span class="op">:</span> <span class="dt">String</span> <span class="op">=</span> key<span class="op">.</span>extract()<span class="op">?;</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> key_str<span class="op">.</span>starts_with(<span class="st">"num_"</span>) <span class="op">{</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            <span class="kw">let</span> val<span class="op">:</span> <span class="dt">f64</span> <span class="op">=</span> value<span class="op">.</span>extract()<span class="op">?;</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            sum <span class="op">+=</span> val<span class="op">;</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    <span class="cn">Ok</span>(sum)</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</div>
</section>
<section id="creating-python-classes" class="level3">
<h3 class="anchored" data-anchor-id="creating-python-classes" id="creating-python-classes">Creating Python Classes</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Python Classes in Rust</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8" data-filename="Python Classes in Rust"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::prelude::</span><span class="op">*;</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyclass<span class="at">]</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="kw">struct</span> Counter <span class="op">{</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    value<span class="op">:</span> <span class="dt">i64</span><span class="op">,</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pymethods<span class="at">]</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="kw">impl</span> Counter <span class="op">{</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    <span class="at">#[</span>new<span class="at">]</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> new(initial_value<span class="op">:</span> <span class="dt">Option</span><span class="op">&lt;</span><span class="dt">i64</span><span class="op">&gt;</span>) <span class="op">-&gt;</span> <span class="dt">Self</span> <span class="op">{</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        Counter <span class="op">{</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>            value<span class="op">:</span> initial_value<span class="op">.</span>unwrap_or(<span class="dv">0</span>)<span class="op">,</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> increment(<span class="op">&amp;</span><span class="kw">mut</span> <span class="kw">self</span>) <span class="op">{</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        <span class="kw">self</span><span class="op">.</span>value <span class="op">+=</span> <span class="dv">1</span><span class="op">;</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> decrement(<span class="op">&amp;</span><span class="kw">mut</span> <span class="kw">self</span>) <span class="op">{</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="kw">self</span><span class="op">.</span>value <span class="op">-=</span> <span class="dv">1</span><span class="op">;</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    <span class="at">#[</span>getter<span class="at">]</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> value(<span class="op">&amp;</span><span class="kw">self</span>) <span class="op">-&gt;</span> <span class="dt">i64</span> <span class="op">{</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        <span class="kw">self</span><span class="op">.</span>value</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    <span class="at">#[</span>setter<span class="at">]</span></span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> set_value(<span class="op">&amp;</span><span class="kw">mut</span> <span class="kw">self</span><span class="op">,</span> value<span class="op">:</span> <span class="dt">i64</span>) <span class="op">{</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        <span class="kw">self</span><span class="op">.</span>value <span class="op">=</span> value<span class="op">;</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> __str__(<span class="op">&amp;</span><span class="kw">self</span>) <span class="op">-&gt;</span> <span class="dt">String</span> <span class="op">{</span></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        <span class="pp">format!</span>(<span class="st">"Counter({})"</span><span class="op">,</span> <span class="kw">self</span><span class="op">.</span>value)</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a><span class="co">// Add to your module function:</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a><span class="co">// m.add_class::&lt;Counter&gt;()?;</span></span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Class Attributes
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><code>#[pyclass]</code>: Makes a Rust struct available as a Python class</li>
<li><code>#[pymethods]</code>: Groups methods for a Python class</li>
<li><code>#[new]</code>: Constructor method</li>
<li><code>#[getter]</code>/<code>#[setter]</code>: Property accessors</li>
</ul>
</div>
</div>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">Error Handling</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Error Handling</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9" data-filename="Error Handling"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::prelude::</span><span class="op">*;</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::exceptions::</span>PyValueError<span class="op">;</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> divide(a<span class="op">:</span> <span class="dt">f64</span><span class="op">,</span> b<span class="op">:</span> <span class="dt">f64</span>) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> b <span class="op">==</span> <span class="dv">0.0</span> <span class="op">{</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Err</span>(<span class="pp">PyValueError::</span>new_err(<span class="st">"Cannot divide by zero"</span>))</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span> <span class="cf">else</span> <span class="op">{</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Ok</span>(a <span class="op">/</span> b)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="co">// Custom exception</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::</span>create_exception<span class="op">;</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="pp">create_exception!</span>(my_rust_python_package<span class="op">,</span> CustomError<span class="op">,</span> <span class="pp">pyo3::exceptions::</span>PyException)<span class="op">;</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> might_fail(should_fail<span class="op">:</span> <span class="dt">bool</span>) <span class="op">-&gt;</span> PyResult<span class="op">&lt;</span><span class="dt">String</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> should_fail <span class="op">{</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Err</span>(<span class="pp">CustomError::</span>new_err(<span class="st">"Something went wrong!"</span>))</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span> <span class="cf">else</span> <span class="op">{</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Ok</span>(<span class="st">"Success!"</span><span class="op">.</span>to_string())</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="building-and-testing" class="level2">
<h2 class="anchored" data-anchor-id="building-and-testing" id="building-and-testing">Building and Testing</h2>
<section id="development-build" class="level3">
<h3 class="anchored" data-anchor-id="development-build" id="development-build">Development Build</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build the package in development mode</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> develop</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Or with debug symbols</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> develop <span class="at">--release</span></span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Development vs Release
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Development builds are faster to compile but slower to run</li>
<li>Release builds are optimized for performance</li>
<li>Use development builds during iteration, release builds for benchmarking</li>
</ul>
</div>
</div>
</section>
<section id="production-build" class="level3">
<h3 class="anchored" data-anchor-id="production-build" id="production-build">Production Build</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build wheel for current platform</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> build <span class="at">--release</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Build for multiple platforms (requires cross-compilation setup)</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> build <span class="at">--release</span> <span class="at">--target</span> x86_64-unknown-linux-gnu</span></code></pre></div></div>
</section>
<section id="testing-the-package" class="level3">
<h3 class="anchored" data-anchor-id="testing-the-package" id="testing-the-package">Testing the Package</h3>
<p>Create a test script <code>test_package.py</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>test_package.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12" data-filename="test_package.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> my_rust_python_package <span class="im">as</span> pkg</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Test basic functions</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(pkg.sum_as_string(<span class="dv">5</span>, <span class="dv">20</span>))  <span class="co"># "25"</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(pkg.multiply(<span class="fl">3.5</span>, <span class="fl">2.0</span>))    <span class="co"># 7.0</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(pkg.fibonacci(<span class="dv">10</span>))         <span class="co"># 55</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Test class</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>counter <span class="op">=</span> pkg.Counter(<span class="dv">10</span>)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>counter.increment()</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(counter.value)  <span class="co"># 11</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="bu">str</span>(counter))   <span class="co"># "Counter(11)"</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Test error handling</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    pkg.divide(<span class="dv">10</span>, <span class="dv">0</span>)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">ValueError</span> <span class="im">as</span> e:</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Caught error: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="python-integration" class="level2">
<h2 class="anchored" data-anchor-id="python-integration" id="python-integration">Python Integration</h2>
<section id="package-initialization" class="level3">
<h3 class="anchored" data-anchor-id="package-initialization" id="package-initialization">Package Initialization</h3>
<p>Edit <code>python/my_rust_python_package/__init__.py</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>python/my_rust_python_package/__init__.py</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13" data-filename="python/my_rust_python_package/__init__.py"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> .my_rust_python_package <span class="im">import</span> <span class="op">*</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>__version__ <span class="op">=</span> <span class="st">"0.1.0"</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>__author__ <span class="op">=</span> <span class="st">"Your Name"</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="co"># You can add pure Python code here too</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> python_helper_function(data):</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""A helper function written in Python."""</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [fibonacci(x) <span class="cf">for</span> x <span class="kw">in</span> data <span class="cf">if</span> x <span class="op">&gt;</span> <span class="dv">0</span>]</span></code></pre></div></div>
</div>
</section>
<section id="type-hints" class="level3">
<h3 class="anchored" data-anchor-id="type-hints" id="type-hints">Type Hints</h3>
<p>Create <code>python/my_rust_python_package/__init__.pyi</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>python/my_rust_python_package/__init__.pyi</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14" data-filename="python/my_rust_python_package/__init__.pyi"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Dict, Any, Optional</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> sum_as_string(a: <span class="bu">int</span>, b: <span class="bu">int</span>) <span class="op">-&gt;</span> <span class="bu">str</span>: ...</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multiply(a: <span class="bu">float</span>, b: <span class="bu">float</span>) <span class="op">-&gt;</span> <span class="bu">float</span>: ...</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci(n: <span class="bu">int</span>) <span class="op">-&gt;</span> <span class="bu">int</span>: ...</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_list(lst: List[<span class="bu">float</span>]) <span class="op">-&gt;</span> List[<span class="bu">float</span>]: ...</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_dict(d: Dict[<span class="bu">str</span>, Any]) <span class="op">-&gt;</span> <span class="bu">float</span>: ...</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> divide(a: <span class="bu">float</span>, b: <span class="bu">float</span>) <span class="op">-&gt;</span> <span class="bu">float</span>: ...</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Counter:</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, initial_value: Optional[<span class="bu">int</span>] <span class="op">=</span> <span class="va">None</span>) <span class="op">-&gt;</span> <span class="va">None</span>: ...</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> increment(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="va">None</span>: ...</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decrement(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="va">None</span>: ...</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>    <span class="at">@property</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> value(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="bu">int</span>: ...</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    <span class="at">@value.setter</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> value(<span class="va">self</span>, value: <span class="bu">int</span>) <span class="op">-&gt;</span> <span class="va">None</span>: ...</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__str__</span>(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="bu">str</span>: ...</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CustomError(<span class="pp">Exception</span>): ...</span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Type Stub Files
</div>
</div>
<div class="callout-body-container callout-body">
<p>Type stub files (<code>.pyi</code>) provide type information for Python tooling like mypy, IDEs, and static analysis tools. They’re crucial for a good developer experience.</p>
</div>
</div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="using-rusts-parallel-processing" class="level3">
<h3 class="anchored" data-anchor-id="using-rusts-parallel-processing" id="using-rusts-parallel-processing">Using Rust’s Parallel Processing</h3>
<p>Add to <code>Cargo.toml</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Cargo.toml - Add Rayon</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15" data-filename="Cargo.toml - Add Rayon"><pre class="sourceCode toml code-with-copy"><code class="sourceCode toml"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">[dependencies]</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="dt">rayon</span> <span class="op">=</span> <span class="st">"1.7"</span></span></code></pre></div></div>
</div>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>Parallel Processing</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16" data-filename="Parallel Processing"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">rayon::prelude::</span><span class="op">*;</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> parallel_sum(numbers<span class="op">:</span> <span class="dt">Vec</span><span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;</span>) <span class="op">-&gt;</span> <span class="dt">f64</span> <span class="op">{</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    numbers<span class="op">.</span>par_iter()<span class="op">.</span>sum()</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> parallel_fibonacci(numbers<span class="op">:</span> <span class="dt">Vec</span><span class="op">&lt;</span><span class="dt">u64</span><span class="op">&gt;</span>) <span class="op">-&gt;</span> <span class="dt">Vec</span><span class="op">&lt;</span><span class="dt">u64</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    numbers<span class="op">.</span>par_iter()<span class="op">.</span>map(<span class="op">|&amp;</span>n<span class="op">|</span> fibonacci(n))<span class="op">.</span>collect()</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</div>
</section>
<section id="memory-efficient-operations" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficient-operations" id="memory-efficient-operations">Memory-Efficient Operations</h3>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>NumPy Integration</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17" data-filename="NumPy Integration"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">pyo3::prelude::</span><span class="op">*;</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">numpy::</span><span class="op">{</span>PyArray1<span class="op">,</span> PyReadonlyArray1<span class="op">};</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co">// Add numpy to Cargo.toml: numpy = "0.20"</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>pyfunction<span class="at">]</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> numpy_operation<span class="op">&lt;</span><span class="ot">'py</span><span class="op">&gt;</span>(</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    py<span class="op">:</span> Python<span class="op">&lt;</span><span class="ot">'py</span><span class="op">&gt;,</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    array<span class="op">:</span> PyReadonlyArray1<span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;,</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>) <span class="op">-&gt;</span> <span class="op">&amp;</span><span class="ot">'py</span> PyArray1<span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;</span> <span class="op">{</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> input <span class="op">=</span> array<span class="op">.</span>as_array()<span class="op">;</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> result<span class="op">:</span> <span class="dt">Vec</span><span class="op">&lt;</span><span class="dt">f64</span><span class="op">&gt;</span> <span class="op">=</span> input<span class="op">.</span>iter()<span class="op">.</span>map(<span class="op">|&amp;</span>x<span class="op">|</span> x <span class="op">*</span> x)<span class="op">.</span>collect()<span class="op">;</span></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    <span class="pp">PyArray1::</span>from_vec(py<span class="op">,</span> result)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="distribution-and-publishing" class="level2">
<h2 class="anchored" data-anchor-id="distribution-and-publishing" id="distribution-and-publishing">Distribution and Publishing</h2>
<section id="building-wheels" class="level3">
<h3 class="anchored" data-anchor-id="building-wheels" id="building-wheels">Building Wheels</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build for current platform</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> build <span class="at">--release</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Build for multiple platforms using cibuildwheel</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install cibuildwheel</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a><span class="ex">cibuildwheel</span> <span class="at">--platform</span> linux</span></code></pre></div></div>
</section>
<section id="github-actions-cicd" class="level3">
<h3 class="anchored" data-anchor-id="github-actions-cicd" id="github-actions-cicd">GitHub Actions CI/CD</h3>
<p>Create <code>.github/workflows/ci.yml</code>:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>.github/workflows/ci.yml</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19" data-filename=".github/workflows/ci.yml"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="fu">name</span><span class="kw">:</span><span class="at"> CI</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="fu">on</span><span class="kw">:</span></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">push</span><span class="kw">:</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">pull_request</span><span class="kw">:</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a><span class="fu">jobs</span><span class="kw">:</span></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">test</span><span class="kw">:</span></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ${{ matrix.os }}</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">strategy</span><span class="kw">:</span></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">matrix</span><span class="kw">:</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">os</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">ubuntu-latest</span><span class="kw">,</span><span class="at"> windows-latest</span><span class="kw">,</span><span class="at"> macos-latest</span><span class="kw">]</span></span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">'3.8'</span><span class="kw">,</span><span class="at"> </span><span class="st">'3.9'</span><span class="kw">,</span><span class="at"> </span><span class="st">'3.10'</span><span class="kw">,</span><span class="at"> </span><span class="st">'3.11'</span><span class="kw">]</span></span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/setup-python@v4</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> ${{ matrix.python-version }}</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> dtolnay/rust-toolchain@stable</span></span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Install maturin</span></span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">run</span><span class="kw">:</span><span class="at"> pip install maturin pytest</span></span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Build and test</span></span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a><span class="fu">      run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>        maturin develop</span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>        pytest tests/</span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a><span class="at">  </span></span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">build</span><span class="kw">:</span></span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">runs-on</span><span class="kw">:</span><span class="at"> ${{ matrix.os }}</span></span>
<span id="cb19-30"><a href="#cb19-30" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">strategy</span><span class="kw">:</span></span>
<span id="cb19-31"><a href="#cb19-31" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">matrix</span><span class="kw">:</span></span>
<span id="cb19-32"><a href="#cb19-32" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">os</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="at">ubuntu-latest</span><span class="kw">,</span><span class="at"> windows-latest</span><span class="kw">,</span><span class="at"> macos-latest</span><span class="kw">]</span></span>
<span id="cb19-33"><a href="#cb19-33" aria-hidden="true" tabindex="-1"></a><span class="at">    </span></span>
<span id="cb19-34"><a href="#cb19-34" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">steps</span><span class="kw">:</span></span>
<span id="cb19-35"><a href="#cb19-35" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/checkout@v4</span></span>
<span id="cb19-36"><a href="#cb19-36" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> dtolnay/rust-toolchain@stable</span></span>
<span id="cb19-37"><a href="#cb19-37" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/setup-python@v4</span></span>
<span id="cb19-38"><a href="#cb19-38" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb19-39"><a href="#cb19-39" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">python-version</span><span class="kw">:</span><span class="at"> </span><span class="st">'3.x'</span></span>
<span id="cb19-40"><a href="#cb19-40" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> Build wheels</span></span>
<span id="cb19-41"><a href="#cb19-41" aria-hidden="true" tabindex="-1"></a><span class="fu">      run</span><span class="kw">: </span><span class="ch">|</span></span>
<span id="cb19-42"><a href="#cb19-42" aria-hidden="true" tabindex="-1"></a>        pip install maturin</span>
<span id="cb19-43"><a href="#cb19-43" aria-hidden="true" tabindex="-1"></a>        maturin build --release</span>
<span id="cb19-44"><a href="#cb19-44" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">uses</span><span class="kw">:</span><span class="at"> actions/upload-artifact@v3</span></span>
<span id="cb19-45"><a href="#cb19-45" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">with</span><span class="kw">:</span></span>
<span id="cb19-46"><a href="#cb19-46" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">name</span><span class="kw">:</span><span class="at"> wheels</span></span>
<span id="cb19-47"><a href="#cb19-47" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">path</span><span class="kw">:</span><span class="at"> target/wheels</span></span></code></pre></div></div>
</div>
</section>
<section id="publishing-to-pypi" class="level3">
<h3 class="anchored" data-anchor-id="publishing-to-pypi" id="publishing-to-pypi">Publishing to PyPI</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install twine</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install twine</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Build the package</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> build <span class="at">--release</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Upload to PyPI</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a><span class="ex">twine</span> upload target/wheels/<span class="pp">*</span></span></code></pre></div></div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Publishing Checklist
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Test your package thoroughly before publishing</li>
<li>Use semantic versioning</li>
<li>Include comprehensive documentation</li>
<li>Test installation on clean environments</li>
</ul>
</div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="error-handling-1" class="level3">
<h3 class="anchored" data-anchor-id="error-handling-1" id="error-handling-1">1. Error Handling</h3>
<ul>
<li>Always use <code>PyResult&lt;T&gt;</code> for functions that might fail</li>
<li>Create custom exceptions for domain-specific errors</li>
<li>Provide clear error messages</li>
</ul>
</section>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">2. Memory Management</h3>
<ul>
<li>Leverage Rust’s ownership system</li>
<li>Use <code>PyReadonlyArray</code> for NumPy arrays when possible</li>
<li>Be mindful of GIL (Global Interpreter Lock) implications</li>
</ul>
</section>
<section id="api-design" class="level3">
<h3 class="anchored" data-anchor-id="api-design" id="api-design">3. API Design</h3>
<ul>
<li>Keep the Rust/Python boundary simple</li>
<li>Use appropriate Python types (lists, dicts, etc.)</li>
<li>Provide comprehensive type hints</li>
</ul>
</section>
<section id="testing" class="level3">
<h3 class="anchored" data-anchor-id="testing" id="testing">4. Testing</h3>
<ul>
<li>Write tests for both Rust and Python code</li>
<li>Use property-based testing with hypothesis</li>
<li>Test error conditions thoroughly</li>
</ul>
</section>
<section id="documentation" class="level3">
<h3 class="anchored" data-anchor-id="documentation" id="documentation">5. Documentation</h3>
<ul>
<li>Document all public functions and classes</li>
<li>Provide usage examples</li>
<li>Include performance benchmarks when relevant</li>
</ul>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues" class="level3">
<h3 class="anchored" data-anchor-id="common-issues" id="common-issues">Common Issues</h3>
<ol type="1">
<li><strong>Import Errors</strong>: Ensure module name in <code>Cargo.toml</code> matches the <code>#[pymodule]</code> name</li>
<li><strong>Build Failures</strong>: Check that all dependencies are properly specified</li>
<li><strong>Type Conversion Errors</strong>: Use appropriate PyO3 types for data exchange</li>
<li><strong>Performance Issues</strong>: Profile both Rust and Python code to identify bottlenecks</li>
</ol>
</section>
<section id="debugging" class="level3">
<h3 class="anchored" data-anchor-id="debugging" id="debugging">Debugging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build with debug symbols</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="ex">maturin</span> develop</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Use Python debugger</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> <span class="at">-m</span> pdb your_test_script.py</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Rust debugging (with debug build)</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a><span class="va">RUST_BACKTRACE</span><span class="op">=</span>1 <span class="ex">python</span> your_test_script.py</span></code></pre></div></div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Debugging Tips
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Use <code>println!</code> macros in Rust for simple debugging</li>
<li>Python’s <code>breakpoint()</code> function works well with Rust extensions</li>
<li>Consider using <code>gdb</code> or <code>lldb</code> for complex debugging scenarios</li>
</ul>
</div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide provides a solid foundation for creating Python packages with Rust backends. The combination offers excellent performance while maintaining Python’s ease of use and ecosystem compatibility.</p>
<p>Key takeaways:</p>
<ul>
<li><strong>Setup</strong>: Use maturin for seamless Rust-Python integration</li>
<li><strong>Development</strong>: Leverage PyO3’s powerful binding capabilities<br>
</li>
<li><strong>Performance</strong>: Utilize Rust’s speed and Python’s ecosystem</li>
<li><strong>Distribution</strong>: Standard Python packaging tools work seamlessly</li>
</ul>
<p>The Rust-Python ecosystem continues to evolve rapidly, making it an excellent choice for performance-critical Python applications.</p>
<hr>
</section>
<section id="further-reading" class="level2">
<h2 class="anchored" data-anchor-id="further-reading" id="further-reading">Further Reading</h2>
<ul>
<li><a href="https://pyo3.rs/">PyO3 User Guide</a></li>
<li><a href="https://github.com/PyO3/maturin">Maturin Documentation</a></li>
<li><a href="https://doc.rust-lang.org/book/">Rust Book</a></li>
<li><a href="https://packaging.python.org/">Python Packaging User Guide</a></li>
</ul>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Getting Started with Rust: A Complete Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/rust-getting-started/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/rust-getting-started/</guid>
      <pubDate>Sat, 14 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="getting-started-with-rust-a-complete-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/rust-getting-started/rust.png" class="img-fluid"></p>
<section id="what-is-rust" class="level2">
<h2 class="anchored" data-anchor-id="what-is-rust" id="what-is-rust">What is Rust?</h2>
<p>Rust is a systems programming language that focuses on safety, speed, and concurrency. It prevents common programming errors like null pointer dereferences and buffer overflows at compile time, while delivering performance comparable to C and C++. Rust is ideal for system programming, web backends, command-line tools, network services, and anywhere you need both performance and reliability.</p>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<section id="installing-rust-via-rustup" class="level3">
<h3 class="anchored" data-anchor-id="installing-rust-via-rustup" id="installing-rust-via-rustup">Installing Rust via Rustup</h3>
<p>The easiest way to install Rust is through <code>rustup</code>, the official Rust installer and version manager:</p>
<p><strong>On Linux/macOS:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> <span class="at">--proto</span> <span class="st">'=https'</span> <span class="at">--tlsv1.2</span> <span class="at">-sSf</span> https://sh.rustup.rs <span class="kw">|</span> <span class="fu">sh</span></span></code></pre></div></div>
<p><strong>On Windows:</strong> Download and run the installer from <a href="https://rustup.rs/">rustup.rs</a></p>
<p>After installation, restart your terminal and verify the installation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">rustc</span> <span class="at">--version</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> <span class="at">--version</span></span></code></pre></div></div>
</section>
<section id="what-gets-installed" class="level3">
<h3 class="anchored" data-anchor-id="what-gets-installed" id="what-gets-installed">What Gets Installed</h3>
<ul>
<li><code>rustc</code>: The Rust compiler</li>
<li><code>cargo</code>: Rust’s package manager and build tool</li>
<li><code>rustup</code>: Tool for managing Rust versions</li>
<li>Standard library documentation</li>
</ul>
</section>
</section>
<section id="your-first-rust-program" class="level2">
<h2 class="anchored" data-anchor-id="your-first-rust-program" id="your-first-rust-program">Your First Rust Program</h2>
<section id="hello-world" class="level3">
<h3 class="anchored" data-anchor-id="hello-world" id="hello-world">Hello World</h3>
<p>Create a new file called <code>main.rs</code>:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"Hello, world!"</span>)<span class="op">;</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
<p>Compile and run:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="ex">rustc</span> main.rs</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="ex">./main</span>  <span class="co"># On Windows: main.exe</span></span></code></pre></div></div>
</section>
<section id="using-cargo-recommended" class="level3">
<h3 class="anchored" data-anchor-id="using-cargo-recommended" id="using-cargo-recommended">Using Cargo (Recommended)</h3>
<p>Cargo is Rust’s build system and package manager. Create a new project:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> new hello_rust</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> hello_rust</span></code></pre></div></div>
<p>This creates a project structure:</p>
<pre><code>hello_rust/
├── Cargo.toml
└── src/
    └── main.rs</code></pre>
<p>Run your project:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> run</span></code></pre></div></div>
<p>Build without running:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> build</span></code></pre></div></div>
</section>
</section>
<section id="core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts" id="core-concepts">Core Concepts</h2>
<section id="variables-and-mutability" class="level3">
<h3 class="anchored" data-anchor-id="variables-and-mutability" id="variables-and-mutability">Variables and Mutability</h3>
<p>Variables are immutable by default in Rust:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> x <span class="op">=</span> <span class="dv">5</span><span class="op">;</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">// x = 6; // This would cause a compile error</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> y <span class="op">=</span> <span class="dv">5</span><span class="op">;</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> <span class="dv">6</span><span class="op">;</span> <span class="co">// This is fine because y is mutable</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"x = {}, y = {}"</span><span class="op">,</span> x<span class="op">,</span> y)<span class="op">;</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="data-types" class="level3">
<h3 class="anchored" data-anchor-id="data-types" id="data-types">Data Types</h3>
<p>Rust has several built-in data types:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Integers</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> integer<span class="op">:</span> <span class="dt">i32</span> <span class="op">=</span> <span class="dv">42</span><span class="op">;</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> unsigned<span class="op">:</span> <span class="dt">u32</span> <span class="op">=</span> <span class="dv">42</span><span class="op">;</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Floating point</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> float<span class="op">:</span> <span class="dt">f64</span> <span class="op">=</span> <span class="dv">3.14</span><span class="op">;</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Boolean</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> is_rust_fun<span class="op">:</span> <span class="dt">bool</span> <span class="op">=</span> <span class="cn">true</span><span class="op">;</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Character</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> letter<span class="op">:</span> <span class="dt">char</span> <span class="op">=</span> <span class="ch">'R'</span><span class="op">;</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">// String</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> greeting<span class="op">:</span> <span class="dt">String</span> <span class="op">=</span> <span class="dt">String</span><span class="pp">::</span>from(<span class="st">"Hello"</span>)<span class="op">;</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> string_slice<span class="op">:</span> <span class="op">&amp;</span><span class="dt">str</span> <span class="op">=</span> <span class="st">"World"</span><span class="op">;</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"{} {} from Rust!"</span><span class="op">,</span> greeting<span class="op">,</span> string_slice)<span class="op">;</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="functions" class="level3">
<h3 class="anchored" data-anchor-id="functions" id="functions">Functions</h3>
<p>Functions are declared with the <code>fn</code> keyword:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> result <span class="op">=</span> add_numbers(<span class="dv">5</span><span class="op">,</span> <span class="dv">3</span>)<span class="op">;</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"5 + 3 = {}"</span><span class="op">,</span> result)<span class="op">;</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> add_numbers(a<span class="op">:</span> <span class="dt">i32</span><span class="op">,</span> b<span class="op">:</span> <span class="dt">i32</span>) <span class="op">-&gt;</span> <span class="dt">i32</span> <span class="op">{</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    a <span class="op">+</span> b <span class="co">// No semicolon means this is the return value</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="control-flow" class="level3">
<h3 class="anchored" data-anchor-id="control-flow" id="control-flow">Control Flow</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> number <span class="op">=</span> <span class="dv">6</span><span class="op">;</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">// If expressions</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> number <span class="op">%</span> <span class="dv">2</span> <span class="op">==</span> <span class="dv">0</span> <span class="op">{</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"{} is even"</span><span class="op">,</span> number)<span class="op">;</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span> <span class="cf">else</span> <span class="op">{</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"{} is odd"</span><span class="op">,</span> number)<span class="op">;</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Loops</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="dv">1</span><span class="op">..=</span><span class="dv">5</span> <span class="op">{</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"Count: {}"</span><span class="op">,</span> i)<span class="op">;</span></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> counter <span class="op">=</span> <span class="dv">0</span><span class="op">;</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> counter <span class="op">&lt;</span> <span class="dv">3</span> <span class="op">{</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"Counter: {}"</span><span class="op">,</span> counter)<span class="op">;</span></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        counter <span class="op">+=</span> <span class="dv">1</span><span class="op">;</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Infinite loop with break</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">loop</span> <span class="op">{</span></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"This runs once"</span>)<span class="op">;</span></span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span><span class="op">;</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
</section>
<section id="ownership-system" class="level2">
<h2 class="anchored" data-anchor-id="ownership-system" id="ownership-system">Ownership System</h2>
<p>Rust’s ownership system is what makes it memory-safe without a garbage collector:</p>
<section id="basic-ownership-rules" class="level3">
<h3 class="anchored" data-anchor-id="basic-ownership-rules" id="basic-ownership-rules">Basic Ownership Rules</h3>
<ol type="1">
<li>Each value has a single owner</li>
<li>When the owner goes out of scope, the value is dropped</li>
<li>There can only be one owner at a time</li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> s1 <span class="op">=</span> <span class="dt">String</span><span class="pp">::</span>from(<span class="st">"hello"</span>)<span class="op">;</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> s2 <span class="op">=</span> s1<span class="op">;</span> <span class="co">// s1 is moved to s2, s1 is no longer valid</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">// println!("{}", s1); // This would cause a compile error</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"{}"</span><span class="op">,</span> s2)<span class="op">;</span> <span class="co">// This works</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> s3 <span class="op">=</span> s2<span class="op">.</span>clone()<span class="op">;</span> <span class="co">// Explicitly clone the data</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"{} and {}"</span><span class="op">,</span> s2<span class="op">,</span> s3)<span class="op">;</span> <span class="co">// Both work now</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="borrowing-and-references" class="level3">
<h3 class="anchored" data-anchor-id="borrowing-and-references" id="borrowing-and-references">Borrowing and References</h3>
<p>Instead of moving ownership, you can borrow references:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> s <span class="op">=</span> <span class="dt">String</span><span class="pp">::</span>from(<span class="st">"hello"</span>)<span class="op">;</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> len <span class="op">=</span> calculate_length(<span class="op">&amp;</span>s)<span class="op">;</span> <span class="co">// Borrow s</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"Length of '{}' is {}"</span><span class="op">,</span> s<span class="op">,</span> len)<span class="op">;</span> <span class="co">// s is still valid</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> calculate_length(s<span class="op">:</span> <span class="op">&amp;</span><span class="dt">String</span>) <span class="op">-&gt;</span> <span class="dt">usize</span> <span class="op">{</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    s<span class="op">.</span>len()</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a><span class="op">}</span> <span class="co">// s goes out of scope but doesn't drop the data (it doesn't own it)</span></span></code></pre></div></div>
</section>
</section>
<section id="error-handling" class="level2">
<h2 class="anchored" data-anchor-id="error-handling" id="error-handling">Error Handling</h2>
<p>Rust uses <code>Result&lt;T, E&gt;</code> and <code>Option&lt;T&gt;</code> for error handling:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">std::fs::</span>File<span class="op">;</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">std::io::</span>ErrorKind<span class="op">;</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Option example</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> numbers <span class="op">=</span> <span class="pp">vec!</span>[<span class="dv">1</span><span class="op">,</span> <span class="dv">2</span><span class="op">,</span> <span class="dv">3</span><span class="op">,</span> <span class="dv">4</span><span class="op">,</span> <span class="dv">5</span>]<span class="op">;</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">match</span> numbers<span class="op">.</span>get(<span class="dv">10</span>) <span class="op">{</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Some</span>(value) <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Found: {}"</span><span class="op">,</span> value)<span class="op">,</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="cn">None</span> <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"No value at index 10"</span>)<span class="op">,</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    <span class="co">// Result example</span></span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> file_result <span class="op">=</span> <span class="pp">File::</span>open(<span class="st">"hello.txt"</span>)<span class="op">;</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">match</span> file_result <span class="op">{</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Ok</span>(file) <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"File opened successfully"</span>)<span class="op">,</span></span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="cn">Err</span>(error) <span class="op">=&gt;</span> <span class="cf">match</span> error<span class="op">.</span>kind() <span class="op">{</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>            <span class="pp">ErrorKind::</span>NotFound <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"File not found"</span>)<span class="op">,</span></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Error opening file: {:?}"</span><span class="op">,</span> error)<span class="op">,</span></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        <span class="op">},</span></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="working-with-collections" class="level2">
<h2 class="anchored" data-anchor-id="working-with-collections" id="working-with-collections">Working with Collections</h2>
<section id="vectors" class="level3">
<h3 class="anchored" data-anchor-id="vectors" id="vectors">Vectors</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> numbers <span class="op">=</span> <span class="pp">vec!</span>[<span class="dv">1</span><span class="op">,</span> <span class="dv">2</span><span class="op">,</span> <span class="dv">3</span>]<span class="op">;</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    numbers<span class="op">.</span>push(<span class="dv">4</span>)<span class="op">;</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> number <span class="kw">in</span> <span class="op">&amp;</span>numbers <span class="op">{</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"{}"</span><span class="op">,</span> number)<span class="op">;</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    <span class="pp">println!</span>(<span class="st">"Third element: {}"</span><span class="op">,</span> numbers[<span class="dv">2</span>])<span class="op">;</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="hashmaps" class="level3">
<h3 class="anchored" data-anchor-id="hashmaps" id="hashmaps">HashMaps</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">use</span> <span class="pp">std::collections::</span>HashMap<span class="op">;</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> <span class="kw">mut</span> scores <span class="op">=</span> <span class="pp">HashMap::</span>new()<span class="op">;</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    scores<span class="op">.</span>insert(<span class="st">"Blue"</span><span class="op">,</span> <span class="dv">10</span>)<span class="op">;</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    scores<span class="op">.</span>insert(<span class="st">"Red"</span><span class="op">,</span> <span class="dv">50</span>)<span class="op">;</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> (team<span class="op">,</span> score) <span class="kw">in</span> <span class="op">&amp;</span>scores <span class="op">{</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"{}: {}"</span><span class="op">,</span> team<span class="op">,</span> score)<span class="op">;</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
</section>
<section id="structs-and-enums" class="level2">
<h2 class="anchored" data-anchor-id="structs-and-enums" id="structs-and-enums">Structs and Enums</h2>
<section id="structs" class="level3">
<h3 class="anchored" data-anchor-id="structs" id="structs">Structs</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">struct</span> Person <span class="op">{</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    name<span class="op">:</span> <span class="dt">String</span><span class="op">,</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    age<span class="op">:</span> <span class="dt">u32</span><span class="op">,</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    email<span class="op">:</span> <span class="dt">String</span><span class="op">,</span></span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a><span class="kw">impl</span> Person <span class="op">{</span></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> new(name<span class="op">:</span> <span class="dt">String</span><span class="op">,</span> age<span class="op">:</span> <span class="dt">u32</span><span class="op">,</span> email<span class="op">:</span> <span class="dt">String</span>) <span class="op">-&gt;</span> Person <span class="op">{</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>        Person <span class="op">{</span> name<span class="op">,</span> age<span class="op">,</span> email <span class="op">}</span></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> greet(<span class="op">&amp;</span><span class="kw">self</span>) <span class="op">{</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        <span class="pp">println!</span>(<span class="st">"Hello, my name is {}"</span><span class="op">,</span> <span class="kw">self</span><span class="op">.</span>name)<span class="op">;</span></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> person <span class="op">=</span> <span class="pp">Person::</span>new(</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        <span class="dt">String</span><span class="pp">::</span>from(<span class="st">"Alice"</span>)<span class="op">,</span></span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        <span class="dv">30</span><span class="op">,</span></span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        <span class="dt">String</span><span class="pp">::</span>from(<span class="st">"alice@example.com"</span>)</span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>    )<span class="op">;</span></span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-24"><a href="#cb18-24" aria-hidden="true" tabindex="-1"></a>    person<span class="op">.</span>greet()<span class="op">;</span></span>
<span id="cb18-25"><a href="#cb18-25" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
<section id="enums" class="level3">
<h3 class="anchored" data-anchor-id="enums" id="enums">Enums</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">enum</span> Message <span class="op">{</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    Quit<span class="op">,</span></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    Move <span class="op">{</span> x<span class="op">:</span> <span class="dt">i32</span><span class="op">,</span> y<span class="op">:</span> <span class="dt">i32</span> <span class="op">},</span></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="bu">Write</span>(<span class="dt">String</span>)<span class="op">,</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    ChangeColor(<span class="dt">i32</span><span class="op">,</span> <span class="dt">i32</span><span class="op">,</span> <span class="dt">i32</span>)<span class="op">,</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a><span class="kw">impl</span> Message <span class="op">{</span></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> call(<span class="op">&amp;</span><span class="kw">self</span>) <span class="op">{</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">match</span> <span class="kw">self</span> <span class="op">{</span></span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>            <span class="pp">Message::</span>Quit <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Quitting"</span>)<span class="op">,</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>            <span class="pp">Message::</span>Move <span class="op">{</span> x<span class="op">,</span> y <span class="op">}</span> <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Moving to ({}, {})"</span><span class="op">,</span> x<span class="op">,</span> y)<span class="op">,</span></span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>            <span class="pp">Message::</span><span class="bu">Write</span>(text) <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Writing: {}"</span><span class="op">,</span> text)<span class="op">,</span></span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>            <span class="pp">Message::</span>ChangeColor(r<span class="op">,</span> g<span class="op">,</span> b) <span class="op">=&gt;</span> <span class="pp">println!</span>(<span class="st">"Changing color to ({}, {}, {})"</span><span class="op">,</span> r<span class="op">,</span> g<span class="op">,</span> b)<span class="op">,</span></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>        <span class="op">}</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> main() <span class="op">{</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">let</span> msg <span class="op">=</span> <span class="pp">Message::</span><span class="bu">Write</span>(<span class="dt">String</span><span class="pp">::</span>from(<span class="st">"Hello"</span>))<span class="op">;</span></span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>    msg<span class="op">.</span>call()<span class="op">;</span></span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
</section>
</section>
<section id="package-management-with-cargo" class="level2">
<h2 class="anchored" data-anchor-id="package-management-with-cargo" id="package-management-with-cargo">Package Management with Cargo</h2>
<section id="adding-dependencies" class="level3">
<h3 class="anchored" data-anchor-id="adding-dependencies" id="adding-dependencies">Adding Dependencies</h3>
<p>Edit your <code>Cargo.toml</code> file to add dependencies:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode toml code-with-copy"><code class="sourceCode toml"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">[dependencies]</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="dt">serde</span> <span class="op">=</span> <span class="st">"1.0"</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="dt">tokio</span> <span class="op">=</span> <span class="op">{ </span><span class="dt">version</span><span class="op"> =</span> <span class="st">"1.0"</span><span class="op">, </span><span class="dt">features</span><span class="op"> =</span> <span class="op">[</span><span class="st">"full"</span><span class="op">] }</span></span></code></pre></div></div>
<p>Then run:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> build</span></code></pre></div></div>
</section>
<section id="common-cargo-commands" class="level3">
<h3 class="anchored" data-anchor-id="common-cargo-commands" id="common-cargo-commands">Common Cargo Commands</h3>
<ul>
<li><code>cargo new project_name</code> - Create a new project</li>
<li><code>cargo build</code> - Compile the project</li>
<li><code>cargo run</code> - Compile and run the project</li>
<li><code>cargo test</code> - Run tests</li>
<li><code>cargo doc --open</code> - Generate and open documentation</li>
<li><code>cargo update</code> - Update dependencies</li>
<li><code>cargo clean</code> - Remove build artifacts</li>
</ul>
</section>
</section>
<section id="development-tools" class="level2">
<h2 class="anchored" data-anchor-id="development-tools" id="development-tools">Development Tools</h2>
<section id="formatting" class="level3">
<h3 class="anchored" data-anchor-id="formatting" id="formatting">Formatting</h3>
<p>Format your code automatically:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> fmt</span></code></pre></div></div>
</section>
<section id="linting" class="level3">
<h3 class="anchored" data-anchor-id="linting" id="linting">Linting</h3>
<p>Check for common mistakes and style issues:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> clippy</span></code></pre></div></div>
</section>
<section id="testing" class="level3">
<h3 class="anchored" data-anchor-id="testing" id="testing">Testing</h3>
<p>Write tests in the same file or separate test modules:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode rust code-with-copy"><code class="sourceCode rust"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="kw">fn</span> add(a<span class="op">:</span> <span class="dt">i32</span><span class="op">,</span> b<span class="op">:</span> <span class="dt">i32</span>) <span class="op">-&gt;</span> <span class="dt">i32</span> <span class="op">{</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>    a <span class="op">+</span> b</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a><span class="at">#[</span>cfg<span class="at">(</span>test<span class="at">)]</span></span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a><span class="kw">mod</span> tests <span class="op">{</span></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">use</span> <span class="kw">super</span><span class="pp">::</span><span class="op">*;</span></span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    <span class="at">#[</span>test<span class="at">]</span></span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">fn</span> test_add() <span class="op">{</span></span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>        <span class="pp">assert_eq!</span>(add(<span class="dv">2</span><span class="op">,</span> <span class="dv">3</span>)<span class="op">,</span> <span class="dv">5</span>)<span class="op">;</span></span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a>    <span class="op">}</span></span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a><span class="op">}</span></span></code></pre></div></div>
<p>Run tests with:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb25"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="ex">cargo</span> test</span></code></pre></div></div>
</section>
</section>
<section id="next-steps" class="level2">
<h2 class="anchored" data-anchor-id="next-steps" id="next-steps">Next Steps</h2>
<section id="learning-resources" class="level3">
<h3 class="anchored" data-anchor-id="learning-resources" id="learning-resources">Learning Resources</h3>
<ol type="1">
<li><strong>The Rust Programming Language Book</strong> - The official book, available online for free</li>
<li><strong>Rust by Example</strong> - Learn Rust through practical examples</li>
<li><strong>Rustlings</strong> - Small exercises to get you used to Rust syntax</li>
<li><strong>The Rust Reference</strong> - Detailed language reference</li>
<li><strong>Rust Standard Library Documentation</strong> - Comprehensive API documentation</li>
</ol>
</section>
<section id="practice-projects" class="level3">
<h3 class="anchored" data-anchor-id="practice-projects" id="practice-projects">Practice Projects</h3>
<ol type="1">
<li><strong>Command-line calculator</strong> - Practice basic syntax and user input</li>
<li><strong>File organizer</strong> - Learn file I/O and error handling</li>
<li><strong>Web scraper</strong> - Work with HTTP requests and HTML parsing</li>
<li><strong>Simple web server</strong> - Understand concurrency and networking</li>
<li><strong>Game of Life</strong> - Practice with 2D arrays and algorithms</li>
</ol>
</section>
<section id="join-the-community" class="level3">
<h3 class="anchored" data-anchor-id="join-the-community" id="join-the-community">Join the Community</h3>
<ul>
<li><strong>Rust Users Forum</strong> - Ask questions and share knowledge</li>
<li><strong>Reddit r/rust</strong> - Community discussions and news</li>
<li><strong>Discord/IRC</strong> - Real-time chat with other Rust developers</li>
<li><strong>Local Rust meetups</strong> - Find Rust developers in your area</li>
</ul>
</section>
</section>
<section id="tips-for-success" class="level2">
<h2 class="anchored" data-anchor-id="tips-for-success" id="tips-for-success">Tips for Success</h2>
<ol type="1">
<li><strong>Embrace the compiler</strong> - Rust’s compiler provides excellent error messages. Read them carefully</li>
<li><strong>Start small</strong> - Begin with simple programs and gradually increase complexity</li>
<li><strong>Practice ownership</strong> - The ownership system is unique to Rust, so it takes time to internalize</li>
<li><strong>Use the standard library</strong> - Rust has a rich standard library with excellent documentation</li>
<li><strong>Don’t fight the borrow checker</strong> - Learn to work with Rust’s safety guarantees rather than against them</li>
</ol>
<p>The Rust compiler is your friend and will help you write safe, fast code. Take time to understand the error messages, and don’t hesitate to refer to the official documentation when you’re stuck. Happy coding!</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Hugging Face Accelerate vs PyTorch Lightning Fabric: A Deep Dive Comparison]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/distributed/accelerate-vs-fabric/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/distributed/accelerate-vs-fabric/</guid>
      <pubDate>Tue, 03 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="hugging-face-accelerate-vs-pytorch-lightning-fabric-a-deep-dive-comparison" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/distributed/accelerate-vs-fabric/accvfab.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>When you’re working with deep learning models that need to scale across multiple GPUs or even multiple machines, you’ll quickly encounter the complexity of distributed training. Two libraries have emerged as popular solutions to simplify this challenge: <strong>Hugging Face Accelerate</strong> and <strong>PyTorch Lightning Fabric</strong>. While both aim to make distributed training more accessible, they take fundamentally different approaches to solving the problem.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>Think of these libraries as two different philosophies for handling the complexity of scaling machine learning workloads. Accelerate acts like a careful translator, taking your existing PyTorch code and automatically adapting it for distributed environments with minimal changes. Lightning Fabric, on the other hand, functions more like a structured framework that provides you with powerful tools and patterns, but asks you to organize your code in specific ways to unlock its full potential.</p>
</div>
</div>
</section>
<section id="understanding-the-core-philosophy" class="level2">
<h2 class="anchored" data-anchor-id="understanding-the-core-philosophy" id="understanding-the-core-philosophy">Understanding the Core Philosophy</h2>
<section id="hugging-face-accelerate-minimal-disruption" class="level3">
<h3 class="anchored" data-anchor-id="hugging-face-accelerate-minimal-disruption" id="hugging-face-accelerate-minimal-disruption">Hugging Face Accelerate: Minimal Disruption</h3>
<p>Hugging Face Accelerate was born from a simple but powerful idea: most researchers and practitioners already have working PyTorch code, and they shouldn’t need to rewrite everything just to scale it up. The library’s design philosophy centers around <strong>minimal code changes</strong>. You can take a training loop that works on a single GPU and, with just a few additional lines, make it work across multiple GPUs, TPUs, or even different machines.</p>
<p>The beauty of Accelerate lies in its transparency. When you wrap your model, optimizer, and data loader with Accelerate’s <code>prepare</code> function, the library handles the complex orchestration of distributed training behind the scenes. Your core training logic remains largely unchanged, which means you can focus on your model architecture and training strategies rather than wrestling with distributed computing concepts.</p>
</section>
<section id="lightning-fabric-structured-flexibility" class="level3">
<h3 class="anchored" data-anchor-id="lightning-fabric-structured-flexibility" id="lightning-fabric-structured-flexibility">Lightning Fabric: Structured Flexibility</h3>
<p>Lightning Fabric approaches the problem from a different angle. Rather than trying to be invisible, Fabric provides you with a set of powerful abstractions and tools that make distributed training not just possible, but elegant. It’s part of the broader PyTorch Lightning ecosystem, which has always emphasized best practices and reproducible research. Fabric gives you fine-grained control over the training process while still handling the low-level distributed computing details.</p>
</section>
</section>
<section id="code-integration-and-learning-curve" class="level2">
<h2 class="anchored" data-anchor-id="code-integration-and-learning-curve" id="code-integration-and-learning-curve">Code Integration and Learning Curve</h2>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Accelerate Approach</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Fabric Approach</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>When you’re starting with Accelerate, the learning curve feels remarkably gentle. To make standard PyTorch code work with Accelerate, you typically need to make just a few key changes:</p>
<ul>
<li>Initialize an <code>Accelerator</code> object</li>
<li>Wrap your model and optimizer with the <code>prepare</code> method</li>
<li>Replace your <code>loss.backward()</code> call with <code>accelerator.backward(loss)</code></li>
<li>The rest of your code can remain exactly as it was</li>
</ul>
<p>This approach has profound implications for how teams adopt distributed training. Junior developers can start using distributed training without needing to understand concepts like gradient synchronization, device placement, or communication backends.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>Lightning Fabric requires a bit more upfront learning, but this investment pays dividends in terms of flexibility and control. Fabric encourages you to structure your code using its abstractions, which might feel unfamiliar at first but lead to more maintainable and scalable codebases. You’ll work with:</p>
<ul>
<li>Fabric’s strategy system for distributed training</li>
<li>Device management for handling different hardware</li>
<li>Logging integrations for experiment tracking</li>
</ul>
<p>The key insight is that Fabric’s slightly steeper learning curve comes with corresponding benefits. Once you understand Fabric’s patterns, you’ll find it easier to implement complex training scenarios, debug distributed issues, and maintain consistency across different experiments.</p>
</div>
</div>
</div>
</section>
<section id="performance-and-optimization-capabilities" class="level2">
<h2 class="anchored" data-anchor-id="performance-and-optimization-capabilities" id="performance-and-optimization-capabilities">Performance and Optimization Capabilities</h2>
<p>Both libraries are built on top of PyTorch’s native distributed training capabilities, so their fundamental performance characteristics are quite similar. However, they differ in how they expose optimization opportunities to you as a developer.</p>
<section id="accelerates-automatic-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="accelerates-automatic-optimizations" id="accelerates-automatic-optimizations">Accelerate’s Automatic Optimizations</h3>
<p>Accelerate shines in its simplicity for standard use cases. The library automatically handles many optimization decisions for you, such as:</p>
<ul>
<li>Choosing appropriate communication backends</li>
<li>Managing memory efficiently across devices</li>
<li>Implementing gradient accumulation strategies</li>
</ul>
<p>For many common scenarios, particularly when training transformer models, Accelerate’s automatic optimizations work excellently out of the box.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Limitation
</div>
</div>
<div class="callout-body-container callout-body">
<p>This automation can sometimes work against you when you need fine-grained control. If you’re implementing custom gradient accumulation strategies, working with unusual model architectures, or need to optimize communication patterns for your specific hardware setup, Accelerate’s abstractions might feel limiting.</p>
</div>
</div>
</section>
<section id="fabrics-explicit-control" class="level3">
<h3 class="anchored" data-anchor-id="fabrics-explicit-control" id="fabrics-explicit-control">Fabric’s Explicit Control</h3>
<p>Lightning Fabric provides more explicit control over optimization decisions. You can:</p>
<ul>
<li>Choose specific distributed strategies</li>
<li>Customize how gradients are synchronized</li>
<li>Implement sophisticated mixed-precision training schemes</li>
</ul>
<p>This control comes at the cost of needing to understand what these choices mean, but it enables you to squeeze every bit of performance out of your hardware.</p>
</section>
</section>
<section id="code-examples-and-practical-implementation" class="level2">
<h2 class="anchored" data-anchor-id="code-examples-and-practical-implementation" id="code-examples-and-practical-implementation">Code Examples and Practical Implementation</h2>
<section id="hugging-face-accelerate-example" class="level3">
<h3 class="anchored" data-anchor-id="hugging-face-accelerate-example" id="hugging-face-accelerate-example">Hugging Face Accelerate Example</h3>
<div id="accelerate-example" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize accelerator - handles device placement and distributed setup</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator()</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Your existing model, optimizer, and data loader</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> YourModel()</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters())</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>train_dataloader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare everything for distributed training - this is the key step</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>model, optimizer, train_dataloader <span class="op">=</span> accelerator.prepare(</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    model, optimizer, train_dataloader</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Your training loop stays almost identical</span></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forward pass works exactly as before</span></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> outputs.loss</span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use accelerator.backward instead of loss.backward()</span></span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>    accelerator.backward(loss)</span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Logging works seamlessly across all processes</span></span>
<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a>    accelerator.log({<span class="st">"loss"</span>: loss.item()})</span></code></pre></div></div>
</div>
</section>
<section id="lightning-fabric-example" class="level3">
<h3 class="anchored" data-anchor-id="lightning-fabric-example" id="lightning-fabric-example">Lightning Fabric Example</h3>
<div id="fabric-example" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric <span class="im">import</span> Fabric</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize Fabric with explicit strategy choices</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(accelerator<span class="op">=</span><span class="st">"gpu"</span>, devices<span class="op">=</span><span class="dv">4</span>, strategy<span class="op">=</span><span class="st">"ddp"</span>)</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup model and optimizer</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> YourModel()</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters())</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup for distributed training - more explicit control</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>train_dataloader <span class="op">=</span> fabric.setup_dataloaders(DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>))</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop with explicit fabric calls</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Forward pass</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> outputs.loss</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Backward pass with fabric</span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>    fabric.backward(loss)</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Explicit logging with fabric</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>    fabric.log(<span class="st">"loss"</span>, loss.item())</span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Difference
</div>
</div>
<div class="callout-body-container callout-body">
<p>The code examples illustrate a fundamental distinction: Accelerate aims to make your existing code work with minimal changes, while Fabric provides more explicit control over the distributed training process.</p>
</div>
</div>
</section>
</section>
<section id="ecosystem-integration-and-tooling" class="level2">
<h2 class="anchored" data-anchor-id="ecosystem-integration-and-tooling" id="ecosystem-integration-and-tooling">Ecosystem Integration and Tooling</h2>
<section id="hugging-face-accelerate-ecosystem" class="level3">
<h3 class="anchored" data-anchor-id="hugging-face-accelerate-ecosystem" id="hugging-face-accelerate-ecosystem">Hugging Face Accelerate Ecosystem</h3>
<p>The ecosystem story reveals another important distinction between these libraries. Hugging Face Accelerate benefits from its tight integration with the broader Hugging Face ecosystem. Benefits include:</p>
<ul>
<li>Seamless interoperability with transformers and datasets libraries</li>
<li>Integration with popular experiment tracking tools</li>
<li>Support for various hardware configurations out of the box</li>
</ul>
</section>
<section id="lightning-fabric-ecosystem" class="level3">
<h3 class="anchored" data-anchor-id="lightning-fabric-ecosystem" id="lightning-fabric-ecosystem">Lightning Fabric Ecosystem</h3>
<p>Lightning Fabric is part of the comprehensive PyTorch Lightning ecosystem, which includes:</p>
<ul>
<li>Distributed training tools</li>
<li>Experiment management systems</li>
<li>Hyperparameter optimization utilities</li>
<li>Deployment tools</li>
</ul>
<p>This ecosystem approach means that once you invest in learning Fabric, you gain access to a complete toolkit for machine learning research and production.</p>
</section>
</section>
<section id="advanced-features-and-customization" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features-and-customization" id="advanced-features-and-customization">Advanced Features and Customization</h2>
<section id="memory-management-and-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-management-and-optimization" id="memory-management-and-optimization">Memory Management and Optimization</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Accelerate</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">Fabric</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<p>Accelerate provides automatic memory management features that work well for most use cases:</p>
<ul>
<li>Automatic gradient accumulation</li>
<li>Mixed precision training</li>
<li>Advanced techniques like gradient checkpointing</li>
</ul>
<p>These features work transparently, requiring minimal configuration from the user.</p>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<p>Lightning Fabric offers more granular control over memory management:</p>
<ul>
<li>Custom gradient accumulation strategies</li>
<li>Fine-tuned mixed precision settings</li>
<li>Advanced memory optimization techniques</li>
<li>Precise control over activation checkpointing</li>
</ul>
</div>
</div>
</div>
</section>
<section id="hardware-support-and-scalability" class="level3">
<h3 class="anchored" data-anchor-id="hardware-support-and-scalability" id="hardware-support-and-scalability">Hardware Support and Scalability</h3>
<p>Both libraries support a wide range of hardware configurations, from single GPUs to multi-node clusters:</p>
<ul>
<li><strong>Accelerate</strong>: Automatically detects hardware setup and configures itself accordingly</li>
<li><strong>Fabric</strong>: Provides explicit configuration options for different hardware setups</li>
</ul>
</section>
</section>
<section id="debugging-and-development-experience" class="level2">
<h2 class="anchored" data-anchor-id="debugging-and-development-experience" id="debugging-and-development-experience">Debugging and Development Experience</h2>
<div id="tbl-debugging" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-debugging-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Debugging Experience Comparison
</figcaption>
<div aria-describedby="tbl-debugging-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 27%">
<col style="width: 41%">
<col style="width: 31%">
</colgroup>
<thead>
<tr class="header">
<th>Aspect</th>
<th>Accelerate</th>
<th>Fabric</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Debugging Feel</strong></td>
<td>Similar to single-GPU debugging</td>
<td>More explicit debugging tools</td>
</tr>
<tr class="even">
<td><strong>Error Messages</strong></td>
<td>Standard PyTorch errors</td>
<td>Enhanced distributed training errors</td>
</tr>
<tr class="odd">
<td><strong>Problem Isolation</strong></td>
<td>Transparent issues</td>
<td>Structured error handling</td>
</tr>
<tr class="even">
<td><strong>Learning Curve</strong></td>
<td>Gentle, gradual</td>
<td>Steeper but more comprehensive</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="performance-benchmarks-and-real-world-usage" class="level2">
<h2 class="anchored" data-anchor-id="performance-benchmarks-and-real-world-usage" id="performance-benchmarks-and-real-world-usage">Performance Benchmarks and Real-World Usage</h2>
<p>In practice, both libraries perform similarly for most common use cases, since they’re both built on PyTorch’s native distributed training capabilities. The performance differences typically come from how well each library’s abstractions match your specific use case.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Performance Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Accelerate</strong>: Excels for transformer models and common architectures</li>
<li><strong>Fabric</strong>: Better performance for custom architectures with targeted optimizations</li>
</ul>
</div>
</div>
</section>
<section id="migration-and-adoption-strategies" class="level2">
<h2 class="anchored" data-anchor-id="migration-and-adoption-strategies" id="migration-and-adoption-strategies">Migration and Adoption Strategies</h2>
<section id="choosing-accelerate-when" class="level3">
<h3 class="anchored" data-anchor-id="choosing-accelerate-when" id="choosing-accelerate-when">Choosing Accelerate When:</h3>
<ul>
<li>You need to scale existing code quickly</li>
<li>Your team is new to distributed training</li>
<li>You’re working primarily with transformer models</li>
<li>You need rapid prototyping and iteration</li>
</ul>
</section>
<section id="choosing-fabric-when" class="level3">
<h3 class="anchored" data-anchor-id="choosing-fabric-when" id="choosing-fabric-when">Choosing Fabric When:</h3>
<ul>
<li>You need fine-grained control over training procedures</li>
<li>You’re implementing custom training algorithms</li>
<li>You want a comprehensive framework for multiple projects</li>
<li>You’re building production ML systems</li>
</ul>
</section>
</section>
<section id="future-considerations" class="level2">
<h2 class="anchored" data-anchor-id="future-considerations" id="future-considerations">Future Considerations</h2>
<p>Both libraries continue to evolve rapidly:</p>
<ul>
<li><strong>Accelerate</strong>: Development tied to Hugging Face ecosystem advances</li>
<li><strong>Fabric</strong>: Focuses on cutting-edge distributed training capabilities</li>
</ul>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Hugging Face Accelerate and PyTorch Lightning Fabric represent two excellent but philosophically different approaches to distributed training:</p>
<ul>
<li><strong>Accelerate</strong>: Prioritizes simplicity and ease of adoption</li>
<li><strong>Fabric</strong>: Emphasizes flexibility and control</li>
</ul>
<p>Neither choice is inherently better than the other. The right choice depends on your specific needs, team expertise, and project requirements. Both libraries will successfully help you move beyond single-GPU limitations and unlock the full potential of distributed computing for machine learning.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Final Recommendation
</div>
</div>
<div class="callout-body-container callout-body">
<p>The most important step is to start experimenting with distributed training, regardless of which library you choose. Both Accelerate and Fabric provide excellent foundations for learning distributed training concepts and scaling your machine learning workloads effectively.</p>
</div>
</div>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references" id="references">References</h2>
<ul>
<li><a href="https://huggingface.co/docs/accelerate">Hugging Face Accelerate Documentation</a></li>
<li><a href="https://lightning.ai/docs/fabric">PyTorch Lightning Fabric Documentation</a></li>
<li><a href="https://pytorch.org/tutorials/beginner/dist_overview.html">PyTorch Distributed Training Guide</a></li>
</ul>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Hugging Face Accelerate Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/distributed/hugging-face-accelerate/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/distributed/hugging-face-accelerate/</guid>
      <pubDate>Tue, 03 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="hugging-face-accelerate-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/distributed/hugging-face-accelerate/accelerate.png" class="img-fluid"></p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>This comprehensive code guide covers everything you need to know about Hugging Face Accelerate, from basic setup to advanced features like DeepSpeed integration. Accelerate simplifies distributed training and mixed precision training across multiple GPUs and nodes.</p>
</section>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<section id="installation" class="level3">
<h3 class="anchored" data-anchor-id="installation" id="installation">Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install accelerate</span></code></pre></div></div>
</section>
<section id="configuration" class="level3">
<h3 class="anchored" data-anchor-id="configuration" id="configuration">Configuration</h3>
<p>Run the configuration wizard to set up your training environment:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> config</span></code></pre></div></div>
<p>Or create a config file programmatically:</p>
<div id="a271a62a" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate.utils <span class="im">import</span> write_basic_config</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>write_basic_config(mixed_precision<span class="op">=</span><span class="st">"fp16"</span>)  <span class="co"># or "bf16", "no"</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="basic-concepts" class="level2">
<h2 class="anchored" data-anchor-id="basic-concepts" id="basic-concepts">Basic Concepts</h2>
<section id="the-accelerator-object" class="level3">
<h3 class="anchored" data-anchor-id="the-accelerator-object" id="the-accelerator-object">The Accelerator Object</h3>
<p>The <code>Accelerator</code> is the main class that handles device placement, gradient synchronization, and other distributed training concerns.</p>
<div id="27e8f2ac" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize accelerator</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator()</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Key properties</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> accelerator.device</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>is_main_process <span class="op">=</span> accelerator.is_main_process</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>num_processes <span class="op">=</span> accelerator.num_processes</span></code></pre></div></div>
</div>
</section>
<section id="device-placement" class="level3">
<h3 class="anchored" data-anchor-id="device-placement" id="device-placement">Device Placement</h3>
<p>Accelerate automatically handles device placement:</p>
<div id="24e1bd68" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Manual device placement (old way)</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>batch <span class="op">=</span> {k: v.to(device) <span class="cf">for</span> k, v <span class="kw">in</span> batch.items()}</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Accelerate way (automatic)</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>model, optimizer, dataloader <span class="op">=</span> accelerator.prepare(model, optimizer, dataloader)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># No need to move batch to device - accelerate handles it</span></span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Key Benefit
</div>
</div>
<div class="callout-body-container callout-body">
<p>With Accelerate, you don’t need to manually handle device placement. The <code>prepare()</code> method takes care of moving your model, optimizer, and dataloader to the appropriate devices.</p>
</div>
</div>
</section>
</section>
<section id="simple-training-loop" class="level2">
<h2 class="anchored" data-anchor-id="simple-training-loop" id="simple-training-loop">Simple Training Loop</h2>
<section id="basic-example" class="level3">
<h3 class="anchored" data-anchor-id="basic-example" id="basic-example">Basic Example</h3>
<div id="1bbdf928" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel, AutoTokenizer, AdamW</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_model():</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize accelerator</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    accelerator <span class="op">=</span> Accelerator()</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load model and tokenizer</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"bert-base-uncased"</span>)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    tokenizer <span class="op">=</span> AutoTokenizer.from_pretrained(<span class="st">"bert-base-uncased"</span>)</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create optimizer</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">5e-5</span>)</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create dataloader (your dataset here)</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    train_dataloader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">16</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare everything with accelerator</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    model, optimizer, train_dataloader <span class="op">=</span> accelerator.prepare(</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        model, optimizer, train_dataloader</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">3</span>):</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> outputs.loss</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass</span></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>            accelerator.backward(loss)</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Print loss (only on main process)</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> accelerator.is_main_process:</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>    train_model()</span></code></pre></div></div>
</div>
</section>
<section id="running-the-training" class="level3">
<h3 class="anchored" data-anchor-id="running-the-training" id="running-the-training">Running the Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Single GPU</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> train.py</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple GPUs</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> launch <span class="at">--num_processes</span><span class="op">=</span>2 train.py</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="co"># With config file</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> launch <span class="at">--config_file</span> config.yaml train.py</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="logging-and-tracking" class="level3">
<h3 class="anchored" data-anchor-id="logging-and-tracking" id="logging-and-tracking">Logging and Tracking</h3>
<div id="6d4db596" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate.logging <span class="im">import</span> get_logger</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize with logging</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(log_with<span class="op">=</span><span class="st">"tensorboard"</span>, project_dir<span class="op">=</span><span class="st">"./logs"</span>)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Get logger</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> get_logger(<span class="va">__name__</span>)</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Start tracking</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>accelerator.init_trackers(<span class="st">"my_experiment"</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Log metrics</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>accelerator.log({<span class="st">"train_loss"</span>: loss.item(), <span class="st">"epoch"</span>: epoch})</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="co"># End tracking</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>accelerator.end_training()</span></code></pre></div></div>
</div>
</section>
<section id="saving-and-loading-models" class="level3">
<h3 class="anchored" data-anchor-id="saving-and-loading-models" id="saving-and-loading-models">Saving and Loading Models</h3>
<div id="961e44d1" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Save model</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>accelerator.save_model(model, <span class="st">"path/to/save"</span>)</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Or save state dict</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>accelerator.save(model.state_dict(), <span class="st">"model.pt"</span>)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>accelerator.load_state(<span class="st">"model.pt"</span>)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Save complete training state</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>accelerator.save_state(<span class="st">"checkpoint_dir"</span>)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Load complete training state</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>accelerator.load_state(<span class="st">"checkpoint_dir"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="evaluation-loop" class="level3">
<h3 class="anchored" data-anchor-id="evaluation-loop" id="evaluation-loop">Evaluation Loop</h3>
<div id="b363330b" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate_model(model, eval_dataloader, accelerator):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    total_samples <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> eval_dataloader:</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> outputs.loss</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gather losses from all processes</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>            gathered_loss <span class="op">=</span> accelerator.gather(loss)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> gathered_loss.<span class="bu">sum</span>().item()</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>            total_samples <span class="op">+=</span> gathered_loss.shape[<span class="dv">0</span>]</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> total_loss <span class="op">/</span> total_samples</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> avg_loss</span></code></pre></div></div>
</div>
</section>
</section>
<section id="multi-gpu-training" class="level2">
<h2 class="anchored" data-anchor-id="multi-gpu-training" id="multi-gpu-training">Multi-GPU Training</h2>
<section id="data-parallel-training" class="level3">
<h3 class="anchored" data-anchor-id="data-parallel-training" id="data-parallel-training">Data Parallel Training</h3>
<div id="13503e82" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_multi_gpu():</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    accelerator <span class="op">=</span> Accelerator()</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model will be replicated across GPUs</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> MyModel()</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(model.parameters())</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare for multi-GPU</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    model, optimizer, train_dataloader <span class="op">=</span> accelerator.prepare(</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        model, optimizer, train_dataloader</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop remains the same</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Accelerate handles gradient synchronization</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        accelerator.backward(loss)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
<section id="launch-commands" class="level3">
<h3 class="anchored" data-anchor-id="launch-commands" id="launch-commands">Launch Commands</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Launch on 4 GPUs</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> launch <span class="at">--num_processes</span><span class="op">=</span>4 <span class="at">--multi_gpu</span> train.py</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Launch with specific GPUs</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0,1,3 <span class="ex">accelerate</span> launch <span class="at">--num_processes</span><span class="op">=</span>3 train.py</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Launch on multiple nodes</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> launch <span class="at">--num_processes</span><span class="op">=</span>8 <span class="at">--num_machines</span><span class="op">=</span>2 <span class="at">--main_process_ip</span><span class="op">=</span>192.168.1.1 train.py</span></code></pre></div></div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Multi-Node Training
</div>
</div>
<div class="callout-body-container callout-body">
<p>For multi-node training, ensure all nodes can communicate with each other and specify the correct IP address of the main process.</p>
</div>
</div>
</section>
</section>
<section id="mixed-precision-training" class="level2">
<h2 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">Mixed Precision Training</h2>
<section id="automatic-mixed-precision" class="level3">
<h3 class="anchored" data-anchor-id="automatic-mixed-precision" id="automatic-mixed-precision">Automatic Mixed Precision</h3>
<div id="a4dc7fa7" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable mixed precision in config or during initialization</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(mixed_precision<span class="op">=</span><span class="st">"fp16"</span>)  <span class="co"># or "bf16"</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop remains exactly the same</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> outputs.loss</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Accelerate handles scaling automatically</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    accelerator.backward(loss)</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
<section id="manual-mixed-precision-control" class="level3">
<h3 class="anchored" data-anchor-id="manual-mixed-precision-control" id="manual-mixed-precision-control">Manual Mixed Precision Control</h3>
<div id="4ddf06e4" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Access the scaler if needed</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> accelerator.mixed_precision <span class="op">==</span> <span class="st">"fp16"</span>:</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    scaler <span class="op">=</span> accelerator.scaler</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Manual scaling (usually not needed)</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    scaled_loss <span class="op">=</span> scaler.scale(loss)</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    scaled_loss.backward()</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    scaler.step(optimizer)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    scaler.update()</span></code></pre></div></div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Precision Types
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>fp16</strong>: Good for most cases, significant speedup</li>
<li><strong>bf16</strong>: Better numerical stability, requires newer hardware</li>
<li><strong>no</strong>: Full precision, slower but most stable</li>
</ul>
</div>
</div>
</section>
</section>
<section id="gradient-accumulation" class="level2">
<h2 class="anchored" data-anchor-id="gradient-accumulation" id="gradient-accumulation">Gradient Accumulation</h2>
<section id="basic-gradient-accumulation" class="level3">
<h3 class="anchored" data-anchor-id="basic-gradient-accumulation" id="basic-gradient-accumulation">Basic Gradient Accumulation</h3>
<div id="c6eb015e" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(gradient_accumulation_steps<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_dataloader:</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use accumulate context manager</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> accelerator.accumulate(model):</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        accelerator.backward(loss)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
<section id="dynamic-gradient-accumulation" class="level3">
<h3 class="anchored" data-anchor-id="dynamic-gradient-accumulation" id="dynamic-gradient-accumulation">Dynamic Gradient Accumulation</h3>
<div id="eb654bca" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_dynamic_accumulation():</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    accumulation_steps <span class="op">=</span> <span class="dv">2</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, batch <span class="kw">in</span> <span class="bu">enumerate</span>(train_dataloader):</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> outputs.loss <span class="op">/</span> accumulation_steps  <span class="co"># Scale loss</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        accelerator.backward(loss)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> (i <span class="op">+</span> <span class="dv">1</span>) <span class="op">%</span> accumulation_steps <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="deepspeed-integration" class="level2">
<h2 class="anchored" data-anchor-id="deepspeed-integration" id="deepspeed-integration">DeepSpeed Integration</h2>
<section id="deepspeed-configuration" class="level3">
<h3 class="anchored" data-anchor-id="deepspeed-configuration" id="deepspeed-configuration">DeepSpeed Configuration</h3>
<p>Create a DeepSpeed config file (<code>ds_config.json</code>):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode json code-with-copy"><code class="sourceCode json"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="fu">{</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"train_batch_size"</span><span class="fu">:</span> <span class="dv">32</span><span class="fu">,</span></span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"gradient_accumulation_steps"</span><span class="fu">:</span> <span class="dv">1</span><span class="fu">,</span></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"optimizer"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"type"</span><span class="fu">:</span> <span class="st">"Adam"</span><span class="fu">,</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"params"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>            <span class="dt">"lr"</span><span class="fu">:</span> <span class="dv">5e-5</span></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="fu">}</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>    <span class="fu">},</span></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"fp16"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"enabled"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>    <span class="fu">},</span></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    <span class="dt">"zero_optimization"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"stage"</span><span class="fu">:</span> <span class="dv">2</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a><span class="fu">}</span></span></code></pre></div></div>
</section>
<section id="using-deepspeed" class="level3">
<h3 class="anchored" data-anchor-id="using-deepspeed" id="using-deepspeed">Using DeepSpeed</h3>
<div id="46dace9e" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize with DeepSpeed</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(deepspeed_plugin<span class="op">=</span><span class="st">"ds_config.json"</span>)</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Or programmatically</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> DeepSpeedPlugin</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>ds_plugin <span class="op">=</span> DeepSpeedPlugin(</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    gradient_accumulation_steps<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    zero_stage<span class="op">=</span><span class="dv">2</span>,</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    offload_optimizer_device<span class="op">=</span><span class="st">"cpu"</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(deepspeed_plugin<span class="op">=</span>ds_plugin)</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Training code remains the same</span></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> accelerator.prepare(model, optimizer)</span></code></pre></div></div>
</div>
</section>
<section id="launch-with-deepspeed" class="level3">
<h3 class="anchored" data-anchor-id="launch-with-deepspeed" id="launch-with-deepspeed">Launch with DeepSpeed</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="ex">accelerate</span> launch <span class="at">--config_file</span> ds_config.yaml train.py</span></code></pre></div></div>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<section id="memory-issues" class="level4">
<h4 class="anchored" data-anchor-id="memory-issues">Memory Issues</h4>
<div id="f8fb08fb" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clear cache regularly</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> accelerator.is_main_process:</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache()</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Use gradient checkpointing</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>model.gradient_checkpointing_enable()</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Reduce batch size or increase gradient accumulation</span></span></code></pre></div></div>
</div>
</section>
<section id="synchronization-issues" class="level4">
<h4 class="anchored" data-anchor-id="synchronization-issues">Synchronization Issues</h4>
<div id="7b499b25" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Wait for all processes</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>accelerator.wait_for_everyone()</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Gather data from all processes</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>all_losses <span class="op">=</span> accelerator.gather(loss)</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Reduce across processes</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>avg_loss <span class="op">=</span> accelerator.<span class="bu">reduce</span>(loss, reduction<span class="op">=</span><span class="st">"mean"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="debugging" class="level4">
<h4 class="anchored" data-anchor-id="debugging">Debugging</h4>
<div id="3f1eed30" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable debug mode</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(debug<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Check if running in distributed mode</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> accelerator.distributed_type <span class="op">!=</span> <span class="st">"NO"</span>:</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Running on </span><span class="sc">{</span>accelerator<span class="sc">.</span>num_processes<span class="sc">}</span><span class="ss"> processes"</span>)</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Print only on main process</span></span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>accelerator.<span class="bu">print</span>(<span class="st">"This will only print once"</span>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="performance-tips" class="level3">
<h3 class="anchored" data-anchor-id="performance-tips" id="performance-tips">Performance Tips</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Optimization Strategies
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Use appropriate batch sizes</strong>: Larger batch sizes generally improve GPU utilization</li>
<li><strong>Enable mixed precision</strong>: Use fp16 or bf16 for faster training</li>
<li><strong>Gradient accumulation</strong>: Simulate larger batch sizes without memory issues</li>
<li><strong>DataLoader optimization</strong>: Use <code>num_workers</code> and <code>pin_memory=True</code></li>
<li><strong>Compile models</strong>: Use <code>torch.compile()</code> for PyTorch 2.0+</li>
</ol>
</div>
</div>
<div id="e9e3b453" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimized setup</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>accelerator <span class="op">=</span> Accelerator(</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>    mixed_precision<span class="op">=</span><span class="st">"bf16"</span>,</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>    gradient_accumulation_steps<span class="op">=</span><span class="dv">4</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile model (PyTorch 2.0+)</span></span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> torch.<span class="bu">compile</span>(model)</span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimized DataLoader</span></span>
<span id="cb23-11"><a href="#cb23-11" aria-hidden="true" tabindex="-1"></a>train_dataloader <span class="op">=</span> DataLoader(</span>
<span id="cb23-12"><a href="#cb23-12" aria-hidden="true" tabindex="-1"></a>    dataset,</span>
<span id="cb23-13"><a href="#cb23-13" aria-hidden="true" tabindex="-1"></a>    batch_size<span class="op">=</span><span class="dv">32</span>,</span>
<span id="cb23-14"><a href="#cb23-14" aria-hidden="true" tabindex="-1"></a>    num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb23-15"><a href="#cb23-15" aria-hidden="true" tabindex="-1"></a>    pin_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb23-16"><a href="#cb23-16" aria-hidden="true" tabindex="-1"></a>    shuffle<span class="op">=</span><span class="va">True</span></span>
<span id="cb23-17"><a href="#cb23-17" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="complete-example-bert-fine-tuning" class="level2">
<h2 class="anchored" data-anchor-id="complete-example-bert-fine-tuning" id="complete-example-bert-fine-tuning">Complete Example: BERT Fine-tuning</h2>
<p>Here’s a complete example showing how to fine-tune BERT for sequence classification:</p>
<div id="c638064d" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoTokenizer, AutoModelForSequenceClassification, AdamW</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> accelerate <span class="im">import</span> Accelerator</span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> datasets <span class="im">import</span> load_dataset</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize</span></span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>    accelerator <span class="op">=</span> Accelerator(mixed_precision<span class="op">=</span><span class="st">"fp16"</span>)</span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load data</span></span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a>    dataset <span class="op">=</span> load_dataset(<span class="st">"imdb"</span>)</span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>    tokenizer <span class="op">=</span> AutoTokenizer.from_pretrained(<span class="st">"bert-base-uncased"</span>)</span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-15"><a href="#cb24-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> tokenize_function(examples):</span>
<span id="cb24-16"><a href="#cb24-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> tokenizer(examples[<span class="st">"text"</span>], truncation<span class="op">=</span><span class="va">True</span>, padding<span class="op">=</span><span class="st">"max_length"</span>)</span>
<span id="cb24-17"><a href="#cb24-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-18"><a href="#cb24-18" aria-hidden="true" tabindex="-1"></a>    tokenized_datasets <span class="op">=</span> dataset.<span class="bu">map</span>(tokenize_function, batched<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb24-19"><a href="#cb24-19" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> tokenized_datasets[<span class="st">"train"</span>].with_format(<span class="st">"torch"</span>)</span>
<span id="cb24-20"><a href="#cb24-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-21"><a href="#cb24-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model and optimizer</span></span>
<span id="cb24-22"><a href="#cb24-22" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> AutoModelForSequenceClassification.from_pretrained(</span>
<span id="cb24-23"><a href="#cb24-23" aria-hidden="true" tabindex="-1"></a>        <span class="st">"bert-base-uncased"</span>, num_labels<span class="op">=</span><span class="dv">2</span></span>
<span id="cb24-24"><a href="#cb24-24" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb24-25"><a href="#cb24-25" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">5e-5</span>)</span>
<span id="cb24-26"><a href="#cb24-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-27"><a href="#cb24-27" aria-hidden="true" tabindex="-1"></a>    <span class="co"># DataLoader</span></span>
<span id="cb24-28"><a href="#cb24-28" aria-hidden="true" tabindex="-1"></a>    train_dataloader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">16</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb24-29"><a href="#cb24-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-30"><a href="#cb24-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare everything</span></span>
<span id="cb24-31"><a href="#cb24-31" aria-hidden="true" tabindex="-1"></a>    model, optimizer, train_dataloader <span class="op">=</span> accelerator.prepare(</span>
<span id="cb24-32"><a href="#cb24-32" aria-hidden="true" tabindex="-1"></a>        model, optimizer, train_dataloader</span>
<span id="cb24-33"><a href="#cb24-33" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb24-34"><a href="#cb24-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-35"><a href="#cb24-35" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb24-36"><a href="#cb24-36" aria-hidden="true" tabindex="-1"></a>    num_epochs <span class="op">=</span> <span class="dv">3</span></span>
<span id="cb24-37"><a href="#cb24-37" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb24-38"><a href="#cb24-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-39"><a href="#cb24-39" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb24-40"><a href="#cb24-40" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb24-41"><a href="#cb24-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> step, batch <span class="kw">in</span> <span class="bu">enumerate</span>(train_dataloader):</span>
<span id="cb24-42"><a href="#cb24-42" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(<span class="op">**</span>batch)</span>
<span id="cb24-43"><a href="#cb24-43" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> outputs.loss</span>
<span id="cb24-44"><a href="#cb24-44" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-45"><a href="#cb24-45" aria-hidden="true" tabindex="-1"></a>            accelerator.backward(loss)</span>
<span id="cb24-46"><a href="#cb24-46" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb24-47"><a href="#cb24-47" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb24-48"><a href="#cb24-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-49"><a href="#cb24-49" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb24-50"><a href="#cb24-50" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-51"><a href="#cb24-51" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> step <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span> <span class="kw">and</span> accelerator.is_main_process:</span>
<span id="cb24-52"><a href="#cb24-52" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Step </span><span class="sc">{</span>step<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb24-53"><a href="#cb24-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-54"><a href="#cb24-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> accelerator.is_main_process:</span>
<span id="cb24-55"><a href="#cb24-55" aria-hidden="true" tabindex="-1"></a>            avg_loss <span class="op">=</span> total_loss <span class="op">/</span> <span class="bu">len</span>(train_dataloader)</span>
<span id="cb24-56"><a href="#cb24-56" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss"> completed. Average loss: </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb24-57"><a href="#cb24-57" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-58"><a href="#cb24-58" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Save model</span></span>
<span id="cb24-59"><a href="#cb24-59" aria-hidden="true" tabindex="-1"></a>    accelerator.wait_for_everyone()</span>
<span id="cb24-60"><a href="#cb24-60" aria-hidden="true" tabindex="-1"></a>    unwrapped_model <span class="op">=</span> accelerator.unwrap_model(model)</span>
<span id="cb24-61"><a href="#cb24-61" aria-hidden="true" tabindex="-1"></a>    unwrapped_model.save_pretrained(<span class="st">"./fine_tuned_bert"</span>)</span>
<span id="cb24-62"><a href="#cb24-62" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb24-63"><a href="#cb24-63" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> accelerator.is_main_process:</span>
<span id="cb24-64"><a href="#cb24-64" aria-hidden="true" tabindex="-1"></a>        tokenizer.save_pretrained(<span class="st">"./fine_tuned_bert"</span>)</span>
<span id="cb24-65"><a href="#cb24-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-66"><a href="#cb24-66" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb24-67"><a href="#cb24-67" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
<section id="summary" class="level2">
<h2 class="anchored" data-anchor-id="summary" id="summary">Summary</h2>
<p>This guide covers the essential aspects of using Hugging Face Accelerate for distributed training. The library abstracts away much of the complexity while providing fine-grained control when needed. Key takeaways:</p>
<ul>
<li><strong>Simplicity</strong>: Minimal code changes required for distributed training</li>
<li><strong>Flexibility</strong>: Works with any PyTorch model and training loop</li>
<li><strong>Performance</strong>: Built-in support for mixed precision and gradient accumulation</li>
<li><strong>Scalability</strong>: Easy scaling from single GPU to multi-node training</li>
<li><strong>Integration</strong>: Seamless integration with popular frameworks like DeepSpeed</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Next Steps
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>Explore the <a href="https://huggingface.co/docs/accelerate/">official Accelerate documentation</a></li>
<li>Try the examples with your own models and datasets</li>
<li>Experiment with different optimization strategies</li>
<li>Consider advanced features like FSDP for very large models</li>
</ul>
</div>
</div>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch Lightning Fabric Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/distributed/pytorch-fabric/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/distributed/pytorch-fabric/</guid>
      <pubDate>Tue, 03 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="pytorch-lightning-fabric-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/distributed/pytorch-fabric/fabric.png" class="img-fluid"></p>
<p>I’ve created a comprehensive code guide for PyTorch Lightning Fabric that covers everything from basic setup to advanced distributed training features</p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Lightning Fabric is a lightweight PyTorch wrapper that provides essential training utilities without the overhead of the full Lightning framework. It’s perfect when you want more control over your training loop while still benefiting from distributed training, mixed precision, and other optimizations.</p>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install lightning</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># or</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install pytorch-lightning</span></code></pre></div></div>
</section>
<section id="basic-setup" class="level2">
<h2 class="anchored" data-anchor-id="basic-setup" id="basic-setup">Basic Setup</h2>
<section id="minimal-example" class="level3">
<h3 class="anchored" data-anchor-id="minimal-example" id="minimal-example">Minimal Example</h3>
<div id="64471d1a" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric <span class="im">import</span> Fabric</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize Fabric</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric()</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Your model</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> nn.Linear(<span class="dv">10</span>, <span class="dv">1</span>)</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.SGD(model.parameters(), lr<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup model and optimizer with Fabric</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Training step</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> model(batch).mean()</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>    fabric.backward(loss)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>    optimizer.step()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="core-components" class="level2">
<h2 class="anchored" data-anchor-id="core-components" id="core-components">Core Components</h2>
<section id="fabric-initialization" class="level3">
<h3 class="anchored" data-anchor-id="fabric-initialization" id="fabric-initialization">Fabric Initialization</h3>
<div id="8b533a09" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric <span class="im">import</span> Fabric</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic initialization</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric()</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="co"># With specific configuration</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"gpu"</span>,           <span class="co"># "cpu", "gpu", "tpu", "auto"</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    strategy<span class="op">=</span><span class="st">"ddp"</span>,              <span class="co"># "ddp", "fsdp", "deepspeed", etc.</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="dv">2</span>,                   <span class="co"># Number of devices</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    precision<span class="op">=</span><span class="st">"16-mixed"</span>,        <span class="co"># "32", "16-mixed", "bf16-mixed"</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    plugins<span class="op">=</span>[],                  <span class="co"># Custom plugins</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Launch the fabric</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span></code></pre></div></div>
</div>
</section>
<section id="model-and-optimizer-setup" class="level3">
<h3 class="anchored" data-anchor-id="model-and-optimizer-setup" id="model-and-optimizer-setup">Model and Optimizer Setup</h3>
<div id="d0a27358" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(nn.Module):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, hidden_size, output_size):</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(input_size, hidden_size)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(hidden_size, output_size)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relu <span class="op">=</span> nn.ReLU()</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(<span class="fl">0.1</span>)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.relu(<span class="va">self</span>.fc1(x))</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model and optimizer</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel(<span class="dv">784</span>, <span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> torch.optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">10</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup with Fabric</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> fabric.setup(scheduler)</span></code></pre></div></div>
</div>
</section>
<section id="dataloader-setup" class="level3">
<h3 class="anchored" data-anchor-id="dataloader-setup" id="dataloader-setup">DataLoader Setup</h3>
<div id="156c27d3" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, TensorDataset</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Create your dataset</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>dataset <span class="op">=</span> TensorDataset(torch.randn(<span class="dv">1000</span>, <span class="dv">784</span>), torch.randint(<span class="dv">0</span>, <span class="dv">10</span>, (<span class="dv">1000</span>,)))</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>dataloader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup with Fabric</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>dataloader <span class="op">=</span> fabric.setup_dataloaders(dataloader)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="training-loop" class="level2">
<h2 class="anchored" data-anchor-id="training-loop" id="training-loop">Training Loop</h2>
<section id="basic-training-loop" class="level3">
<h3 class="anchored" data-anchor-id="basic-training-loop" id="basic-training-loop">Basic Training Loop</h3>
<div id="a261c731" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_epoch(fabric, model, optimizer, dataloader, criterion):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Zero gradients</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass with Fabric</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Optimizer step</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log every 100 batches</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>            fabric.<span class="bu">print</span>(<span class="ss">f'Batch </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(dataloader)</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> train_epoch(fabric, model, optimizer, dataloader, criterion)</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>    scheduler.step()</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>    fabric.<span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: Average Loss = </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="training-with-validation" class="level3">
<h3 class="anchored" data-anchor-id="training-with-validation" id="training-with-validation">Training with Validation</h3>
<div id="62dda45c" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate(fabric, model, val_dataloader, criterion):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data, target <span class="kw">in</span> val_dataloader:</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            pred <span class="op">=</span> output.argmax(dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (pred <span class="op">==</span> target).<span class="bu">sum</span>().item()</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> target.size(<span class="dv">0</span>)</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> correct <span class="op">/</span> total</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> total_loss <span class="op">/</span> <span class="bu">len</span>(val_dataloader)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> avg_loss, accuracy</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Complete training with validation</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>train_loader <span class="op">=</span> fabric.setup_dataloaders(train_dataloader)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>val_loader <span class="op">=</span> fabric.setup_dataloaders(val_dataloader)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>    train_loss <span class="op">=</span> train_epoch(fabric, model, optimizer, train_loader, criterion)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validation</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>    val_loss, val_acc <span class="op">=</span> validate(fabric, model, val_loader, criterion)</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>    fabric.<span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">:'</span>)</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>    fabric.<span class="bu">print</span>(<span class="ss">f'  Train Loss: </span><span class="sc">{</span>train_loss<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>    fabric.<span class="bu">print</span>(<span class="ss">f'  Val Loss: </span><span class="sc">{</span>val_loss<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>    fabric.<span class="bu">print</span>(<span class="ss">f'  Val Acc: </span><span class="sc">{</span>val_acc<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>    scheduler.step()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="multi-gpu-training" class="level2">
<h2 class="anchored" data-anchor-id="multi-gpu-training" id="multi-gpu-training">Multi-GPU Training</h2>
<section id="distributed-data-parallel-ddp" class="level3">
<h3 class="anchored" data-anchor-id="distributed-data-parallel-ddp" id="distributed-data-parallel-ddp">Distributed Data Parallel (DDP)</h3>
<div id="a1445d98" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize Fabric for multi-GPU</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"gpu"</span>,</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    strategy<span class="op">=</span><span class="st">"ddp"</span>,</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="dv">4</span>,  <span class="co"># Use 4 GPUs</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="co"># All-reduce for metrics across processes</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> all_reduce_mean(fabric, tensor):</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Average tensor across all processes"""</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    fabric.all_reduce(tensor, reduce_op<span class="op">=</span><span class="st">"mean"</span>)</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> tensor</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Training with distributed metrics</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_distributed(fabric, model, optimizer, dataloader, criterion):</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> torch.tensor(<span class="fl">0.0</span>, device<span class="op">=</span>fabric.device)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    num_batches <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> data, target <span class="kw">in</span> dataloader:</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.detach()</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Average loss across all processes</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> total_loss <span class="op">/</span> num_batches</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> all_reduce_mean(fabric, avg_loss)</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> avg_loss.item()</span></code></pre></div></div>
</div>
</section>
<section id="fully-sharded-data-parallel-fsdp" class="level3">
<h3 class="anchored" data-anchor-id="fully-sharded-data-parallel-fsdp" id="fully-sharded-data-parallel-fsdp">Fully Sharded Data Parallel (FSDP)</h3>
<div id="5b05811e" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># For very large models</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"gpu"</span>,</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    strategy<span class="op">=</span><span class="st">"fsdp"</span>,</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="dv">8</span>,</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    precision<span class="op">=</span><span class="st">"bf16-mixed"</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># FSDP automatically shards model parameters</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="mixed-precision" class="level2">
<h2 class="anchored" data-anchor-id="mixed-precision" id="mixed-precision">Mixed Precision</h2>
<section id="automatic-mixed-precision" class="level3">
<h3 class="anchored" data-anchor-id="automatic-mixed-precision" id="automatic-mixed-precision">Automatic Mixed Precision</h3>
<div id="c3843136" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable mixed precision</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(precision<span class="op">=</span><span class="st">"16-mixed"</span>)  <span class="co"># or "bf16-mixed"</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Training remains the same - Fabric handles precision automatically</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_amp(fabric, model, optimizer, dataloader, criterion):</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> data, target <span class="kw">in</span> dataloader:</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass (automatically uses mixed precision)</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass (handles gradient scaling)</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span></code></pre></div></div>
</div>
</section>
<section id="manual-precision-control" class="level3">
<h3 class="anchored" data-anchor-id="manual-precision-control" id="manual-precision-control">Manual Precision Control</h3>
<div id="cbbd81cc" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric.utilities <span class="im">import</span> rank_zero_only</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="at">@rank_zero_only</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_model_precision(model):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Log model parameter precisions (only on rank 0)"""</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, param <span class="kw">in</span> model.named_parameters():</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>param<span class="sc">.</span>dtype<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Check model precision after setup</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>log_model_precision(model)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="logging-and-checkpointing" class="level2">
<h2 class="anchored" data-anchor-id="logging-and-checkpointing" id="logging-and-checkpointing">Logging and Checkpointing</h2>
<section id="checkpointing" class="level3">
<h3 class="anchored" data-anchor-id="checkpointing" id="checkpointing">Checkpointing</h3>
<div id="d7b1ba1b" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> save_checkpoint(fabric, model, optimizer, epoch, loss, path):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Save model checkpoint"""</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    checkpoint <span class="op">=</span> {</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">"model"</span>: model,</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">"optimizer"</span>: optimizer,</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">"epoch"</span>: epoch,</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">"loss"</span>: loss</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    fabric.save(path, checkpoint)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_checkpoint(fabric, path):</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Load model checkpoint"""</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    checkpoint <span class="op">=</span> fabric.load(path)</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> checkpoint</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Save checkpoint</span></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>checkpoint_path <span class="op">=</span> <span class="ss">f"checkpoint_epoch_</span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">.ckpt"</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>save_checkpoint(fabric, model, optimizer, epoch, train_loss, checkpoint_path)</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Load checkpoint</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> os.path.exists(<span class="st">"checkpoint_epoch_5.ckpt"</span>):</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>    checkpoint <span class="op">=</span> load_checkpoint(fabric, <span class="st">"checkpoint_epoch_5.ckpt"</span>)</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> checkpoint[<span class="st">"model"</span>]</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> checkpoint[<span class="st">"optimizer"</span>]</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>    start_epoch <span class="op">=</span> checkpoint[<span class="st">"epoch"</span>] <span class="op">+</span> <span class="dv">1</span></span></code></pre></div></div>
</div>
</section>
<section id="logging-with-external-loggers" class="level3">
<h3 class="anchored" data-anchor-id="logging-with-external-loggers" id="logging-with-external-loggers">Logging with External Loggers</h3>
<div id="204ccfa0" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric.loggers <span class="im">import</span> TensorBoardLogger, CSVLogger</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize logger</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> TensorBoardLogger(<span class="st">"logs"</span>, name<span class="op">=</span><span class="st">"my_experiment"</span>)</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Setup Fabric with logger</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(loggers<span class="op">=</span>[logger])</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Log metrics</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_metrics(fabric, metrics, step):</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> logger <span class="kw">in</span> fabric.loggers:</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        logger.log_metrics(metrics, step)</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage in training loop</span></span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    train_loss <span class="op">=</span> train_epoch(fabric, model, optimizer, train_loader, criterion)</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    val_loss, val_acc <span class="op">=</span> validate(fabric, model, val_loader, criterion)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log metrics</span></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>    metrics <span class="op">=</span> {</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        <span class="st">"train_loss"</span>: train_loss,</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        <span class="st">"val_loss"</span>: val_loss,</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        <span class="st">"val_accuracy"</span>: val_acc,</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="st">"learning_rate"</span>: optimizer.param_groups[<span class="dv">0</span>][<span class="st">'lr'</span>]</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    log_metrics(fabric, metrics, epoch)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="custom-precision-plugin" class="level3">
<h3 class="anchored" data-anchor-id="custom-precision-plugin" id="custom-precision-plugin">Custom Precision Plugin</h3>
<div id="ea29933c" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric.plugins <span class="im">import</span> MixedPrecisionPlugin</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Custom precision configuration</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>precision_plugin <span class="op">=</span> MixedPrecisionPlugin(</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    precision<span class="op">=</span><span class="st">"16-mixed"</span>,</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    device<span class="op">=</span><span class="st">"cuda"</span>,</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    scaler_kwargs<span class="op">=</span>{<span class="st">"init_scale"</span>: <span class="dv">2</span><span class="op">**</span><span class="dv">16</span>}</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric(plugins<span class="op">=</span>[precision_plugin])</span></code></pre></div></div>
</div>
</section>
<section id="gradient-clipping" class="level3">
<h3 class="anchored" data-anchor-id="gradient-clipping" id="gradient-clipping">Gradient Clipping</h3>
<div id="e81d8805" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_grad_clipping(fabric, model, optimizer, dataloader, criterion, max_norm<span class="op">=</span><span class="fl">1.0</span>):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> data, target <span class="kw">in</span> dataloader:</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Gradient clipping</span></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>        fabric.clip_gradients(model, optimizer, max_norm<span class="op">=</span>max_norm)</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span></code></pre></div></div>
</div>
</section>
<section id="early-stopping" class="level3">
<h3 class="anchored" data-anchor-id="early-stopping" id="early-stopping">Early Stopping</h3>
<div id="1f3952ad" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> EarlyStopping:</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, patience<span class="op">=</span><span class="dv">10</span>, min_delta<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patience <span class="op">=</span> patience</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.min_delta <span class="op">=</span> min_delta</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.best_loss <span class="op">=</span> <span class="bu">float</span>(<span class="st">'inf'</span>)</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, val_loss):</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> val_loss <span class="op">&lt;</span> <span class="va">self</span>.best_loss <span class="op">-</span> <span class="va">self</span>.min_delta:</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.best_loss <span class="op">=</span> val_loss</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.counter <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.counter <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.counter <span class="op">&gt;=</span> <span class="va">self</span>.patience</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>early_stopping <span class="op">=</span> EarlyStopping(patience<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>    train_loss <span class="op">=</span> train_epoch(fabric, model, optimizer, train_loader, criterion)</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>    val_loss, val_acc <span class="op">=</span> validate(fabric, model, val_loader, criterion)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> early_stopping(val_loss):</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a>        fabric.<span class="bu">print</span>(<span class="ss">f"Early stopping at epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="proper-fabric-launch" class="level3">
<h3 class="anchored" data-anchor-id="proper-fabric-launch" id="proper-fabric-launch">Proper Fabric Launch</h3>
<div id="a0954c1d" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Always use fabric.launch() for proper initialization</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    fabric <span class="op">=</span> Fabric(accelerator<span class="op">=</span><span class="st">"gpu"</span>, devices<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    fabric.launch()</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Your training code here</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model()</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ... rest of training</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
<section id="rank-specific-operations" class="level3">
<h3 class="anchored" data-anchor-id="rank-specific-operations" id="rank-specific-operations">Rank-specific Operations</h3>
<div id="4d7003d7" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric.utilities <span class="im">import</span> rank_zero_only</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="at">@rank_zero_only</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> save_model_artifacts(model, path):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Only save on rank 0 to avoid conflicts"""</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    torch.save(model.state_dict(), path)</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a><span class="at">@rank_zero_only</span>  </span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> print_training_info(epoch, loss):</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Only print on rank 0 to avoid duplicate outputs"""</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</div>
</section>
<section id="proper-device-management" class="level3">
<h3 class="anchored" data-anchor-id="proper-device-management" id="proper-device-management">Proper Device Management</h3>
<div id="0fe6dfc7" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Let Fabric handle device placement</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>fabric <span class="op">=</span> Fabric()</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>fabric.launch()</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Don't manually move to device - Fabric handles this</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="co"># BAD: model.to(device), data.to(device)</span></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a><span class="co"># GOOD: Let fabric.setup() handle device placement</span></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>dataloader <span class="op">=</span> fabric.setup_dataloaders(dataloader)</span></code></pre></div></div>
</div>
</section>
<section id="memory-efficient-training" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficient-training" id="memory-efficient-training">Memory Efficient Training</h3>
<div id="e64a0501" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_training(fabric, model, optimizer, dataloader, criterion):</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use gradient checkpointing for large models</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(model, <span class="st">'gradient_checkpointing_enable'</span>):</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>            model.gradient_checkpointing_enable()</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Clear cache periodically</span></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>            torch.cuda.empty_cache()</span></code></pre></div></div>
</div>
</section>
<section id="complete-training-script-template" class="level3">
<h3 class="anchored" data-anchor-id="complete-training-script-template" id="complete-training-script-template">Complete Training Script Template</h3>
<div id="128353d2" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric <span class="im">import</span> Fabric</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> lightning.fabric.utilities <span class="im">import</span> rank_zero_only</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_model():</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> nn.Sequential(</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>        nn.Linear(<span class="dv">784</span>, <span class="dv">256</span>),</span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>        nn.ReLU(),</span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>        nn.Linear(<span class="dv">256</span>, <span class="dv">128</span>),</span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>        nn.ReLU(),</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>        nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_epoch(fabric, model, optimizer, dataloader, criterion):</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> data, target <span class="kw">in</span> dataloader:</span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a>        fabric.backward(loss)</span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(dataloader)</span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize Fabric</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a>    fabric <span class="op">=</span> Fabric(</span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a>        accelerator<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a>        strategy<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a>        devices<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a>        precision<span class="op">=</span><span class="st">"16-mixed"</span></span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a>    fabric.launch()</span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model, optimizer, data</span></span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> create_model()</span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup with Fabric</span></span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a>    model, optimizer <span class="op">=</span> fabric.setup(model, optimizer)</span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb21-49"><a href="#cb21-49" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb21-50"><a href="#cb21-50" aria-hidden="true" tabindex="-1"></a>        avg_loss <span class="op">=</span> train_epoch(fabric, model, optimizer, dataloader, criterion)</span>
<span id="cb21-51"><a href="#cb21-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb21-52"><a href="#cb21-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> fabric.is_global_zero:</span>
<span id="cb21-53"><a href="#cb21-53" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">: Loss = </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb21-54"><a href="#cb21-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-55"><a href="#cb21-55" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb21-56"><a href="#cb21-56" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide covers the essential aspects of using Lightning Fabric for efficient PyTorch training. Fabric provides the perfect balance between control and convenience, making it ideal for researchers and practitioners who want distributed training capabilities without giving up flexibility in their training loops.</p>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways" id="key-takeaways">Key Takeaways</h2>
<ul>
<li><strong>Lightweight</strong>: Fabric adds minimal overhead to your PyTorch code</li>
<li><strong>Flexible</strong>: Maintain full control over your training loop</li>
<li><strong>Scalable</strong>: Easy distributed training with DDP, FSDP, and other strategies</li>
<li><strong>Efficient</strong>: Built-in mixed precision and optimization features</li>
<li><strong>Compatible</strong>: Works with existing PyTorch code with minimal changes</li>
</ul>
<p>For more advanced use cases and the latest features, refer to the <a href="https://lightning.ai/docs/fabric/stable/">official Lightning Fabric documentation</a>.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch Collate Function Speed-Up Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/pytorch-collate-gains/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/pytorch-collate-gains/</guid>
      <pubDate>Sun, 01 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="pytorch-collate-function-speed-up-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/pytorch-collate-gains/pytorch.jpg" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>The collate function in PyTorch is a crucial component for optimizing data loading performance. It determines how individual samples are combined into batches, and custom implementations can significantly speed up training by reducing data preprocessing overhead and memory operations.</p>
</section>
<section id="default-vs-custom-collate-functions" class="level2">
<h2 class="anchored" data-anchor-id="default-vs-custom-collate-functions" id="default-vs-custom-collate-functions">Default vs Custom Collate Functions</h2>
<section id="default-behavior" class="level3">
<h3 class="anchored" data-anchor-id="default-behavior" id="default-behavior">Default Behavior</h3>
<div id="747c5105" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, Dataset</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleDataset(Dataset):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, size<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data <span class="op">=</span> [torch.randn(<span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> [torch.randint(<span class="dv">0</span>, <span class="dv">10</span>, (<span class="dv">1</span>,)).item() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.data)</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.data[idx], <span class="va">self</span>.labels[idx]</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Using default collate function</span></span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>dataset <span class="op">=</span> SimpleDataset(<span class="dv">1000</span>)</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>default_loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Timing default collate</span></span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>start_time <span class="op">=</span> time.time()</span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch_idx, (data, labels) <span class="kw">in</span> <span class="bu">enumerate</span>(default_loader):</span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:  <span class="co"># Test first 10 batches</span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>default_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Default collate time: </span><span class="sc">{</span>default_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Default collate time: 0.0136 seconds</code></pre>
</div>
</div>
</section>
<section id="custom-collate-function---basic-optimization" class="level3">
<h3 class="anchored" data-anchor-id="custom-collate-function---basic-optimization" id="custom-collate-function---basic-optimization">Custom Collate Function - Basic Optimization</h3>
<div id="883adc2d" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fast_collate(batch):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Optimized collate function for image data"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Separate data and labels</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stack tensors directly (faster than default_collate for large tensors)</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    data_tensor <span class="op">=</span> torch.stack(data, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data_tensor, labels_tensor</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Using custom collate function</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>custom_loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>, collate_fn<span class="op">=</span>fast_collate)</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Timing custom collate</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>start_time <span class="op">=</span> time.time()</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch_idx, (data, labels) <span class="kw">in</span> <span class="bu">enumerate</span>(custom_loader):</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>custom_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Custom collate time: </span><span class="sc">{</span>custom_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Speed improvement: </span><span class="sc">{</span>(default_time<span class="op">/</span>custom_time <span class="op">-</span> <span class="dv">1</span>) <span class="op">*</span> <span class="dv">100</span><span class="sc">:.1f}</span><span class="ss">%"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Custom collate time: 0.0101 seconds
Speed improvement: 35.0%</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="advanced-optimizations" id="advanced-optimizations">Advanced Optimizations</h2>
<section id="memory-efficient-collate-for-variable-length-sequences" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficient-collate-for-variable-length-sequences" id="memory-efficient-collate-for-variable-length-sequences">Memory-Efficient Collate for Variable-Length Sequences</h3>
<div id="381946d8" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.utils.rnn <span class="im">as</span> rnn_utils</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VariableLengthDataset(Dataset):</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, size<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Simulate variable-length sequences</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data <span class="op">=</span> [torch.randn(np.random.randint(<span class="dv">10</span>, <span class="dv">100</span>), <span class="dv">128</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> [torch.randint(<span class="dv">0</span>, <span class="dv">5</span>, (<span class="dv">1</span>,)).item() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.data)</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.data[idx], <span class="va">self</span>.labels[idx]</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_variable_collate(batch):</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Efficient collate for variable-length sequences"""</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get sequence lengths for efficient packing</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    lengths <span class="op">=</span> torch.tensor([<span class="bu">len</span>(seq) <span class="cf">for</span> seq <span class="kw">in</span> data])</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Pad sequences efficiently</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    padded_data <span class="op">=</span> rnn_utils.pad_sequence(data, batch_first<span class="op">=</span><span class="va">True</span>, padding_value<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> padded_data, labels_tensor, lengths</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Performance comparison</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>var_dataset <span class="op">=</span> VariableLengthDataset(<span class="dv">500</span>)</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Default collate (will fail for variable lengths, so we'll use a naive approach)</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> naive_variable_collate(batch):</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>    max_len <span class="op">=</span> <span class="bu">max</span>(<span class="bu">len</span>(seq) <span class="cf">for</span> seq <span class="kw">in</span> data)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Inefficient padding</span></span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>    padded_data <span class="op">=</span> []</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> seq <span class="kw">in</span> data:</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(seq) <span class="op">&lt;</span> max_len:</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>            padded_seq <span class="op">=</span> torch.cat([seq, torch.zeros(max_len <span class="op">-</span> <span class="bu">len</span>(seq), seq.size(<span class="dv">1</span>))])</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>            padded_seq <span class="op">=</span> seq</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        padded_data.append(padded_seq)</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.stack(padded_data), torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a><span class="co"># Timing comparison</span></span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>naive_loader <span class="op">=</span> DataLoader(var_dataset, batch_size<span class="op">=</span><span class="dv">16</span>, collate_fn<span class="op">=</span>naive_variable_collate)</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>efficient_loader <span class="op">=</span> DataLoader(var_dataset, batch_size<span class="op">=</span><span class="dv">16</span>, collate_fn<span class="op">=</span>efficient_variable_collate)</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a><span class="co"># Naive approach timing</span></span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>start_time <span class="op">=</span> time.time()</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch_idx, batch <span class="kw">in</span> <span class="bu">enumerate</span>(naive_loader):</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>naive_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a><span class="co"># Efficient approach timing</span></span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>start_time <span class="op">=</span> time.time()</span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch_idx, batch <span class="kw">in</span> <span class="bu">enumerate</span>(efficient_loader):</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:</span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>        <span class="cf">break</span></span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>efficient_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Naive variable collate time: </span><span class="sc">{</span>naive_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Efficient variable collate time: </span><span class="sc">{</span>efficient_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Speed improvement: </span><span class="sc">{</span>(naive_time<span class="op">/</span>efficient_time <span class="op">-</span> <span class="dv">1</span>) <span class="op">*</span> <span class="dv">100</span><span class="sc">:.1f}</span><span class="ss">%"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Naive variable collate time: 0.0031 seconds
Efficient variable collate time: 0.0024 seconds
Speed improvement: 26.3%</code></pre>
</div>
</div>
</section>
<section id="gpu-accelerated-collate-function" class="level3">
<h3 class="anchored" data-anchor-id="gpu-accelerated-collate-function" id="gpu-accelerated-collate-function">GPU-Accelerated Collate Function</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> gpu_accelerated_collate(batch, device<span class="op">=</span><span class="st">'cuda'</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Collate function that moves data to GPU during batching"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> torch.cuda.is_available():</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        device <span class="op">=</span> <span class="st">'cpu'</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stack and move to GPU in one operation</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    data_tensor <span class="op">=</span> torch.stack(data, dim<span class="op">=</span><span class="dv">0</span>).to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>).to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data_tensor, labels_tensor</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Performance comparison with GPU transfer</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> <span class="st">'cuda'</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Standard approach: CPU collate + GPU transfer</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>    standard_loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># GPU-accelerated collate</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>    gpu_loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>, </span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>                           collate_fn<span class="op">=</span><span class="kw">lambda</span> batch: gpu_accelerated_collate(batch, device))</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Timing standard approach</span></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, labels) <span class="kw">in</span> <span class="bu">enumerate</span>(standard_loader):</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        data, labels <span class="op">=</span> data.to(device), labels.to(device)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>    standard_gpu_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Timing GPU-accelerated collate</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, labels) <span class="kw">in</span> <span class="bu">enumerate</span>(gpu_loader):</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> <span class="dv">10</span>:</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>    gpu_collate_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Standard CPU-&gt;GPU time: </span><span class="sc">{</span>standard_gpu_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"GPU-accelerated collate time: </span><span class="sc">{</span>gpu_collate_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Speed improvement: </span><span class="sc">{</span>(standard_gpu_time<span class="op">/</span>gpu_collate_time <span class="op">-</span> <span class="dv">1</span>) <span class="op">*</span> <span class="dv">100</span><span class="sc">:.1f}</span><span class="ss">%"</span>)</span></code></pre></div></div>
</section>
<section id="memory-mapped-file-collate-for-large-datasets" class="level3">
<h3 class="anchored" data-anchor-id="memory-mapped-file-collate-for-large-datasets" id="memory-mapped-file-collate-for-large-datasets">Memory-Mapped File Collate for Large Datasets</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mmap</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MemoryMappedDataset(Dataset):</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Dataset using memory-mapped files for efficient large data loading"""</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_array, labels_array):</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data <span class="op">=</span> data_array</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> labels_array</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.labels)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Return views instead of copies when possible</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.from_numpy(<span class="va">self</span>.data[idx].copy()), <span class="va">self</span>.labels[idx]</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> zero_copy_collate(batch):</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Zero-copy collate function for numpy arrays"""</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use torch.from_numpy for zero-copy conversion when possible</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Stack numpy arrays first, then convert to tensor</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        data_array <span class="op">=</span> np.stack([d.numpy() <span class="cf">if</span> <span class="bu">isinstance</span>(d, torch.Tensor) <span class="cf">else</span> d <span class="cf">for</span> d <span class="kw">in</span> data])</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        data_tensor <span class="op">=</span> torch.from_numpy(data_array)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span>:</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Fallback to regular stacking</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        data_tensor <span class="op">=</span> torch.stack(data)</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data_tensor, labels_tensor</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data for demonstration</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>sample_data <span class="op">=</span> np.random.randn(<span class="dv">1000</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>).astype(np.float32)</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>sample_labels <span class="op">=</span> np.random.randint(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">1000</span>)</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>mmap_dataset <span class="op">=</span> MemoryMappedDataset(sample_data, sample_labels)</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>mmap_loader <span class="op">=</span> DataLoader(mmap_dataset, batch_size<span class="op">=</span><span class="dv">32</span>, collate_fn<span class="op">=</span>zero_copy_collate)</span></code></pre></div></div>
</section>
</section>
<section id="specialized-collate-functions" class="level2">
<h2 class="anchored" data-anchor-id="specialized-collate-functions" id="specialized-collate-functions">Specialized Collate Functions</h2>
<section id="multi-modal-data-collate" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-data-collate" id="multi-modal-data-collate">Multi-Modal Data Collate</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiModalDataset(Dataset):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, size<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.images <span class="op">=</span> [torch.randn(<span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text <span class="op">=</span> [torch.randint(<span class="dv">0</span>, <span class="dv">1000</span>, (np.random.randint(<span class="dv">5</span>, <span class="dv">50</span>),)) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> [torch.randint(<span class="dv">0</span>, <span class="dv">10</span>, (<span class="dv">1</span>,)).item() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size)]</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.labels)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">'image'</span>: <span class="va">self</span>.images[idx],</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">'text'</span>: <span class="va">self</span>.text[idx],</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">'label'</span>: <span class="va">self</span>.labels[idx]</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> multimodal_collate(batch):</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Efficient collate for multi-modal data"""</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Separate different modalities</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    images <span class="op">=</span> [item[<span class="st">'image'</span>] <span class="cf">for</span> item <span class="kw">in</span> batch]</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>    texts <span class="op">=</span> [item[<span class="st">'text'</span>] <span class="cf">for</span> item <span class="kw">in</span> batch]</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>    labels <span class="op">=</span> [item[<span class="st">'label'</span>] <span class="cf">for</span> item <span class="kw">in</span> batch]</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Batch images</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    image_batch <span class="op">=</span> torch.stack(images)</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Batch variable-length text with padding</span></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    text_lengths <span class="op">=</span> torch.tensor([<span class="bu">len</span>(text) <span class="cf">for</span> text <span class="kw">in</span> texts])</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>    text_batch <span class="op">=</span> rnn_utils.pad_sequence(texts, batch_first<span class="op">=</span><span class="va">True</span>, padding_value<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Batch labels</span></span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>    label_batch <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>        <span class="st">'images'</span>: image_batch,</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        <span class="st">'texts'</span>: text_batch,</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>        <span class="st">'text_lengths'</span>: text_lengths,</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>        <span class="st">'labels'</span>: label_batch</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>multimodal_dataset <span class="op">=</span> MultiModalDataset(<span class="dv">500</span>)</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>multimodal_loader <span class="op">=</span> DataLoader(multimodal_dataset, batch_size<span class="op">=</span><span class="dv">16</span>, collate_fn<span class="op">=</span>multimodal_collate)</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the multimodal loader</span></span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>sample_batch <span class="op">=</span> <span class="bu">next</span>(<span class="bu">iter</span>(multimodal_loader))</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Multimodal batch shapes:"</span>)</span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> key, value <span class="kw">in</span> sample_batch.items():</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>key<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>value<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="augmentation-aware-collate" class="level3">
<h3 class="anchored" data-anchor-id="augmentation-aware-collate" id="augmentation-aware-collate">Augmentation-Aware Collate</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> augmentation_collate(batch, transform<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Collate function that applies augmentations during batching"""</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> transform:</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply augmentations during collation</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        augmented_data <span class="op">=</span> [transform(img) <span class="cf">for</span> img <span class="kw">in</span> data]</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        data_tensor <span class="op">=</span> torch.stack(augmented_data)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        data_tensor <span class="op">=</span> torch.stack(data)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data_tensor, labels_tensor</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Define augmentation pipeline</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>augment_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>    transforms.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>    transforms.RandomRotation(<span class="dv">10</span>),</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    transforms.ColorJitter(brightness<span class="op">=</span><span class="fl">0.2</span>, contrast<span class="op">=</span><span class="fl">0.2</span>)</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Create collate function with augmentation</span></span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>aug_collate_fn <span class="op">=</span> <span class="kw">lambda</span> batch: augmentation_collate(batch, augment_transform)</span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>aug_loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span><span class="dv">32</span>, collate_fn<span class="op">=</span>aug_collate_fn)</span></code></pre></div></div>
</section>
</section>
<section id="performance-tips-and-best-practices" class="level2">
<h2 class="anchored" data-anchor-id="performance-tips-and-best-practices" id="performance-tips-and-best-practices">Performance Tips and Best Practices</h2>
<section id="minimize-data-copying" class="level3">
<h3 class="anchored" data-anchor-id="minimize-data-copying" id="minimize-data-copying">1. Minimize Data Copying</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_collate_tips(batch):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Demonstrates efficient collate practices"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># TIP 1: Use torch.stack instead of torch.cat when possible</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># torch.stack is faster for same-sized tensors</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    data_tensor <span class="op">=</span> torch.stack(data, dim<span class="op">=</span><span class="dv">0</span>)  <span class="co"># Faster</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># data_tensor = torch.cat([d.unsqueeze(0) for d in data], dim=0)  # Slower</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># TIP 2: Use appropriate dtypes to save memory</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)  <span class="co"># Use long for indices</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># TIP 3: Pre-allocate tensors when size is known</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># This is more relevant for complex batching scenarios</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data_tensor, labels_tensor</span></code></pre></div></div>
</section>
<section id="memory-usage-optimization" class="level3">
<h3 class="anchored" data-anchor-id="memory-usage-optimization" id="memory-usage-optimization">2. Memory Usage Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> memory_efficient_collate(batch):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Memory-efficient collate function"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    data, labels <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>batch)</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Pre-allocate output tensor to avoid multiple allocations</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> <span class="bu">len</span>(data)</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    data_shape <span class="op">=</span> data[<span class="dv">0</span>].shape</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Allocate output tensor once</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    output_tensor <span class="op">=</span> torch.empty((batch_size,) <span class="op">+</span> data_shape, dtype<span class="op">=</span>data[<span class="dv">0</span>].dtype)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Fill the tensor in-place</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, tensor <span class="kw">in</span> <span class="bu">enumerate</span>(data):</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        output_tensor[i] <span class="op">=</span> tensor</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    labels_tensor <span class="op">=</span> torch.tensor(labels, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output_tensor, labels_tensor</span></code></pre></div></div>
</section>
<section id="benchmarking-your-collate-functions" class="level3">
<h3 class="anchored" data-anchor-id="benchmarking-your-collate-functions" id="benchmarking-your-collate-functions">3. Benchmarking Your Collate Functions</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_collate_functions():</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Comprehensive benchmarking of different collate approaches"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    dataset <span class="op">=</span> SimpleDataset(<span class="dv">1000</span>)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> <span class="dv">32</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    num_batches <span class="op">=</span> <span class="dv">20</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    collate_functions <span class="op">=</span> {</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'default'</span>: <span class="va">None</span>,</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'fast_collate'</span>: fast_collate,</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">'efficient_tips'</span>: efficient_collate_tips,</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">'memory_efficient'</span>: memory_efficient_collate</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> {}</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, collate_fn <span class="kw">in</span> collate_functions.items():</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        loader <span class="op">=</span> DataLoader(dataset, batch_size<span class="op">=</span>batch_size, </span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>                          collate_fn<span class="op">=</span>collate_fn, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, labels) <span class="kw">in</span> <span class="bu">enumerate</span>(loader):</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">&gt;=</span> num_batches:</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>                <span class="cf">break</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        elapsed_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>        results[name] <span class="op">=</span> elapsed_time</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>elapsed_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate improvements</span></span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>    baseline <span class="op">=</span> results[<span class="st">'default'</span>]</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, time_taken <span class="kw">in</span> results.items():</span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> name <span class="op">!=</span> <span class="st">'default'</span>:</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>            improvement <span class="op">=</span> (baseline <span class="op">/</span> time_taken <span class="op">-</span> <span class="dv">1</span>) <span class="op">*</span> <span class="dv">100</span></span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>name<span class="sc">}</span><span class="ss"> improvement: </span><span class="sc">{</span>improvement<span class="sc">:.1f}</span><span class="ss">%"</span>)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Run the benchmark</span></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>benchmark_collate_functions()</span></code></pre></div></div>
</section>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways" id="key-takeaways">Key Takeaways</h2>
<ol type="1">
<li><strong>Use <code>torch.stack()</code> instead of <code>torch.cat()</code></strong> for same-sized tensors</li>
<li><strong>Minimize data copying</strong> by working with tensor views when possible</li>
<li><strong>Pre-allocate tensors</strong> when batch sizes and shapes are known</li>
<li><strong>Consider GPU transfer</strong> during collation for better pipeline efficiency</li>
<li><strong>Use appropriate data types</strong> to optimize memory usage</li>
<li><strong>Profile your specific use case</strong> as optimal strategies vary by data type and size</li>
<li><strong>Leverage specialized functions</strong> like <code>pad_sequence</code> for variable-length data</li>
</ol>
<p>Custom collate functions can provide significant performance improvements, especially for large datasets or complex data structures. The key is to minimize unnecessary data operations and memory allocations while taking advantage of PyTorch’s optimized tensor operations.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch Training and Inference Optimization Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/pytorch-optimizations/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/pytorch-optimizations/</guid>
      <pubDate>Sun, 01 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="pytorch-training-and-inference-optimization-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/pytorch-optimizations/pytorch_workflow.png" class="img-fluid"></p>
<p>The guide includes practical code examples you can directly use in your projects, along with best practices and common pitfalls to avoid. Each section builds upon the previous ones, so you can implement these optimizations incrementally based on your specific needs and performance requirements.</p>
<section id="general-optimization-principles" class="level2">
<h2 class="anchored" data-anchor-id="general-optimization-principles" id="general-optimization-principles">General Optimization Principles</h2>
<section id="use-the-right-data-types" class="level3">
<h3 class="anchored" data-anchor-id="use-the-right-data-types" id="use-the-right-data-types">1. Use the Right Data Types</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Use half precision when possible (reduces memory and increases speed)</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.half()  <span class="co"># Convert to float16</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Or use mixed precision training</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Use appropriate tensor types</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.tensor(data, dtype<span class="op">=</span>torch.float32)  <span class="co"># Explicit dtype</span></span></code></pre></div></div>
</section>
<section id="optimize-data-loading" class="level3">
<h3 class="anchored" data-anchor-id="optimize-data-loading" id="optimize-data-loading">2. Optimize Data Loading</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.multiprocessing <span class="im">as</span> mp</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimize DataLoader</span></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>train_loader <span class="op">=</span> DataLoader(</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    dataset,</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    batch_size<span class="op">=</span><span class="dv">32</span>,</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    shuffle<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    num_workers<span class="op">=</span><span class="dv">4</span>,  <span class="co"># Use multiple workers</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>    pin_memory<span class="op">=</span><span class="va">True</span>,  <span class="co"># Faster GPU transfer</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    persistent_workers<span class="op">=</span><span class="va">True</span>,  <span class="co"># Keep workers alive</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    prefetch_factor<span class="op">=</span><span class="dv">2</span>  <span class="co"># Prefetch batches</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Use non_blocking transfers</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> batch[<span class="dv">0</span>].to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    target <span class="op">=</span> batch[<span class="dv">1</span>].to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span></code></pre></div></div>
</section>
<section id="tensor-operations-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="tensor-operations-best-practices" id="tensor-operations-best-practices">3. Tensor Operations Best Practices</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Avoid unnecessary CPU-GPU transfers</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.randn(<span class="dv">1000</span>, <span class="dv">1000</span>, device<span class="op">=</span><span class="st">'cuda'</span>)  <span class="co"># Create directly on GPU</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Use in-place operations when possible</span></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>x.add_(y)  <span class="co"># Instead of x = x + y</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>x.mul_(<span class="dv">2</span>)  <span class="co"># Instead of x = x * 2</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Batch operations instead of loops</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Bad</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(batch_size):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    result[i] <span class="op">=</span> model(x[i])</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Good</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> model(x)  <span class="co"># Process entire batch</span></span></code></pre></div></div>
</section>
</section>
<section id="training-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="training-optimizations" id="training-optimizations">Training Optimizations</h2>
<section id="mixed-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">1. Mixed Precision Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MyModel().cuda()</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(model.parameters())</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>scaler <span class="op">=</span> GradScaler()</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Forward pass with autocast</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> autocast():</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Backward pass with gradient scaling</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        scaler.scale(loss).backward()</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        scaler.step(optimizer)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        scaler.update()</span></code></pre></div></div>
</section>
<section id="gradient-accumulation" class="level3">
<h3 class="anchored" data-anchor-id="gradient-accumulation" id="gradient-accumulation">2. Gradient Accumulation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>accumulation_steps <span class="op">=</span> <span class="dv">4</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>optimizer.zero_grad()</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i, batch <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> autocast():</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(inputs)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, targets) <span class="op">/</span> accumulation_steps</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    scaler.scale(loss).backward()</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> (i <span class="op">+</span> <span class="dv">1</span>) <span class="op">%</span> accumulation_steps <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        scaler.step(optimizer)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        scaler.update()</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span></code></pre></div></div>
</section>
<section id="efficient-learning-rate-scheduling" class="level3">
<h3 class="anchored" data-anchor-id="efficient-learning-rate-scheduling" id="efficient-learning-rate-scheduling">3. Efficient Learning Rate Scheduling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.optim.lr_scheduler <span class="im">import</span> OneCycleLR</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">0.001</span>)</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> OneCycleLR(</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    optimizer,</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    max_lr<span class="op">=</span><span class="fl">0.01</span>,</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    epochs<span class="op">=</span>num_epochs,</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    steps_per_epoch<span class="op">=</span><span class="bu">len</span>(train_loader)</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Use scheduler after each batch for OneCycleLR</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ... training step ...</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    scheduler.step()</span></code></pre></div></div>
</section>
<section id="model-compilation-pytorch-2.0" class="level3">
<h3 class="anchored" data-anchor-id="model-compilation-pytorch-2.0" id="model-compilation-pytorch-2.0">4. Model Compilation (PyTorch 2.0+)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile model for faster training</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> torch.<span class="bu">compile</span>(model)</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Different modes for different use cases</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"reduce-overhead"</span>)  <span class="co"># For large models</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> torch.<span class="bu">compile</span>(model, mode<span class="op">=</span><span class="st">"max-autotune"</span>)     <span class="co"># For maximum performance</span></span></code></pre></div></div>
</section>
<section id="checkpoint-and-resume-training" class="level3">
<h3 class="anchored" data-anchor-id="checkpoint-and-resume-training" id="checkpoint-and-resume-training">5. Checkpoint and Resume Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> save_checkpoint(model, optimizer, epoch, loss, filename):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    torch.save({</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>        <span class="st">'epoch'</span>: epoch,</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">'model_state_dict'</span>: model.state_dict(),</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">'optimizer_state_dict'</span>: optimizer.state_dict(),</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">'loss'</span>: loss,</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    }, filename)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_checkpoint(model, optimizer, filename):</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    checkpoint <span class="op">=</span> torch.load(filename)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    model.load_state_dict(checkpoint[<span class="st">'model_state_dict'</span>])</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    optimizer.load_state_dict(checkpoint[<span class="st">'optimizer_state_dict'</span>])</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> checkpoint[<span class="st">'epoch'</span>], checkpoint[<span class="st">'loss'</span>]</span></code></pre></div></div>
</section>
</section>
<section id="inference-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="inference-optimizations" id="inference-optimizations">Inference Optimizations</h2>
<section id="model-optimization-for-inference" class="level3">
<h3 class="anchored" data-anchor-id="model-optimization-for-inference" id="model-optimization-for-inference">1. Model Optimization for Inference</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Set model to evaluation mode</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Disable gradient computation</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(inputs)</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Use torch.inference_mode() for even better performance</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.inference_mode():</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(inputs)</span></code></pre></div></div>
</section>
<section id="torchscript-optimization" class="level3">
<h3 class="anchored" data-anchor-id="torchscript-optimization" id="torchscript-optimization">2. TorchScript Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Trace the model</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>example_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>)</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>traced_model <span class="op">=</span> torch.jit.trace(model, example_input)</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Or script the model</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>scripted_model <span class="op">=</span> torch.jit.script(model)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Optimize the scripted model</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>optimized_model <span class="op">=</span> torch.jit.optimize_for_inference(scripted_model)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Save and load</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>torch.jit.save(optimized_model, <span class="st">"optimized_model.pt"</span>)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>loaded_model <span class="op">=</span> torch.jit.load(<span class="st">"optimized_model.pt"</span>)</span></code></pre></div></div>
</section>
<section id="quantization" class="level3">
<h3 class="anchored" data-anchor-id="quantization" id="quantization">3. Quantization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization <span class="im">as</span> quant</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Post-training quantization</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>quantized_model <span class="op">=</span> torch.quantization.quantize_dynamic(</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    model, {torch.nn.Linear}, dtype<span class="op">=</span>torch.qint8</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Quantization-aware training</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>model.train()</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>model.qconfig <span class="op">=</span> torch.quantization.get_default_qat_qconfig(<span class="st">'fbgemm'</span>)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>torch.quantization.prepare_qat(model, inplace<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Train the model...</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to quantized model</span></span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>quantized_model <span class="op">=</span> torch.quantization.convert(model, inplace<span class="op">=</span><span class="va">False</span>)</span></code></pre></div></div>
</section>
<section id="batch-processing-for-inference" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing-for-inference" id="batch-processing-for-inference">4. Batch Processing for Inference</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_inference(model, data_loader, device):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.inference_mode():</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch <span class="kw">in</span> data_loader:</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>            inputs <span class="op">=</span> batch.to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>            results.append(outputs.cpu())</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.cat(results, dim<span class="op">=</span><span class="dv">0</span>)</span></code></pre></div></div>
</section>
</section>
<section id="memory-management" class="level2">
<h2 class="anchored" data-anchor-id="memory-management" id="memory-management">Memory Management</h2>
<section id="memory-efficient-training" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficient-training" id="memory-efficient-training">1. Memory Efficient Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clear unnecessary variables</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="kw">del</span> intermediate_results</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>torch.cuda.empty_cache()  <span class="co"># Free GPU memory</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Use gradient checkpointing for large models</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.checkpoint <span class="im">import</span> checkpoint</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MyModel(nn.Module):</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use checkpointing for memory-intensive layers</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> checkpoint(<span class="va">self</span>.expensive_layer, x)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="monitor-memory-usage" class="level3">
<h3 class="anchored" data-anchor-id="monitor-memory-usage" id="monitor-memory-usage">2. Monitor Memory Usage</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> print_memory_usage():</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"GPU memory allocated: </span><span class="sc">{</span>torch<span class="sc">.</span>cuda<span class="sc">.</span>memory_allocated() <span class="op">/</span> <span class="fl">1e9</span><span class="sc">:.2f}</span><span class="ss"> GB"</span>)</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"GPU memory cached: </span><span class="sc">{</span>torch<span class="sc">.</span>cuda<span class="sc">.</span>memory_reserved() <span class="op">/</span> <span class="fl">1e9</span><span class="sc">:.2f}</span><span class="ss"> GB"</span>)</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor during training</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># ... training code ...</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>            print_memory_usage()</span></code></pre></div></div>
</section>
<section id="memory-efficient-data-loading" class="level3">
<h3 class="anchored" data-anchor-id="memory-efficient-data-loading" id="memory-efficient-data-loading">3. Memory-Efficient Data Loading</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MemoryEfficientDataset(torch.utils.data.Dataset):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_paths):</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data_paths <span class="op">=</span> data_paths</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load data on-demand instead of keeping in memory</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        data <span class="op">=</span> <span class="va">self</span>.load_data(<span class="va">self</span>.data_paths[idx])</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> data</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.data_paths)</span></code></pre></div></div>
</section>
</section>
<section id="hardware-specific-optimizations" class="level2">
<h2 class="anchored" data-anchor-id="hardware-specific-optimizations" id="hardware-specific-optimizations">Hardware-Specific Optimizations</h2>
<section id="gpu-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="gpu-optimizations" id="gpu-optimizations">1. GPU Optimizations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Set optimal GPU settings</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>torch.backends.cudnn.benchmark <span class="op">=</span> <span class="va">True</span>  <span class="co"># For fixed input sizes</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>torch.backends.cudnn.deterministic <span class="op">=</span> <span class="va">False</span>  <span class="co"># For reproducibility (slower)</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Use multiple GPUs</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> torch.cuda.device_count() <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> nn.DataParallel(model)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Or use DistributedDataParallel for better performance</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.nn.parallel <span class="im">import</span> DistributedDataParallel <span class="im">as</span> DDP</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DDP(model, device_ids<span class="op">=</span>[local_rank])</span></code></pre></div></div>
</section>
<section id="cpu-optimizations" class="level3">
<h3 class="anchored" data-anchor-id="cpu-optimizations" id="cpu-optimizations">2. CPU Optimizations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Set number of threads</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>torch.set_num_threads(<span class="dv">4</span>)</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Use Intel MKL-DNN optimizations</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>torch.backends.mkldnn.enabled <span class="op">=</span> <span class="va">True</span></span></code></pre></div></div>
</section>
<section id="apple-silicon-mps-support" class="level3">
<h3 class="anchored" data-anchor-id="apple-silicon-mps-support" id="apple-silicon-mps-support">3. Apple Silicon (MPS) Support</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use Metal Performance Shaders on Apple Silicon</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> torch.backends.mps.is_available():</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">"mps"</span>)</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> model.to(device)</span></code></pre></div></div>
</section>
</section>
<section id="profiling-and-debugging" class="level2">
<h2 class="anchored" data-anchor-id="profiling-and-debugging" id="profiling-and-debugging">Profiling and Debugging</h2>
<section id="pytorch-profiler" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-profiler" id="pytorch-profiler">1. PyTorch Profiler</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.profiler <span class="im">import</span> profile, record_function, ProfilerActivity</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> profile(</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    activities<span class="op">=</span>[ProfilerActivity.CPU, ProfilerActivity.CUDA],</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    record_shapes<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    profile_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    with_stack<span class="op">=</span><span class="va">True</span></span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>) <span class="im">as</span> prof:</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> train_loader:</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> record_function(<span class="st">"forward"</span>):</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> record_function(<span class="st">"backward"</span>):</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> record_function(<span class="st">"optimizer"</span>):</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Save trace for tensorboard</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>prof.export_chrome_trace(<span class="st">"trace.json"</span>)</span></code></pre></div></div>
</section>
<section id="memory-profiling" class="level3">
<h3 class="anchored" data-anchor-id="memory-profiling" id="memory-profiling">2. Memory Profiling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Profile memory usage</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> profile(profile_memory<span class="op">=</span><span class="va">True</span>) <span class="im">as</span> prof:</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    model(inputs)</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(prof.key_averages().table(sort_by<span class="op">=</span><span class="st">"self_cuda_memory_usage"</span>, row_limit<span class="op">=</span><span class="dv">10</span>))</span></code></pre></div></div>
</section>
<section id="speed-benchmarking" class="level3">
<h3 class="anchored" data-anchor-id="speed-benchmarking" id="speed-benchmarking">3. Speed Benchmarking</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_model(model, input_tensor, num_runs<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Warmup</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> model(input_tensor)</span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Benchmark</span></span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>    torch.cuda.synchronize()</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_runs):</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>            _ <span class="op">=</span> model(input_tensor)</span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a>    torch.cuda.synchronize()</span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a>    end_time <span class="op">=</span> time.time()</span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a>    avg_time <span class="op">=</span> (end_time <span class="op">-</span> start_time) <span class="op">/</span> num_runs</span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Average inference time: </span><span class="sc">{</span>avg_time<span class="op">*</span><span class="dv">1000</span><span class="sc">:.2f}</span><span class="ss"> ms"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices-summary" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-summary" id="best-practices-summary">Best Practices Summary</h2>
<ol type="1">
<li><strong>Always profile first</strong> - Identify bottlenecks before optimizing</li>
<li><strong>Use mixed precision</strong> - Significant speedup with minimal accuracy loss</li>
<li><strong>Optimize data loading</strong> - Use multiple workers and pin memory</li>
<li><strong>Batch operations</strong> - Avoid loops over individual samples</li>
<li><strong>Model compilation</strong> - Use <code>torch.compile()</code> for PyTorch 2.0+</li>
<li><strong>Memory management</strong> - Monitor and optimize memory usage</li>
<li><strong>Hardware utilization</strong> - Use all available compute resources</li>
<li><strong>Quantization for inference</strong> - Reduce model size and increase speed</li>
<li><strong>TorchScript for production</strong> - Better performance and deployment options</li>
<li><strong>Regular checkpointing</strong> - Save training progress and enable resumption</li>
</ol>
</section>
<section id="common-pitfalls-to-avoid" class="level2">
<h2 class="anchored" data-anchor-id="common-pitfalls-to-avoid" id="common-pitfalls-to-avoid">Common Pitfalls to Avoid</h2>
<ul>
<li>Moving tensors between CPU and GPU unnecessarily</li>
<li>Using small batch sizes that underutilize hardware</li>
<li>Not using <code>torch.no_grad()</code> during inference</li>
<li>Creating tensors in loops instead of batching</li>
<li>Not clearing variables and calling <code>torch.cuda.empty_cache()</code></li>
<li>Using synchronous operations when asynchronous would work</li>
<li>Not leveraging built-in optimized functions</li>
</ul>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Why I Choose PyTorch for Deep Learning]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/why-pytorch/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/why-pytorch/</guid>
      <pubDate>Sun, 01 Jun 2025 00:00:00 GMT</pubDate>
      
      <category>news</category>
      <content:encoded><![CDATA[






<section id="why-i-choose-pytorch-for-deep-learning" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/why-pytorch/pytorch.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>When it comes to deep learning frameworks, the landscape offers several compelling options. TensorFlow, JAX, and PyTorch each have their strengths, but after working extensively with multiple frameworks, PyTorch has become my go-to choice for deep learning projects. Here’s why this dynamic framework continues to win over researchers and practitioners alike.</p>
</section>
<section id="sec-dynamic-graphs" class="level2">
<h2 class="anchored" data-anchor-id="sec-dynamic-graphs" id="sec-dynamic-graphs">The Power of Dynamic Computation Graphs</h2>
<p>PyTorch’s defining feature is its <strong>dynamic computation graph</strong>, also known as “define-by-run.” Unlike static graphs where you must define the entire network architecture upfront, PyTorch builds the computational graph on-the-fly as operations execute. This approach offers unprecedented flexibility for complex architectures and experimental research.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Dynamic vs Static Graphs
</div>
</div>
<div class="callout-body-container callout-body">
<p>Consider debugging a recurrent neural network with variable sequence lengths. In PyTorch, you can step through your code line by line, inspect tensors at any point, and modify the network behavior based on runtime conditions.</p>
</div>
</div>
<p>This dynamic nature makes PyTorch feel more like writing regular Python code rather than wrestling with a rigid framework.</p>
</section>
<section id="sec-pythonic-design" class="level2">
<h2 class="anchored" data-anchor-id="sec-pythonic-design" id="sec-pythonic-design">Pythonic Design Philosophy</h2>
<p>PyTorch embraces Python’s design principles, making it intuitive for developers already familiar with the language. The API feels natural and follows Python conventions closely. Operations like tensor manipulation, automatic differentiation, and model definition align with how Python developers expect to write code.</p>
<p>The framework integrates seamlessly with the broader Python ecosystem:</p>
<ul>
<li><strong>NumPy</strong>: Arrays convert effortlessly to PyTorch tensors</li>
<li><strong>Matplotlib</strong>: Works perfectly for visualization<br>
</li>
<li><strong>Standard debugging tools</strong>: Function as expected</li>
</ul>
<p>This integration reduces the learning curve and allows developers to leverage existing Python skills.</p>
</section>
<section id="sec-research-first" class="level2">
<h2 class="anchored" data-anchor-id="sec-research-first" id="sec-research-first">Research-First Mentality</h2>
<p>PyTorch originated from the research community and maintains strong connections to academic work. The framework prioritizes flexibility and experimentation over rigid optimization, making it ideal for cutting-edge research where novel architectures and training procedures are constantly emerging.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Research Impact
</div>
</div>
<div class="callout-body-container callout-body">
<p>Major research breakthroughs often appear first in PyTorch implementations. The framework’s flexibility allows researchers to quickly prototype new ideas without fighting against framework constraints.</p>
</div>
</div>
<p>This research-first approach has created a virtuous cycle where PyTorch continues to attract top researchers, leading to more innovations and better tooling.</p>
</section>
<section id="sec-debugging" class="level2">
<h2 class="anchored" data-anchor-id="sec-debugging" id="sec-debugging">Exceptional Debugging Experience</h2>
<p>Debugging deep learning models can be notoriously challenging, but PyTorch makes this process more manageable. Since PyTorch code executes imperatively, you can use standard Python debugging tools effectively:</p>
<ul>
<li><code>pdb</code> debugger</li>
<li>Print statements</li>
<li>IDE debuggers</li>
</ul>
<p>The framework provides excellent error messages that point to the exact line where issues occur. When tensor shapes don’t match or operations fail, PyTorch gives clear, actionable feedback rather than cryptic error messages buried deep in the framework’s internals.</p>
</section>
<section id="sec-ecosystem" class="level2">
<h2 class="anchored" data-anchor-id="sec-ecosystem" id="sec-ecosystem">Mature Ecosystem and Community</h2>
<p>PyTorch has cultivated a vibrant ecosystem of libraries and tools:</p>
<div id="tbl-ecosystem" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-ecosystem-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Key PyTorch Ecosystem Libraries
</figcaption>
<div aria-describedby="tbl-ecosystem-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<colgroup>
<col style="width: 50%">
<col style="width: 50%">
</colgroup>
<thead>
<tr class="header">
<th>Library</th>
<th>Purpose</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>PyTorch Lightning</strong></td>
<td>Simplifies training loops and experiment management</td>
</tr>
<tr class="even">
<td><strong>Transformers</strong> (Hugging Face)</td>
<td>State-of-the-art pre-trained models</td>
</tr>
<tr class="odd">
<td><strong>TorchVision</strong></td>
<td>Computer vision utilities</td>
</tr>
<tr class="even">
<td><strong>TorchText</strong></td>
<td>Natural language processing tools</td>
</tr>
<tr class="odd">
<td><strong>TorchAudio</strong></td>
<td>Audio processing capabilities</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>The community actively contributes tutorials, examples, and extensions. PyTorch’s documentation is comprehensive and includes practical examples alongside API references. The official tutorials cover everything from basic tensor operations to advanced topics like distributed training and model optimization.</p>
</section>
<section id="sec-performance" class="level2">
<h2 class="anchored" data-anchor-id="sec-performance" id="sec-performance">Performance and Production Readiness</h2>
<p>While PyTorch initially focused on research flexibility, recent versions have significantly improved production capabilities:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Deployment Tools</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Performance Features</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li><strong>TorchScript</strong>: Converts dynamic PyTorch models to static representations</li>
<li><strong>TorchServe</strong>: Provides model serving infrastructure<br>
</li>
<li><strong>PyTorch Mobile</strong>: Enables deployment on mobile devices</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<ul>
<li><strong>JIT Compiler</strong>: Optimizes computation graphs</li>
<li><strong>GPU Utilization</strong>: Efficient resource management</li>
<li><strong>Competitive Performance</strong>: Matches or exceeds alternatives</li>
</ul>
</div>
</div>
</div>
<p>For most applications, PyTorch’s performance matches or exceeds alternatives while maintaining superior flexibility.</p>
</section>
<section id="sec-autograd" class="level2">
<h2 class="anchored" data-anchor-id="sec-autograd" id="sec-autograd">Automatic Differentiation Done Right</h2>
<p>PyTorch’s automatic differentiation system, <strong>Autograd</strong>, elegantly handles gradient computation. The system tracks operations on tensors and builds a computational graph automatically. Computing gradients requires just a single <code>.backward()</code> call.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example of PyTorch's automatic differentiation</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Create tensors with gradient tracking</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> torch.tensor([<span class="fl">2.0</span>], requires_grad<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> torch.tensor([<span class="fl">3.0</span>], requires_grad<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Define computation</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>z <span class="op">=</span> x <span class="op">*</span> y <span class="op">+</span> x<span class="op">**</span><span class="dv">2</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Compute gradients automatically</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>z.backward()</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"dz/dx = </span><span class="sc">{</span>x<span class="sc">.</span>grad<span class="sc">}</span><span class="ss">"</span>)  <span class="co"># dz/dx = [7.0]</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"dz/dy = </span><span class="sc">{</span>y<span class="sc">.</span>grad<span class="sc">}</span><span class="ss">"</span>)  <span class="co"># dz/dy = [2.0]</span></span></code></pre></div></div>
<p>The differentiation system integrates smoothly with control flow, making it easy to implement complex architectures with conditional execution, loops, and dynamic behavior. This capability proves essential for advanced architectures like attention mechanisms and recursive networks.</p>
</section>
<section id="sec-industry" class="level2">
<h2 class="anchored" data-anchor-id="sec-industry" id="sec-industry">Growing Industry Adoption</h2>
<p>While TensorFlow dominated early industry adoption, PyTorch has gained significant ground in production environments. Major companies using PyTorch include:</p>
<ul>
<li><strong>Meta</strong> (Facebook)</li>
<li><strong>Tesla</strong></li>
<li><strong>OpenAI</strong></li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Unified Development
</div>
</div>
<div class="callout-body-container callout-body">
<p>Many companies now choose PyTorch for both research and production, eliminating the need to translate models between frameworks. This unified approach reduces complexity and accelerates the path from research to deployment.</p>
</div>
</div>
</section>
<section id="sec-future-proof" class="level2">
<h2 class="anchored" data-anchor-id="sec-future-proof" id="sec-future-proof">Future-Proof Architecture</h2>
<p>PyTorch’s design principles position it well for future developments in deep learning. The framework’s flexibility accommodates new paradigms without requiring major architectural changes:</p>
<ul>
<li>Few-shot learning</li>
<li>Meta-learning<br>
</li>
<li>Neural architecture search</li>
</ul>
<p>The PyTorch team actively develops new features while maintaining backward compatibility. Regular releases introduce performance improvements, new operators, and enhanced tooling without breaking existing code.</p>
</section>
<section id="sec-conclusion" class="level2">
<h2 class="anchored" data-anchor-id="sec-conclusion" id="sec-conclusion">Making the Choice</h2>
<p>Choosing PyTorch means prioritizing:</p>
<ol type="1">
<li><strong>Flexibility</strong> - Dynamic computation graphs</li>
<li><strong>Ease of use</strong> - Pythonic design</li>
<li><strong>Modern development practices</strong> - Excellent debugging and tooling</li>
</ol>
<p>The framework excels for research, education, and increasingly for production applications. Its dynamic nature, excellent debugging capabilities, and strong ecosystem make it a compelling choice for deep learning projects.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Recommendation
</div>
</div>
<div class="callout-body-container callout-body">
<p>For anyone starting a new deep learning project or considering a framework switch, PyTorch offers a modern, flexible foundation that grows with your needs and supports both experimentation and deployment.</p>
</div>
</div>
<p>While other frameworks have their merits, PyTorch’s combination of research-friendly design, production readiness, and vibrant community creates a compelling package for deep learning practitioners. The framework continues evolving rapidly while maintaining its core philosophy of putting developers first.</p>
<hr>
<p><em>This guide reflects the current state of PyTorch and its ecosystem. The deep learning landscape continues to evolve, but PyTorch’s foundational strengths position it well for future developments.</em></p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Kubeflow: A Comprehensive Guide to Machine Learning on Kubernetes]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-explain/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-explain/</guid>
      <pubDate>Sat, 31 May 2025 00:00:00 GMT</pubDate>
      
      <category>tutorial</category>
      <category>mlops</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="kubeflow-a-comprehensive-guide-to-machine-learning-on-kubernetes" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-explain/kubeflow.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaway
</div>
</div>
<div class="callout-body-container callout-body">
<p>Kubeflow is an open-source machine learning platform designed to make deployments of machine learning workflows on Kubernetes simple, portable, and scalable.</p>
</div>
</div>
<p>Kubeflow is an open-source machine learning platform designed to make deployments of machine learning workflows on Kubernetes simple, portable, and scalable. Originally developed by Google and now maintained by the Kubeflow community, it provides a comprehensive ecosystem for managing the entire machine learning lifecycle—from experimentation and training to serving and monitoring—all within a Kubernetes environment.</p>
<p>The platform addresses one of the most significant challenges in modern machine learning: bridging the gap between data science experimentation and production deployment. By leveraging Kubernetes’ container orchestration capabilities, Kubeflow enables ML teams to build, deploy, and manage machine learning systems at scale while maintaining consistency across different environments.</p>
</section>
<section id="architecture-and-core-components" class="level2">
<h2 class="anchored" data-anchor-id="architecture-and-core-components" id="architecture-and-core-components">Architecture and Core Components</h2>
<section id="high-level-architecture" class="level3">
<h3 class="anchored" data-anchor-id="high-level-architecture" id="high-level-architecture">High-Level Architecture</h3>
<p>Kubeflow follows a microservices architecture built on top of Kubernetes. The platform consists of several interconnected components, each serving specific functions in the ML workflow:</p>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Central Dashboard</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Kubeflow Pipelines</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Kubeflow Notebooks</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-4-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-4" role="tab" aria-controls="tabset-1-4" aria-selected="false" href="">Katib</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-5-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-5" role="tab" aria-controls="tabset-1-5" aria-selected="false" href="">KServe</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-6-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-6" role="tab" aria-controls="tabset-1-6" aria-selected="false" href="">Training Operators</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<p>The web-based user interface that provides a unified view of all Kubeflow components and allows users to manage their ML workflows through a single interface.</p>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<p>A comprehensive solution for building and deploying portable, scalable machine learning workflows based on Docker containers. It includes a user interface for managing and tracking experiments, jobs, and runs.</p>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<p>Provides Jupyter notebook servers for interactive development and experimentation. These notebooks run as Kubernetes pods and can be configured with different resource requirements and ML frameworks.</p>
</div>
<div id="tabset-1-4" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-4-tab">
<p>An automated machine learning system for hyperparameter tuning and neural architecture search. It supports various optimization algorithms and can run experiments across multiple nodes.</p>
</div>
<div id="tabset-1-5" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-5-tab">
<p>A serverless inferencing platform that provides standardized model serving capabilities with features like canary deployments, autoscaling, and multi-framework support.</p>
</div>
<div id="tabset-1-6" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-6-tab">
<p>A collection of Kubernetes operators for distributed training across different ML frameworks including TensorFlow, PyTorch, MPI, XGBoost, and PaddlePaddle.</p>
</div>
</div>
</div>
</section>
<section id="core-components-deep-dive" class="level3">
<h3 class="anchored" data-anchor-id="core-components-deep-dive" id="core-components-deep-dive">Core Components Deep Dive</h3>
<section id="kubeflow-pipelines-1" class="level4">
<h4 class="anchored" data-anchor-id="kubeflow-pipelines-1">Kubeflow Pipelines</h4>
<p>Kubeflow Pipelines represents the workflow orchestration heart of the platform. It enables users to define, deploy, and manage end-to-end ML workflows as code. Key features include:</p>
<ul>
<li><p><strong>Pipeline Definition</strong>: Workflows are defined using the Kubeflow Pipelines SDK, which allows data scientists to create reproducible, parameterized pipelines using Python. Each pipeline consists of multiple components that can be reused across different workflows.</p></li>
<li><p><strong>Component Library</strong>: A rich ecosystem of pre-built components for common ML tasks such as data preprocessing, model training, evaluation, and deployment. Users can also create custom components using containerized applications.</p></li>
<li><p><strong>Experiment Management</strong>: Built-in experiment tracking capabilities that allow teams to compare different pipeline runs, track metrics, and manage model versions systematically.</p></li>
<li><p><strong>Artifact Management</strong>: Automatic tracking and versioning of pipeline artifacts including datasets, models, and intermediate results, enabling full reproducibility of ML experiments.</p></li>
</ul>
</section>
<section id="kubeflow-notebooks-1" class="level4">
<h4 class="anchored" data-anchor-id="kubeflow-notebooks-1">Kubeflow Notebooks</h4>
<p>The notebook component provides a managed Jupyter environment optimized for machine learning workloads:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Resource Management
</div>
</div>
<div class="callout-body-container callout-body">
<p>Dynamic resource allocation allowing users to specify CPU, memory, and GPU requirements for their notebook servers based on workload demands.</p>
</div>
</div>
<ul>
<li><p><strong>Multi-Framework Support</strong>: Pre-configured notebook images with popular ML frameworks like TensorFlow, PyTorch, scikit-learn, and R, eliminating environment setup overhead.</p></li>
<li><p><strong>Persistent Storage</strong>: Integration with Kubernetes persistent volumes ensures that notebook work persists across server restarts and provides shared storage capabilities for team collaboration.</p></li>
<li><p><strong>Custom Images</strong>: Support for custom Docker images enables teams to create standardized environments with specific tool configurations and dependencies.</p></li>
</ul>
</section>
<section id="katib-for-automl" class="level4">
<h4 class="anchored" data-anchor-id="katib-for-automl">Katib for AutoML</h4>
<p>Katib provides automated machine learning capabilities focused on hyperparameter optimization and neural architecture search:</p>
<ul>
<li><p><strong>Optimization Algorithms</strong>: Support for various optimization strategies including random search, grid search, Bayesian optimization, and evolutionary algorithms.</p></li>
<li><p><strong>Parallel Execution</strong>: Distributed hyperparameter tuning across multiple nodes, significantly reducing experiment time for computationally intensive tasks.</p></li>
<li><p><strong>Early Stopping</strong>: Intelligent early stopping mechanisms that terminate underperforming trials, optimizing resource utilization.</p></li>
<li><p><strong>Multi-Objective Optimization</strong>: Support for optimizing multiple metrics simultaneously, useful for scenarios requiring trade-offs between accuracy, latency, and model size.</p></li>
</ul>
</section>
<section id="kserve-model-serving" class="level4">
<h4 class="anchored" data-anchor-id="kserve-model-serving">KServe Model Serving</h4>
<p>KServe provides enterprise-grade model serving capabilities:</p>
<ul>
<li><p><strong>Serverless Scaling</strong>: Automatic scaling to zero when no requests are being processed, and rapid scale-up based on incoming traffic patterns.</p></li>
<li><p><strong>Multi-Framework Support</strong>: Native support for TensorFlow, PyTorch, scikit-learn, XGBoost, and custom serving runtimes through standardized prediction protocols.</p></li>
<li><p><strong>Advanced Deployment Strategies</strong>: Built-in support for canary deployments, A/B testing, and blue-green deployments for safe model rollouts.</p></li>
<li><p><strong>Explainability Integration</strong>: Integration with explainability frameworks to provide model interpretability alongside predictions.</p></li>
</ul>
</section>
</section>
</section>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<section id="prerequisites" class="level3">
<h3 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h3>
<p>Before installing Kubeflow, ensure you have:</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Minimum Requirements
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Kubernetes Cluster</strong>: Version 1.21 or later recommended</li>
<li><strong>Resources</strong>: Minimum 4 CPU cores and 16GB RAM for basic installations</li>
<li><strong>Storage</strong>: Persistent storage capabilities with dynamic provisioning</li>
<li><strong>Network</strong>: Proper ingress configuration for external access</li>
</ul>
</div>
</div>
</section>
<section id="installation-methods" class="level3">
<h3 class="anchored" data-anchor-id="installation-methods" id="installation-methods">Installation Methods</h3>
<section id="kubeflow-manifests" class="level4">
<h4 class="anchored" data-anchor-id="kubeflow-manifests">Kubeflow Manifests</h4>
<p>The most straightforward installation method uses Kubeflow manifests:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clone the manifests repository</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/kubeflow/manifests.git</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> manifests</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install Kubeflow components</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="cf">while</span> <span class="ot">! </span><span class="ex">kustomize</span> build example <span class="kw">|</span> <span class="ex">kubectl</span> apply <span class="at">-f</span> <span class="at">-</span><span class="kw">;</span> <span class="cf">do</span> </span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>  <span class="bu">echo</span> <span class="st">"Retrying to apply resources"</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>  <span class="fu">sleep</span> 10</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="cf">done</span></span></code></pre></div></div>
<p>This method provides fine-grained control over component selection and configuration but requires manual management of dependencies and updates.</p>
</section>
<section id="distribution-specific-installations" class="level4">
<h4 class="anchored" data-anchor-id="distribution-specific-installations">Distribution-Specific Installations</h4>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-2-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-1" role="tab" aria-controls="tabset-2-1" aria-selected="true" href="">Google Cloud</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-2" role="tab" aria-controls="tabset-2-2" aria-selected="false" href="">AWS</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-2-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-2-3" role="tab" aria-controls="tabset-2-3" aria-selected="false" href="">Azure</a></li></ul>
<div class="tab-content">
<div id="tabset-2-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-2-1-tab">
<p>Use Google Cloud AI Platform Pipelines or deploy Kubeflow on GKE with optimized configurations for Google Cloud services.</p>
</div>
<div id="tabset-2-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-2-tab">
<p>Leverage AWS-specific distributions like Kubeflow on Amazon EKS, which provides pre-configured integrations with AWS services like S3, IAM, and CloudWatch.</p>
</div>
<div id="tabset-2-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-2-3-tab">
<p>Use Azure Machine Learning or deploy Kubeflow on AKS with Azure-specific optimizations and service integrations.</p>
</div>
</div>
</div>
</section>
</section>
<section id="post-installation-configuration" class="level3">
<h3 class="anchored" data-anchor-id="post-installation-configuration" id="post-installation-configuration">Post-Installation Configuration</h3>
<p>After installation, configure essential settings:</p>
<ul>
<li><p><strong>Authentication</strong>: Set up appropriate authentication mechanisms, whether through Kubernetes RBAC, external identity providers like OIDC, or platform-specific authentication systems.</p></li>
<li><p><strong>Storage Classes</strong>: Configure storage classes for different workload types, ensuring appropriate performance characteristics for training jobs, notebooks, and pipeline artifacts.</p></li>
<li><p><strong>Resource Quotas</strong>: Establish resource quotas and limits to prevent resource contention and ensure fair resource allocation across users and teams.</p></li>
<li><p><strong>Monitoring</strong>: Deploy monitoring solutions like Prometheus and Grafana to track cluster health, resource utilization, and application performance.</p></li>
</ul>
</section>
</section>
<section id="building-ml-pipelines" class="level2">
<h2 class="anchored" data-anchor-id="building-ml-pipelines" id="building-ml-pipelines">Building ML Pipelines</h2>
<section id="pipeline-components" class="level3">
<h3 class="anchored" data-anchor-id="pipeline-components" id="pipeline-components">Pipeline Components</h3>
<p>Kubeflow Pipelines are built from reusable components, each encapsulating a specific ML task:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A[Data Ingestion] --&gt; B[Data Preprocessing]
    B --&gt; C[Feature Engineering]
    C --&gt; D[Model Training]
    D --&gt; E[Model Evaluation]
    E --&gt; F[Model Deployment]
    F --&gt; G[Model Monitoring]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<ul>
<li><p><strong>Lightweight Components</strong>: Python functions that can be converted into pipeline components with minimal overhead, suitable for simple data processing tasks.</p></li>
<li><p><strong>Containerized Components</strong>: More complex components packaged as Docker containers, providing isolation and reproducibility for sophisticated ML operations.</p></li>
<li><p><strong>Pre-built Components</strong>: Community-contributed components available through the Kubeflow Pipelines component hub, covering common ML operations like data validation, feature engineering, and model evaluation.</p></li>
</ul>
</section>
<section id="pipeline-development-workflow" class="level3">
<h3 class="anchored" data-anchor-id="pipeline-development-workflow" id="pipeline-development-workflow">Pipeline Development Workflow</h3>
<ol type="1">
<li><p><strong>Design Phase</strong>: Define the overall workflow structure, identifying key stages like data ingestion, preprocessing, training, evaluation, and deployment.</p></li>
<li><p><strong>Component Development</strong>: Create or select appropriate components for each pipeline stage, ensuring proper input/output specifications and parameter definitions.</p></li>
<li><p><strong>Pipeline Assembly</strong>: Use the Kubeflow Pipelines SDK to connect components, define data flow, and specify execution dependencies.</p></li>
<li><p><strong>Testing and Validation</strong>: Test pipeline components individually and as complete workflows using smaller datasets before production deployment.</p></li>
</ol>
</section>
<section id="best-practices-for-pipeline-development" class="level3">
<h3 class="anchored" data-anchor-id="best-practices-for-pipeline-development" id="best-practices-for-pipeline-development">Best Practices for Pipeline Development</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Development Best Practices
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Modularity</strong>: Design components to be as modular and reusable as possible</li>
<li><strong>Parameterization</strong>: Make pipelines highly parameterizable</li>
<li><strong>Error Handling</strong>: Implement comprehensive error handling and logging</li>
<li><strong>Version Control</strong>: Maintain proper version control for pipeline definitions</li>
</ul>
</div>
</div>
</section>
</section>
<section id="model-training-and-experimentation" class="level2">
<h2 class="anchored" data-anchor-id="model-training-and-experimentation" id="model-training-and-experimentation">Model Training and Experimentation</h2>
<section id="distributed-training" class="level3">
<h3 class="anchored" data-anchor-id="distributed-training" id="distributed-training">Distributed Training</h3>
<p>Kubeflow supports distributed training across multiple frameworks:</p>
<ul>
<li><p><strong>TensorFlow Training</strong>: The TFJob operator enables distributed TensorFlow training with parameter servers or all-reduce strategies, automatically handling worker coordination and failure recovery.</p></li>
<li><p><strong>PyTorch Training</strong>: PyTorchJob operator supports distributed PyTorch training using various backends like NCCL and Gloo, with automatic scaling and fault tolerance.</p></li>
<li><p><strong>MPI Training</strong>: For frameworks that support MPI-based distributed training, the MPIJob operator provides seamless integration with message-passing interfaces.</p></li>
</ul>
</section>
<section id="experiment-management" class="level3">
<h3 class="anchored" data-anchor-id="experiment-management" id="experiment-management">Experiment Management</h3>
<ul>
<li><p><strong>Experiment Tracking</strong>: Kubeflow Pipelines automatically tracks experiment metadata, including parameters, metrics, and artifacts, enabling comprehensive experiment comparison and analysis.</p></li>
<li><p><strong>Hyperparameter Tuning</strong>: Katib integration allows for sophisticated hyperparameter optimization experiments with support for various search algorithms and early stopping strategies.</p></li>
<li><p><strong>Model Versioning</strong>: Built-in model versioning capabilities track model evolution over time, supporting model lineage and reproducibility requirements.</p></li>
</ul>
</section>
<section id="resource-optimization" class="level3">
<h3 class="anchored" data-anchor-id="resource-optimization" id="resource-optimization">Resource Optimization</h3>
<ul>
<li><p><strong>Auto-scaling</strong>: Dynamic resource allocation based on training workload requirements, optimizing cost and performance.</p></li>
<li><p><strong>GPU Scheduling</strong>: Intelligent GPU scheduling and sharing capabilities to maximize utilization of expensive GPU resources.</p></li>
<li><p><strong>Spot Instance Support</strong>: Integration with cloud provider spot instances for cost-effective training of non-critical workloads.</p></li>
</ul>
</section>
</section>
<section id="model-serving-and-deployment" class="level2">
<h2 class="anchored" data-anchor-id="model-serving-and-deployment" id="model-serving-and-deployment">Model Serving and Deployment</h2>
<section id="serving-strategies" class="level3">
<h3 class="anchored" data-anchor-id="serving-strategies" id="serving-strategies">Serving Strategies</h3>
<ul>
<li><p><strong>Real-time Serving</strong>: Low-latency serving for applications requiring immediate responses, with support for high-throughput scenarios.</p></li>
<li><p><strong>Batch Prediction</strong>: Efficient batch processing capabilities for scenarios where predictions can be computed offline or in batches.</p></li>
<li><p><strong>Edge Deployment</strong>: Support for deploying models to edge devices and environments with limited resources.</p></li>
</ul>
</section>
<section id="deployment-patterns" class="level3">
<h3 class="anchored" data-anchor-id="deployment-patterns" id="deployment-patterns">Deployment Patterns</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A[Model Registry] --&gt; B[Canary Deployment]
    A --&gt; C[A/B Testing]
    A --&gt; D[Shadow Deployment]
    B --&gt; E[Production Traffic]
    C --&gt; E
    D --&gt; F[Performance Evaluation]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
<ul>
<li><p><strong>Canary Deployments</strong>: Gradual rollout of new model versions to a subset of traffic, enabling safe deployment with minimal risk.</p></li>
<li><p><strong>A/B Testing</strong>: Side-by-side comparison of different model versions to evaluate performance improvements and business impact.</p></li>
<li><p><strong>Shadow Deployment</strong>: Deploy new models alongside existing ones to evaluate performance without affecting production traffic.</p></li>
</ul>
</section>
<section id="sec-model-monitoring" class="level3">
<h3 class="anchored" data-anchor-id="sec-model-monitoring" id="sec-model-monitoring">Model Monitoring</h3>
<ul>
<li><p><strong>Performance Monitoring</strong>: Continuous tracking of model performance metrics like accuracy, latency, and throughput.</p></li>
<li><p><strong>Data Drift Detection</strong>: Monitoring for changes in input data distribution that might affect model performance.</p></li>
<li><p><strong>Model Explainability</strong>: Integration with explainability tools to provide insights into model predictions and decision-making processes.</p></li>
</ul>
</section>
</section>
<section id="integration-with-ml-ecosystem" class="level2">
<h2 class="anchored" data-anchor-id="integration-with-ml-ecosystem" id="integration-with-ml-ecosystem">Integration with ML Ecosystem</h2>
<section id="data-integration" class="level3">
<h3 class="anchored" data-anchor-id="data-integration" id="data-integration">Data Integration</h3>
<ul>
<li><p><strong>Data Pipeline Integration</strong>: Seamless integration with data pipeline tools like Apache Airflow, allowing for end-to-end data-to-model workflows.</p></li>
<li><p><strong>Feature Store Integration</strong>: Support for feature stores like Feast, enabling consistent feature engineering across training and serving environments.</p></li>
<li><p><strong>Data Versioning</strong>: Integration with data versioning tools like DVC or Pachyderm for reproducible data management.</p></li>
</ul>
</section>
<section id="mlops-integration" class="level3">
<h3 class="anchored" data-anchor-id="mlops-integration" id="mlops-integration">MLOps Integration</h3>
<ul>
<li><p><strong>CI/CD Integration</strong>: Support for continuous integration and deployment pipelines, enabling automated model training, testing, and deployment.</p></li>
<li><p><strong>Model Registry</strong>: Integration with model registries like MLflow for centralized model management and lifecycle tracking.</p></li>
<li><p><strong>Monitoring and Observability</strong>: Integration with observability platforms for comprehensive monitoring of ML system health and performance.</p></li>
</ul>
</section>
<section id="cloud-provider-integration" class="level3">
<h3 class="anchored" data-anchor-id="cloud-provider-integration" id="cloud-provider-integration">Cloud Provider Integration</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-3-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-1" role="tab" aria-controls="tabset-3-1" aria-selected="true" href="">AWS Integration</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-2" role="tab" aria-controls="tabset-3-2" aria-selected="false" href="">Google Cloud Integration</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-3-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-3-3" role="tab" aria-controls="tabset-3-3" aria-selected="false" href="">Azure Integration</a></li></ul>
<div class="tab-content">
<div id="tabset-3-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-3-1-tab">
<p>Native support for AWS services like S3 for storage, IAM for authentication, and CloudWatch for monitoring.</p>
</div>
<div id="tabset-3-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-2-tab">
<p>Deep integration with Google Cloud services including BigQuery, Cloud Storage, and AI Platform services.</p>
</div>
<div id="tabset-3-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-3-3-tab">
<p>Support for Azure services like Azure Blob Storage, Azure Active Directory, and Azure Monitor.</p>
</div>
</div>
</div>
</section>
</section>
<section id="best-practices-and-considerations" class="level2">
<h2 class="anchored" data-anchor-id="best-practices-and-considerations" id="best-practices-and-considerations">Best Practices and Considerations</h2>
<section id="security-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="security-best-practices" id="security-best-practices">Security Best Practices</h3>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Security Considerations
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Authentication and Authorization</strong>: Implement proper authentication mechanisms and role-based access control</li>
<li><strong>Network Security</strong>: Use network policies and service meshes to secure communication</li>
<li><strong>Secret Management</strong>: Proper management of secrets and credentials</li>
<li><strong>Container Security</strong>: Regular scanning of container images for vulnerabilities</li>
</ul>
</div>
</div>
</section>
<section id="performance-optimization" class="level3">
<h3 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h3>
<ul>
<li><p><strong>Resource Planning</strong>: Careful planning of compute resources based on workload characteristics and performance requirements.</p></li>
<li><p><strong>Storage Optimization</strong>: Choose appropriate storage solutions based on access patterns, performance requirements, and cost considerations.</p></li>
<li><p><strong>Network Optimization</strong>: Optimize network configuration for data-intensive workloads, particularly for distributed training scenarios.</p></li>
<li><p><strong>Caching Strategies</strong>: Implement appropriate caching strategies for frequently accessed data and model artifacts.</p></li>
</ul>
</section>
<section id="operational-excellence" class="level3">
<h3 class="anchored" data-anchor-id="operational-excellence" id="operational-excellence">Operational Excellence</h3>
<ul>
<li><p><strong>Monitoring and Alerting</strong>: Comprehensive monitoring of system health, resource utilization, and application performance with appropriate alerting mechanisms.</p></li>
<li><p><strong>Backup and Recovery</strong>: Regular backups of critical data and configurations with tested recovery procedures.</p></li>
<li><p><strong>Documentation</strong>: Maintain comprehensive documentation of system architecture, operational procedures, and troubleshooting guides.</p></li>
<li><p><strong>Training and Support</strong>: Ensure team members are properly trained on Kubeflow operations and best practices.</p></li>
</ul>
</section>
</section>
<section id="use-cases-and-success-stories" class="level2">
<h2 class="anchored" data-anchor-id="use-cases-and-success-stories" id="use-cases-and-success-stories">Use Cases and Success Stories</h2>
<section id="enterprise-ml-platforms" class="level3">
<h3 class="anchored" data-anchor-id="enterprise-ml-platforms" id="enterprise-ml-platforms">Enterprise ML Platforms</h3>
<p>Large enterprises use Kubeflow to standardize their ML infrastructure across multiple teams and projects, providing consistent tooling and workflows while maintaining flexibility for different use cases.</p>
</section>
<section id="research-organizations" class="level3">
<h3 class="anchored" data-anchor-id="research-organizations" id="research-organizations">Research Organizations</h3>
<p>Academic and research institutions leverage Kubeflow’s flexibility and scalability to support diverse research projects with varying computational requirements and experimental approaches.</p>
</section>
<section id="startups-and-smes" class="level3">
<h3 class="anchored" data-anchor-id="startups-and-smes" id="startups-and-smes">Startups and SMEs</h3>
<p>Smaller organizations use Kubeflow to access enterprise-grade ML infrastructure without the overhead of building and maintaining custom solutions, accelerating their time to market.</p>
</section>
<section id="industry-specific-applications" class="level3">
<h3 class="anchored" data-anchor-id="industry-specific-applications" id="industry-specific-applications">Industry-Specific Applications</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-4-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-1" role="tab" aria-controls="tabset-4-1" aria-selected="true" href="">Financial Services</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-4-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-2" role="tab" aria-controls="tabset-4-2" aria-selected="false" href="">Healthcare</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-4-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-4-3" role="tab" aria-controls="tabset-4-3" aria-selected="false" href="">Retail and E-commerce</a></li></ul>
<div class="tab-content">
<div id="tabset-4-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-4-1-tab">
<p>Risk modeling, fraud detection, and algorithmic trading applications benefit from Kubeflow’s scalability and compliance capabilities.</p>
</div>
<div id="tabset-4-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-4-2-tab">
<p>Medical imaging, drug discovery, and clinical decision support systems leverage Kubeflow’s robust pipeline management and model serving capabilities.</p>
</div>
<div id="tabset-4-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-4-3-tab">
<p>Recommendation systems, demand forecasting, and personalization engines use Kubeflow’s ability to handle large-scale, real-time ML workloads.</p>
</div>
</div>
</div>
</section>
</section>
<section id="future-directions-and-roadmap" class="level2">
<h2 class="anchored" data-anchor-id="future-directions-and-roadmap" id="future-directions-and-roadmap">Future Directions and Roadmap</h2>
<section id="emerging-technologies" class="level3">
<h3 class="anchored" data-anchor-id="emerging-technologies" id="emerging-technologies">Emerging Technologies</h3>
<ul>
<li><p><strong>AutoML Integration</strong>: Enhanced integration with automated machine learning tools and techniques for democratizing ML development.</p></li>
<li><p><strong>Edge Computing</strong>: Improved support for edge deployment scenarios with optimized resource utilization and offline capabilities.</p></li>
<li><p><strong>Federated Learning</strong>: Native support for federated learning scenarios where data cannot be centralized due to privacy or regulatory constraints.</p></li>
</ul>
</section>
<section id="community-development" class="level3">
<h3 class="anchored" data-anchor-id="community-development" id="community-development">Community Development</h3>
<ul>
<li><p><strong>Component Ecosystem</strong>: Continued growth of the component ecosystem with contributions from the broader ML community.</p></li>
<li><p><strong>Integration Partnerships</strong>: Expanding partnerships with cloud providers, ML tool vendors, and open-source projects to enhance the platform’s capabilities.</p></li>
<li><p><strong>Standards Adoption</strong>: Participation in industry standards development to ensure compatibility and interoperability with other ML platforms and tools.</p></li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<p>Kubeflow represents a significant advancement in making machine learning workflows more scalable, reproducible, and manageable. Its modular architecture and extensibility make it suitable for organizations of all sizes.</p>
</div>
</div>
<p>Kubeflow represents a significant advancement in making machine learning workflows more scalable, reproducible, and manageable. By leveraging Kubernetes’ container orchestration capabilities, it provides a comprehensive platform that addresses the full spectrum of ML lifecycle management needs.</p>
<p>The platform’s strength lies in its modularity and extensibility, allowing organizations to adopt components incrementally based on their specific requirements and maturity levels. Whether you’re a startup looking to establish ML infrastructure or an enterprise seeking to standardize ML operations across multiple teams, Kubeflow provides the foundation for building robust, scalable ML systems.</p>
<p>As the machine learning landscape continues to evolve, Kubeflow’s active community and vendor-neutral approach position it well to adapt to emerging technologies and methodologies. Organizations investing in Kubeflow today are building on a platform designed to grow with their ML maturity and requirements, providing a solid foundation for long-term ML success.</p>
<p>The key to successful Kubeflow adoption lies in understanding your organization’s specific requirements, starting with pilot projects to build expertise, and gradually expanding usage as teams become more comfortable with the platform. With proper planning and implementation, Kubeflow can significantly accelerate your organization’s ML capabilities while maintaining the operational excellence required for production ML systems.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Kubeflow Deep Learning Guide with PyTorch]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-pytorch/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-pytorch/</guid>
      <pubDate>Sat, 31 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="kubeflow-deep-learning-guide-with-pytorch" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/kubeflow-pytorch/kupyt.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Kubeflow is a machine learning toolkit for Kubernetes that makes deployments of ML workflows on Kubernetes simple, portable, and scalable. This guide focuses on using Kubeflow with PyTorch for deep learning tasks.</p>
<section id="key-kubeflow-components-for-deep-learning" class="level3">
<h3 class="anchored" data-anchor-id="key-kubeflow-components-for-deep-learning" id="key-kubeflow-components-for-deep-learning">Key Kubeflow Components for Deep Learning:</h3>
<ul>
<li><strong>Training Operator</strong>: For distributed training jobs</li>
<li><strong>Katib</strong>: For hyperparameter tuning and neural architecture search</li>
<li><strong>KServe</strong>: For model serving and inference</li>
<li><strong>Pipelines</strong>: For ML workflow orchestration</li>
<li><strong>Notebooks</strong>: For interactive development</li>
</ul>
</section>
</section>
<section id="prerequisites" class="level2">
<h2 class="anchored" data-anchor-id="prerequisites" id="prerequisites">Prerequisites</h2>
<p>Before starting, ensure you have:</p>
<ul>
<li>Kubernetes cluster with Kubeflow installed</li>
<li>kubectl configured to access your cluster</li>
<li>Docker for building container images</li>
<li>Basic knowledge of PyTorch and Kubernetes</li>
</ul>
</section>
<section id="setting-up-your-environment" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-your-environment" id="setting-up-your-environment">Setting Up Your Environment</h2>
<section id="create-a-namespace" class="level3">
<h3 class="anchored" data-anchor-id="create-a-namespace" id="create-a-namespace">1. Create a Namespace</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> v1</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Namespace</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch-training</span></span></code></pre></div></div>
</section>
<section id="base-docker-image-for-pytorch" class="level3">
<h3 class="anchored" data-anchor-id="base-docker-image-for-pytorch" id="base-docker-image-for-pytorch">2. Base Docker Image for PyTorch</h3>
<p>Create a <code>Dockerfile</code> for your PyTorch environment:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode dockerfile code-with-copy"><code class="sourceCode dockerfile"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Install additional dependencies</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="dt">\</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    torchvision <span class="dt">\</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    tensorboard <span class="dt">\</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    scikit-learn <span class="dt">\</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>    pandas <span class="dt">\</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    numpy <span class="dt">\</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    matplotlib <span class="dt">\</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    seaborn</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy your training code</span></span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> . /app/</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Set the default command</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="kw">CMD</span> [<span class="st">"python"</span>, <span class="st">"train.py"</span>]</span></code></pre></div></div>
</section>
</section>
<section id="creating-pytorch-training-jobs" class="level2">
<h2 class="anchored" data-anchor-id="creating-pytorch-training-jobs" id="creating-pytorch-training-jobs">Creating PyTorch Training Jobs</h2>
<section id="simple-training-job" class="level3">
<h3 class="anchored" data-anchor-id="simple-training-job" id="simple-training-job">Simple Training Job</h3>
<p>Create a basic PyTorch training script (<code>train.py</code>):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> argparse</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleNet(nn.Module):</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(SimpleNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">1</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">1</span>),</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>, <span class="dv">1</span>),</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>),</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.25</span>),</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>            nn.Flatten(),</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">9216</span>, <span class="dv">128</span>),</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.5</span>),</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">128</span>, num_classes)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.features(x)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_epoch(model, device, train_loader, optimizer, criterion, epoch):</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        pred <span class="op">=</span> output.argmax(dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">+=</span> pred.eq(target.view_as(pred)).<span class="bu">sum</span>().item()</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Train Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss"> [</span><span class="sc">{</span>batch_idx <span class="op">*</span> <span class="bu">len</span>(data)<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span><span class="bu">len</span>(train_loader.dataset)<span class="sc">}</span><span class="ss">] '</span></span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>                  <span class="ss">f'Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.6f}</span><span class="ss">'</span>)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> <span class="fl">100.</span> <span class="op">*</span> correct <span class="op">/</span> <span class="bu">len</span>(train_loader.dataset)</span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>    avg_loss <span class="op">=</span> total_loss <span class="op">/</span> <span class="bu">len</span>(train_loader)</span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Train Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Average Loss: </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">, Accuracy: </span><span class="sc">{</span>accuracy<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> avg_loss, accuracy</span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test(model, device, test_loader, criterion):</span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>    test_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data, target <span class="kw">in</span> test_loader:</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>            test_loss <span class="op">+=</span> criterion(output, target).item()</span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a>            pred <span class="op">=</span> output.argmax(dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> pred.eq(target.view_as(pred)).<span class="bu">sum</span>().item()</span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a>    test_loss <span class="op">/=</span> <span class="bu">len</span>(test_loader)</span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> <span class="fl">100.</span> <span class="op">*</span> correct <span class="op">/</span> <span class="bu">len</span>(test_loader.dataset)</span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Test set: Average loss: </span><span class="sc">{</span>test_loss<span class="sc">:.4f}</span><span class="ss">, Accuracy: </span><span class="sc">{</span>accuracy<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> test_loss, accuracy</span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a>    parser <span class="op">=</span> argparse.ArgumentParser(description<span class="op">=</span><span class="st">'PyTorch MNIST Training'</span>)</span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--batch-size'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">64</span>, metavar<span class="op">=</span><span class="st">'N'</span>,</span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'input batch size for training (default: 64)'</span>)</span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--test-batch-size'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">1000</span>, metavar<span class="op">=</span><span class="st">'N'</span>,</span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'input batch size for testing (default: 1000)'</span>)</span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--epochs'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">10</span>, metavar<span class="op">=</span><span class="st">'N'</span>,</span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'number of epochs to train (default: 10)'</span>)</span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--lr'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">float</span>, default<span class="op">=</span><span class="fl">0.01</span>, metavar<span class="op">=</span><span class="st">'LR'</span>,</span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'learning rate (default: 0.01)'</span>)</span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--momentum'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">float</span>, default<span class="op">=</span><span class="fl">0.5</span>, metavar<span class="op">=</span><span class="st">'M'</span>,</span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'SGD momentum (default: 0.5)'</span>)</span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--no-cuda'</span>, action<span class="op">=</span><span class="st">'store_true'</span>, default<span class="op">=</span><span class="va">False</span>,</span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'disables CUDA training'</span>)</span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--seed'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">1</span>, metavar<span class="op">=</span><span class="st">'S'</span>,</span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'random seed (default: 1)'</span>)</span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--model-dir'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">str</span>, default<span class="op">=</span><span class="st">'/tmp/model'</span>,</span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a>                        <span class="bu">help</span><span class="op">=</span><span class="st">'directory to save the model'</span>)</span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>    args <span class="op">=</span> parser.parse_args()</span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a>    torch.manual_seed(args.seed)</span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="kw">and</span> <span class="kw">not</span> args.no_cuda <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb3-96"><a href="#cb3-96" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-97"><a href="#cb3-97" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data loading</span></span>
<span id="cb3-98"><a href="#cb3-98" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb3-99"><a href="#cb3-99" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb3-100"><a href="#cb3-100" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb3-101"><a href="#cb3-101" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb3-102"><a href="#cb3-102" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-103"><a href="#cb3-103" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'/tmp/data'</span>, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>transform)</span>
<span id="cb3-104"><a href="#cb3-104" aria-hidden="true" tabindex="-1"></a>    test_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'/tmp/data'</span>, train<span class="op">=</span><span class="va">False</span>, transform<span class="op">=</span>transform)</span>
<span id="cb3-105"><a href="#cb3-105" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-106"><a href="#cb3-106" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>args.batch_size, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-107"><a href="#cb3-107" aria-hidden="true" tabindex="-1"></a>    test_loader <span class="op">=</span> DataLoader(test_dataset, batch_size<span class="op">=</span>args.test_batch_size, shuffle<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb3-108"><a href="#cb3-108" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-109"><a href="#cb3-109" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model, loss, and optimizer</span></span>
<span id="cb3-110"><a href="#cb3-110" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleNet().to(device)</span>
<span id="cb3-111"><a href="#cb3-111" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb3-112"><a href="#cb3-112" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.SGD(model.parameters(), lr<span class="op">=</span>args.lr, momentum<span class="op">=</span>args.momentum)</span>
<span id="cb3-113"><a href="#cb3-113" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-114"><a href="#cb3-114" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb3-115"><a href="#cb3-115" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, args.epochs <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb3-116"><a href="#cb3-116" aria-hidden="true" tabindex="-1"></a>        train_loss, train_acc <span class="op">=</span> train_epoch(model, device, train_loader, optimizer, criterion, epoch)</span>
<span id="cb3-117"><a href="#cb3-117" aria-hidden="true" tabindex="-1"></a>        test_loss, test_acc <span class="op">=</span> test(model, device, test_loader, criterion)</span>
<span id="cb3-118"><a href="#cb3-118" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-119"><a href="#cb3-119" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Save model</span></span>
<span id="cb3-120"><a href="#cb3-120" aria-hidden="true" tabindex="-1"></a>    os.makedirs(args.model_dir, exist_ok<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-121"><a href="#cb3-121" aria-hidden="true" tabindex="-1"></a>    torch.save(model.state_dict(), <span class="ss">f'</span><span class="sc">{</span>args<span class="sc">.</span>model_dir<span class="sc">}</span><span class="ss">/model.pth'</span>)</span>
<span id="cb3-122"><a href="#cb3-122" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Model saved to </span><span class="sc">{</span>args<span class="sc">.</span>model_dir<span class="sc">}</span><span class="ss">/model.pth'</span>)</span>
<span id="cb3-123"><a href="#cb3-123" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-124"><a href="#cb3-124" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">'__main__'</span>:</span>
<span id="cb3-125"><a href="#cb3-125" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</section>
<section id="pytorchjob-yaml-configuration" class="level3">
<h3 class="anchored" data-anchor-id="pytorchjob-yaml-configuration" id="pytorchjob-yaml-configuration">PyTorchJob YAML Configuration</h3>
<p>Create a <code>pytorchjob.yaml</code> file:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> kubeflow.org/v1</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> PyTorchJob</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch-mnist-training</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">namespace</span><span class="kw">:</span><span class="at"> pytorch-training</span></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">pytorchReplicaSpecs</span><span class="kw">:</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">Master</span><span class="kw">:</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">restartPolicy</span><span class="kw">:</span><span class="at"> OnFailure</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">annotations</span><span class="kw">:</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">sidecar.istio.io/inject</span><span class="kw">:</span><span class="at"> </span><span class="st">"false"</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/pytorch-mnist:latest</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">imagePullPolicy</span><span class="kw">:</span><span class="at"> Always</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">command</span><span class="kw">:</span></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> python</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> train.py</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">args</span><span class="kw">:</span></span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --epochs=20</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --batch-size=64</span></span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --lr=0.01</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --model-dir=/mnt/model</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"2Gi"</span></span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"4Gi"</span></span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"2"</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">nvidia.com/gpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">volumeMounts</span><span class="kw">:</span></span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> model-storage</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">mountPath</span><span class="kw">:</span><span class="at"> /mnt/model</span></span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> model-storage</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">persistentVolumeClaim</span><span class="kw">:</span></span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">claimName</span><span class="kw">:</span><span class="at"> model-pvc</span></span></code></pre></div></div>
</section>
</section>
<section id="distributed-training" class="level2">
<h2 class="anchored" data-anchor-id="distributed-training" id="distributed-training">Distributed Training</h2>
<p>For distributed training across multiple GPUs or nodes:</p>
<section id="distributed-training-script" class="level3">
<h3 class="anchored" data-anchor-id="distributed-training-script" id="distributed-training-script">Distributed Training Script</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.distributed <span class="im">as</span> dist</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.multiprocessing <span class="im">as</span> mp</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.nn.parallel <span class="im">import</span> DistributedDataParallel <span class="im">as</span> DDP</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data.distributed <span class="im">import</span> DistributedSampler</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup(rank, world_size):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Initialize the distributed environment."""</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_ADDR'</span>] <span class="op">=</span> os.environ.get(<span class="st">'MASTER_ADDR'</span>, <span class="st">'localhost'</span>)</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    os.environ[<span class="st">'MASTER_PORT'</span>] <span class="op">=</span> os.environ.get(<span class="st">'MASTER_PORT'</span>, <span class="st">'12355'</span>)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    dist.init_process_group(<span class="st">"nccl"</span>, rank<span class="op">=</span>rank, world_size<span class="op">=</span>world_size)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cleanup():</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Clean up the distributed environment."""</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    dist.destroy_process_group()</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_distributed(rank, world_size, args):</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    setup(rank, world_size)</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="ss">f"cuda:</span><span class="sc">{</span>rank<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    torch.cuda.set_device(device)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create model and move to GPU</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleNet().to(device)</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> DDP(model, device_ids<span class="op">=</span>[rank])</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create distributed sampler</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    train_sampler <span class="op">=</span> DistributedSampler(train_dataset, </span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>                                       num_replicas<span class="op">=</span>world_size, </span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>                                       rank<span class="op">=</span>rank)</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, </span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>                              batch_size<span class="op">=</span>args.batch_size,</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>                              sampler<span class="op">=</span>train_sampler,</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>                              pin_memory<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.SGD(model.parameters(), lr<span class="op">=</span>args.lr)</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training loop</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(args.epochs):</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        train_sampler.set_epoch(epoch)</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device, non_blocking<span class="op">=</span><span class="va">True</span>), target.to(device, non_blocking<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span> <span class="kw">and</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Save model only on rank 0</span></span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> rank <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>        torch.save(model.module.state_dict(), <span class="ss">f'</span><span class="sc">{</span>args<span class="sc">.</span>model_dir<span class="sc">}</span><span class="ss">/distributed_model.pth'</span>)</span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>    cleanup()</span></code></pre></div></div>
</section>
<section id="distributed-pytorchjob-yaml" class="level3">
<h3 class="anchored" data-anchor-id="distributed-pytorchjob-yaml" id="distributed-pytorchjob-yaml">Distributed PyTorchJob YAML</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> kubeflow.org/v1</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> PyTorchJob</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch-distributed-training</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">namespace</span><span class="kw">:</span><span class="at"> pytorch-training</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">pytorchReplicaSpecs</span><span class="kw">:</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">Master</span><span class="kw">:</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">restartPolicy</span><span class="kw">:</span><span class="at"> OnFailure</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/pytorch-distributed:latest</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">command</span><span class="kw">:</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> python</span></span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> distributed_train.py</span></span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">args</span><span class="kw">:</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --epochs=50</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --batch-size=32</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">nvidia.com/gpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">Worker</span><span class="kw">:</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">restartPolicy</span><span class="kw">:</span><span class="at"> OnFailure</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/pytorch-distributed:latest</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">command</span><span class="kw">:</span></span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> python</span></span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> distributed_train.py</span></span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">args</span><span class="kw">:</span></span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --epochs=50</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="kw">-</span><span class="at"> --batch-size=32</span></span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">nvidia.com/gpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span></code></pre></div></div>
</section>
</section>
<section id="hyperparameter-tuning-with-katib" class="level2">
<h2 class="anchored" data-anchor-id="hyperparameter-tuning-with-katib" id="hyperparameter-tuning-with-katib">Hyperparameter Tuning with Katib</h2>
<section id="katib-experiment-configuration" class="level3">
<h3 class="anchored" data-anchor-id="katib-experiment-configuration" id="katib-experiment-configuration">Katib Experiment Configuration</h3>
<p>Create a <code>katib-experiment.yaml</code>:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> kubeflow.org/v1beta1</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Experiment</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch-hyperparameter-tuning</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">namespace</span><span class="kw">:</span><span class="at"> pytorch-training</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">algorithm</span><span class="kw">:</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">algorithmName</span><span class="kw">:</span><span class="at"> random</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">objective</span><span class="kw">:</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">type</span><span class="kw">:</span><span class="at"> maximize</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">goal</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.95</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">objectiveMetricName</span><span class="kw">:</span><span class="at"> accuracy</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">parameters</span><span class="kw">:</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> lr</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">parameterType</span><span class="kw">:</span><span class="at"> double</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">feasibleSpace</span><span class="kw">:</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">min</span><span class="kw">:</span><span class="at"> </span><span class="st">"0.001"</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">max</span><span class="kw">:</span><span class="at"> </span><span class="st">"0.1"</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> batch-size</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">parameterType</span><span class="kw">:</span><span class="at"> int</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">feasibleSpace</span><span class="kw">:</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">min</span><span class="kw">:</span><span class="at"> </span><span class="st">"16"</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">max</span><span class="kw">:</span><span class="at"> </span><span class="st">"128"</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> momentum</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">parameterType</span><span class="kw">:</span><span class="at"> double</span></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">feasibleSpace</span><span class="kw">:</span></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">min</span><span class="kw">:</span><span class="at"> </span><span class="st">"0.1"</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">max</span><span class="kw">:</span><span class="at"> </span><span class="st">"0.9"</span></span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">trialTemplate</span><span class="kw">:</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">primaryContainerName</span><span class="kw">:</span><span class="at"> training-container</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">trialSpec</span><span class="kw">:</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> kubeflow.org/v1</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">kind</span><span class="kw">:</span><span class="at"> PyTorchJob</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">pytorchReplicaSpecs</span><span class="kw">:</span></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">Master</span><span class="kw">:</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">restartPolicy</span><span class="kw">:</span><span class="at"> OnFailure</span></span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a><span class="at">              </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a><span class="at">                </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> training-container</span></span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/pytorch-katib:latest</span></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="fu">command</span><span class="kw">:</span></span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> python</span></span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> train_with_metrics.py</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="fu">args</span><span class="kw">:</span></span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> --lr=${trialParameters.lr}</span></span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> --batch-size=${trialParameters.batch-size}</span></span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> --momentum=${trialParameters.momentum}</span></span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a><span class="at">                  </span><span class="kw">-</span><span class="at"> --epochs=10</span></span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">parallelTrialCount</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">maxTrialCount</span><span class="kw">:</span><span class="at"> </span><span class="dv">20</span></span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">maxFailedTrialCount</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span></code></pre></div></div>
</section>
<section id="training-script-with-metrics-for-katib" class="level3">
<h3 class="anchored" data-anchor-id="training-script-with-metrics-for-katib" id="training-script-with-metrics-for-katib">Training Script with Metrics for Katib</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># train_with_metrics.py</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> argparse</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    parser <span class="op">=</span> argparse.ArgumentParser()</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--lr'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">float</span>, default<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--batch-size'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">64</span>)</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--momentum'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">float</span>, default<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    parser.add_argument(<span class="st">'--epochs'</span>, <span class="bu">type</span><span class="op">=</span><span class="bu">int</span>, default<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    args <span class="op">=</span> parser.parse_args()</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Setup device</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data loading</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'/tmp/data'</span>, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>transform)</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    test_dataset <span class="op">=</span> datasets.MNIST(<span class="st">'/tmp/data'</span>, train<span class="op">=</span><span class="va">False</span>, transform<span class="op">=</span>transform)</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>args.batch_size, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>    test_loader <span class="op">=</span> DataLoader(test_dataset, batch_size<span class="op">=</span><span class="dv">1000</span>, shuffle<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> SimpleNet().to(device)</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.SGD(model.parameters(), lr<span class="op">=</span>args.lr, momentum<span class="op">=</span>args.momentum)</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(args.epochs):</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        model.train()</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluation</span></span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data, target <span class="kw">in</span> test_loader:</span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(data)</span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs.data, <span class="dv">1</span>)</span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> target.size(<span class="dv">0</span>)</span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (predicted <span class="op">==</span> target).<span class="bu">sum</span>().item()</span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> correct <span class="op">/</span> total</span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Print metrics for Katib (important format)</span></span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"accuracy=</span><span class="sc">{</span>accuracy<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"loss=</span><span class="sc">{</span><span class="dv">1</span><span class="op">-</span>accuracy<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb8-65"><a href="#cb8-65" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-66"><a href="#cb8-66" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">'__main__'</span>:</span>
<span id="cb8-67"><a href="#cb8-67" aria-hidden="true" tabindex="-1"></a>    main()</span></code></pre></div></div>
</section>
</section>
<section id="model-serving-with-kserve" class="level2">
<h2 class="anchored" data-anchor-id="model-serving-with-kserve" id="model-serving-with-kserve">Model Serving with KServe</h2>
<section id="create-a-model-server" class="level3">
<h3 class="anchored" data-anchor-id="create-a-model-server" id="create-a-model-server">Create a Model Server</h3>
<p>First, create a custom predictor (<code>predictor.py</code>):</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> kserve</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Dict</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleNet(nn.Module):</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(SimpleNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">1</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">1</span>),</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>, <span class="dv">1</span>),</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>),</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.25</span>),</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            nn.Flatten(),</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">9216</span>, <span class="dv">128</span>),</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.5</span>),</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">128</span>, num_classes)</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.features(x)</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PyTorchMNISTPredictor(kserve.Model):</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, name: <span class="bu">str</span>):</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(name)</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.name <span class="op">=</span> name</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">None</span></span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ready <span class="op">=</span> <span class="va">False</span></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load(<span class="va">self</span>):</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> SimpleNet()</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.load_state_dict(torch.load(<span class="st">'/mnt/models/model.pth'</span>, map_location<span class="op">=</span><span class="st">'cpu'</span>))</span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ready <span class="op">=</span> <span class="va">True</span></span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, payload: Dict) <span class="op">-&gt;</span> Dict:</span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.ready:</span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">RuntimeError</span>(<span class="st">"Model not loaded"</span>)</span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decode base64 image</span></span>
<span id="cb9-53"><a href="#cb9-53" aria-hidden="true" tabindex="-1"></a>        image_data <span class="op">=</span> base64.b64decode(payload[<span class="st">"instances"</span>][<span class="dv">0</span>][<span class="st">"image"</span>])</span>
<span id="cb9-54"><a href="#cb9-54" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data)).convert(<span class="st">'L'</span>)</span>
<span id="cb9-55"><a href="#cb9-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-56"><a href="#cb9-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess</span></span>
<span id="cb9-57"><a href="#cb9-57" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> <span class="va">self</span>.transform(image).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb9-58"><a href="#cb9-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-59"><a href="#cb9-59" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Predict</span></span>
<span id="cb9-60"><a href="#cb9-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb9-61"><a href="#cb9-61" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(input_tensor)</span>
<span id="cb9-62"><a href="#cb9-62" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> torch.softmax(output, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb9-63"><a href="#cb9-63" aria-hidden="true" tabindex="-1"></a>            predicted_class <span class="op">=</span> torch.argmax(probabilities, dim<span class="op">=</span><span class="dv">1</span>).item()</span>
<span id="cb9-64"><a href="#cb9-64" aria-hidden="true" tabindex="-1"></a>            confidence <span class="op">=</span> probabilities[<span class="dv">0</span>][predicted_class].item()</span>
<span id="cb9-65"><a href="#cb9-65" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-66"><a href="#cb9-66" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb9-67"><a href="#cb9-67" aria-hidden="true" tabindex="-1"></a>            <span class="st">"predictions"</span>: [{</span>
<span id="cb9-68"><a href="#cb9-68" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class"</span>: predicted_class,</span>
<span id="cb9-69"><a href="#cb9-69" aria-hidden="true" tabindex="-1"></a>                <span class="st">"confidence"</span>: confidence,</span>
<span id="cb9-70"><a href="#cb9-70" aria-hidden="true" tabindex="-1"></a>                <span class="st">"probabilities"</span>: probabilities[<span class="dv">0</span>].tolist()</span>
<span id="cb9-71"><a href="#cb9-71" aria-hidden="true" tabindex="-1"></a>            }]</span>
<span id="cb9-72"><a href="#cb9-72" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb9-73"><a href="#cb9-73" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-74"><a href="#cb9-74" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb9-75"><a href="#cb9-75" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> PyTorchMNISTPredictor(<span class="st">"pytorch-mnist-predictor"</span>)</span>
<span id="cb9-76"><a href="#cb9-76" aria-hidden="true" tabindex="-1"></a>    model.load()</span>
<span id="cb9-77"><a href="#cb9-77" aria-hidden="true" tabindex="-1"></a>    kserve.ModelServer().start([model])</span></code></pre></div></div>
</section>
<section id="inferenceservice-yaml" class="level3">
<h3 class="anchored" data-anchor-id="inferenceservice-yaml" id="inferenceservice-yaml">InferenceService YAML</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> serving.kserve.io/v1beta1</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> InferenceService</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> pytorch-mnist-predictor</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">namespace</span><span class="kw">:</span><span class="at"> pytorch-training</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">predictor</span><span class="kw">:</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> kserve-container</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/pytorch-predictor:latest</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">containerPort</span><span class="kw">:</span><span class="at"> </span><span class="dv">8080</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">protocol</span><span class="kw">:</span><span class="at"> TCP</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">volumeMounts</span><span class="kw">:</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> model-storage</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">mountPath</span><span class="kw">:</span><span class="at"> /mnt/models</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"100m"</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"1Gi"</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1"</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"2Gi"</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> model-storage</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">persistentVolumeClaim</span><span class="kw">:</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">claimName</span><span class="kw">:</span><span class="at"> model-pvc</span></span></code></pre></div></div>
</section>
</section>
<section id="complete-pipeline-example" class="level2">
<h2 class="anchored" data-anchor-id="complete-pipeline-example" id="complete-pipeline-example">Complete Pipeline Example</h2>
<section id="kubeflow-pipeline-with-pytorch" class="level3">
<h3 class="anchored" data-anchor-id="kubeflow-pipeline-with-pytorch" id="kubeflow-pipeline-with-pytorch">Kubeflow Pipeline with PyTorch</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> kfp</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> kfp <span class="im">import</span> dsl</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> kfp.components <span class="im">import</span> create_component_from_func</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_data_op():</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dsl.ContainerOp(</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">'preprocess-data'</span>,</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span><span class="st">'your-registry/data-preprocessing:latest'</span>,</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        command<span class="op">=</span>[<span class="st">'python'</span>, <span class="st">'preprocess.py'</span>],</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        file_outputs<span class="op">=</span>{<span class="st">'dataset_path'</span>: <span class="st">'/tmp/dataset_path.txt'</span>}</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_model_op(dataset_path, lr: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.01</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">64</span>):</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dsl.ContainerOp(</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">'train-model'</span>,</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span><span class="st">'your-registry/pytorch-training:latest'</span>,</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        command<span class="op">=</span>[<span class="st">'python'</span>, <span class="st">'train.py'</span>],</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        arguments<span class="op">=</span>[</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--data-path'</span>, dataset_path,</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--lr'</span>, lr,</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--batch-size'</span>, batch_size,</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--model-dir'</span>, <span class="st">'/tmp/model'</span></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        ],</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>        file_outputs<span class="op">=</span>{<span class="st">'model_path'</span>: <span class="st">'/tmp/model_path.txt'</span>}</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate_model_op(model_path, dataset_path):</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dsl.ContainerOp(</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">'evaluate-model'</span>,</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span><span class="st">'your-registry/pytorch-evaluation:latest'</span>,</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>        command<span class="op">=</span>[<span class="st">'python'</span>, <span class="st">'evaluate.py'</span>],</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        arguments<span class="op">=</span>[</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--model-path'</span>, model_path,</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>            <span class="st">'--data-path'</span>, dataset_path</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>        ],</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>        file_outputs<span class="op">=</span>{<span class="st">'metrics'</span>: <span class="st">'/tmp/metrics.json'</span>}</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> deploy_model_op(model_path):</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> dsl.ContainerOp(</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        name<span class="op">=</span><span class="st">'deploy-model'</span>,</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        image<span class="op">=</span><span class="st">'your-registry/model-deployment:latest'</span>,</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        command<span class="op">=</span>[<span class="st">'python'</span>, <span class="st">'deploy.py'</span>],</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        arguments<span class="op">=</span>[<span class="st">'--model-path'</span>, model_path]</span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a><span class="at">@dsl.pipeline</span>(</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">'PyTorch Training Pipeline'</span>,</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>    description<span class="op">=</span><span class="st">'Complete PyTorch training and deployment pipeline'</span></span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> pytorch_training_pipeline(</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>    lr: <span class="bu">float</span> <span class="op">=</span> <span class="fl">0.01</span>,</span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a>    batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">64</span>,</span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>    epochs: <span class="bu">int</span> <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb11-55"><a href="#cb11-55" aria-hidden="true" tabindex="-1"></a>):</span>
<span id="cb11-56"><a href="#cb11-56" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data preprocessing</span></span>
<span id="cb11-57"><a href="#cb11-57" aria-hidden="true" tabindex="-1"></a>    preprocess_task <span class="op">=</span> preprocess_data_op()</span>
<span id="cb11-58"><a href="#cb11-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-59"><a href="#cb11-59" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model training</span></span>
<span id="cb11-60"><a href="#cb11-60" aria-hidden="true" tabindex="-1"></a>    train_task <span class="op">=</span> train_model_op(</span>
<span id="cb11-61"><a href="#cb11-61" aria-hidden="true" tabindex="-1"></a>        dataset_path<span class="op">=</span>preprocess_task.outputs[<span class="st">'dataset_path'</span>],</span>
<span id="cb11-62"><a href="#cb11-62" aria-hidden="true" tabindex="-1"></a>        lr<span class="op">=</span>lr,</span>
<span id="cb11-63"><a href="#cb11-63" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span>batch_size</span>
<span id="cb11-64"><a href="#cb11-64" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-65"><a href="#cb11-65" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-66"><a href="#cb11-66" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Model evaluation</span></span>
<span id="cb11-67"><a href="#cb11-67" aria-hidden="true" tabindex="-1"></a>    evaluate_task <span class="op">=</span> evaluate_model_op(</span>
<span id="cb11-68"><a href="#cb11-68" aria-hidden="true" tabindex="-1"></a>        model_path<span class="op">=</span>train_task.outputs[<span class="st">'model_path'</span>],</span>
<span id="cb11-69"><a href="#cb11-69" aria-hidden="true" tabindex="-1"></a>        dataset_path<span class="op">=</span>preprocess_task.outputs[<span class="st">'dataset_path'</span>]</span>
<span id="cb11-70"><a href="#cb11-70" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb11-71"><a href="#cb11-71" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-72"><a href="#cb11-72" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Conditional deployment based on accuracy</span></span>
<span id="cb11-73"><a href="#cb11-73" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> dsl.Condition(evaluate_task.outputs[<span class="st">'metrics'</span>], <span class="st">'&gt;'</span>, <span class="st">'0.9'</span>):</span>
<span id="cb11-74"><a href="#cb11-74" aria-hidden="true" tabindex="-1"></a>        deploy_task <span class="op">=</span> deploy_model_op(</span>
<span id="cb11-75"><a href="#cb11-75" aria-hidden="true" tabindex="-1"></a>            model_path<span class="op">=</span>train_task.outputs[<span class="st">'model_path'</span>]</span>
<span id="cb11-76"><a href="#cb11-76" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb11-77"><a href="#cb11-77" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-78"><a href="#cb11-78" aria-hidden="true" tabindex="-1"></a><span class="co"># Compile and run the pipeline</span></span>
<span id="cb11-79"><a href="#cb11-79" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">'__main__'</span>:</span>
<span id="cb11-80"><a href="#cb11-80" aria-hidden="true" tabindex="-1"></a>    kfp.compiler.Compiler().<span class="bu">compile</span>(pytorch_training_pipeline, <span class="st">'pytorch_pipeline.yaml'</span>)</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="resource-management" class="level3">
<h3 class="anchored" data-anchor-id="resource-management" id="resource-management">1. Resource Management</h3>
<ul>
<li>Always specify resource requests and limits</li>
<li>Use GPU resources efficiently with proper scheduling</li>
<li>Implement proper cleanup procedures</li>
</ul>
</section>
<section id="data-management" class="level3">
<h3 class="anchored" data-anchor-id="data-management" id="data-management">2. Data Management</h3>
<ul>
<li>Use persistent volumes for model storage</li>
<li>Implement data versioning</li>
<li>Use distributed storage for large datasets</li>
</ul>
</section>
<section id="monitoring-and-logging" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-and-logging" id="monitoring-and-logging">3. Monitoring and Logging</h3>
<ul>
<li>Implement comprehensive logging</li>
<li>Use metrics collection for model performance</li>
<li>Set up alerts for training failures</li>
</ul>
</section>
<section id="security" class="level3">
<h3 class="anchored" data-anchor-id="security" id="security">4. Security</h3>
<ul>
<li>Use proper RBAC configurations</li>
<li>Secure container images</li>
<li>Implement secrets management for sensitive data</li>
</ul>
</section>
<section id="scalability" class="level3">
<h3 class="anchored" data-anchor-id="scalability" id="scalability">5. Scalability</h3>
<ul>
<li>Design for horizontal scaling</li>
<li>Use distributed training for large models</li>
<li>Implement efficient data loading pipelines</li>
</ul>
</section>
<section id="model-versioning" class="level3">
<h3 class="anchored" data-anchor-id="model-versioning" id="model-versioning">6. Model Versioning</h3>
<ul>
<li>Tag and version your models</li>
<li>Implement A/B testing for model deployments</li>
<li>Use model registries for tracking</li>
</ul>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">7. Error Handling</h3>
<ul>
<li>Implement robust error handling in training scripts</li>
<li>Use appropriate restart policies</li>
<li>Set up proper monitoring and alerting</li>
</ul>
<p>This guide provides a comprehensive foundation for using Kubeflow with PyTorch for deep learning workflows. Adapt the examples to your specific use cases and requirements.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[MLflow for PyTorch - Complete Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/mlflow-pytorch/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/mlflow-pytorch/</guid>
      <pubDate>Fri, 30 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="mlflow-for-pytorch---complete-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/mlflow-pytorch/mlflow.jpg" class="img-fluid"></p>
<p>MLflow is an open-source platform for managing the machine learning lifecycle, including experimentation, reproducibility, deployment, and model registry. This guide covers how to integrate MLflow with PyTorch for comprehensive ML workflow management. ## Installation and Setup</p>
<section id="install-mlflow" class="level3">
<h3 class="anchored" data-anchor-id="install-mlflow" id="install-mlflow">Install MLflow</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install mlflow</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision</span></code></pre></div></div>
</section>
<section id="start-mlflow-ui" class="level3">
<h3 class="anchored" data-anchor-id="start-mlflow-ui" id="start-mlflow-ui">Start MLflow UI</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> ui</span></code></pre></div></div>
<p>This starts the MLflow UI at <code>http://localhost:5000</code></p>
</section>
<section id="basic-configuration" class="level3">
<h3 class="anchored" data-anchor-id="basic-configuration" id="basic-configuration">Basic Configuration</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pytorch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Set tracking URI (optional - defaults to local)</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>mlflow.set_tracking_uri(<span class="st">"http://localhost:5000"</span>)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Set experiment name</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>mlflow.set_experiment(<span class="st">"pytorch_experiments"</span>)</span></code></pre></div></div>
</section>
<section id="basic-mlflow-concepts" class="level2">
<h2 class="anchored" data-anchor-id="basic-mlflow-concepts" id="basic-mlflow-concepts">Basic MLflow Concepts</h2>
<ul>
<li><strong>Experiment</strong>: A collection of runs for a particular task</li>
<li><strong>Run</strong>: A single execution of your ML code</li>
<li><strong>Artifact</strong>: Files generated during a run (models, plots, data)</li>
<li><strong>Metric</strong>: Numerical values tracked over time</li>
<li><strong>Parameter</strong>: Input configurations for your run</li>
</ul>
</section>
<section id="experiment-tracking" class="level2">
<h2 class="anchored" data-anchor-id="experiment-tracking" id="experiment-tracking">Experiment Tracking</h2>
<section id="basic-run-structure" class="level3">
<h3 class="anchored" data-anchor-id="basic-run-structure" id="basic-run-structure">Basic Run Structure</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run():</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Your training code here</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"learning_rate"</span>, <span class="fl">0.001</span>)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    mlflow.log_metric(<span class="st">"accuracy"</span>, <span class="fl">0.95</span>)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    mlflow.log_artifact(<span class="st">"model.pth"</span>)</span></code></pre></div></div>
</section>
<section id="complete-training-example" class="level3">
<h3 class="anchored" data-anchor-id="complete-training-example" id="complete-training-example">Complete Training Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pytorch</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.metrics <span class="im">import</span> accuracy_score</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleNet(nn.Module):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, hidden_size, num_classes):</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(SimpleNet, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(input_size, hidden_size)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relu <span class="op">=</span> nn.ReLU()</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(hidden_size, num_classes)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.relu(out)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.fc2(out)</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_model():</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Hyperparameters</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    input_size <span class="op">=</span> <span class="dv">784</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>    hidden_size <span class="op">=</span> <span class="dv">128</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>    num_classes <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    learning_rate <span class="op">=</span> <span class="fl">0.001</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> <span class="dv">64</span></span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>    num_epochs <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Start MLflow run</span></span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run():</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log hyperparameters</span></span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"input_size"</span>, input_size)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"hidden_size"</span>, hidden_size)</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"num_classes"</span>, num_classes)</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"learning_rate"</span>, learning_rate)</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"batch_size"</span>, batch_size)</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"num_epochs"</span>, num_epochs)</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize model</span></span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> SimpleNet(input_size, hidden_size, num_classes)</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> optim.Adam(model.parameters(), lr<span class="op">=</span>learning_rate)</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Training loop</span></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>            total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simulate training data</span></span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">100</span>):  <span class="co"># 100 batches</span></span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Generate dummy data</span></span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>                inputs <span class="op">=</span> torch.randn(batch_size, input_size)</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>                labels <span class="op">=</span> torch.randint(<span class="dv">0</span>, num_classes, (batch_size,))</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a>                optimizer.zero_grad()</span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> model(inputs)</span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>                loss.backward()</span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>                optimizer.step()</span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a>                running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a>                _, predicted <span class="op">=</span> torch.<span class="bu">max</span>(outputs.data, <span class="dv">1</span>)</span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a>                total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a>                correct <span class="op">+=</span> (predicted <span class="op">==</span> labels).<span class="bu">sum</span>().item()</span>
<span id="cb5-68"><a href="#cb5-68" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-69"><a href="#cb5-69" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Calculate metrics</span></span>
<span id="cb5-70"><a href="#cb5-70" aria-hidden="true" tabindex="-1"></a>            epoch_loss <span class="op">=</span> running_loss <span class="op">/</span> <span class="dv">100</span></span>
<span id="cb5-71"><a href="#cb5-71" aria-hidden="true" tabindex="-1"></a>            epoch_acc <span class="op">=</span> <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span>
<span id="cb5-72"><a href="#cb5-72" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-73"><a href="#cb5-73" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Log metrics</span></span>
<span id="cb5-74"><a href="#cb5-74" aria-hidden="true" tabindex="-1"></a>            mlflow.log_metric(<span class="st">"loss"</span>, epoch_loss, step<span class="op">=</span>epoch)</span>
<span id="cb5-75"><a href="#cb5-75" aria-hidden="true" tabindex="-1"></a>            mlflow.log_metric(<span class="st">"accuracy"</span>, epoch_acc, step<span class="op">=</span>epoch)</span>
<span id="cb5-76"><a href="#cb5-76" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb5-77"><a href="#cb5-77" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch [</span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>num_epochs<span class="sc">}</span><span class="ss">], Loss: </span><span class="sc">{</span>epoch_loss<span class="sc">:.4f}</span><span class="ss">, Accuracy: </span><span class="sc">{</span>epoch_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb5-78"><a href="#cb5-78" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-79"><a href="#cb5-79" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log model</span></span>
<span id="cb5-80"><a href="#cb5-80" aria-hidden="true" tabindex="-1"></a>        mlflow.pytorch.log_model(model, <span class="st">"model"</span>)</span>
<span id="cb5-81"><a href="#cb5-81" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-82"><a href="#cb5-82" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log additional artifacts</span></span>
<span id="cb5-83"><a href="#cb5-83" aria-hidden="true" tabindex="-1"></a>        torch.save(model.state_dict(), <span class="st">"model_state_dict.pth"</span>)</span>
<span id="cb5-84"><a href="#cb5-84" aria-hidden="true" tabindex="-1"></a>        mlflow.log_artifact(<span class="st">"model_state_dict.pth"</span>)</span>
<span id="cb5-85"><a href="#cb5-85" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-86"><a href="#cb5-86" aria-hidden="true" tabindex="-1"></a><span class="co"># Run training</span></span>
<span id="cb5-87"><a href="#cb5-87" aria-hidden="true" tabindex="-1"></a>train_model()</span></code></pre></div></div>
</section>
</section>
<section id="model-logging" class="level2">
<h2 class="anchored" data-anchor-id="model-logging" id="model-logging">Model Logging</h2>
<section id="different-ways-to-log-pytorch-models" class="level3">
<h3 class="anchored" data-anchor-id="different-ways-to-log-pytorch-models" id="different-ways-to-log-pytorch-models">Different Ways to Log PyTorch Models</h3>
<section id="log-complete-model" class="level4">
<h4 class="anchored" data-anchor-id="log-complete-model">1. Log Complete Model</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Log the entire model</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(model, <span class="st">"complete_model"</span>)</span></code></pre></div></div>
</section>
<section id="log-model-state-dict" class="level4">
<h4 class="anchored" data-anchor-id="log-model-state-dict">2. Log Model State Dict</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Save and log state dict</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>torch.save(model.state_dict(), <span class="st">"model_state_dict.pth"</span>)</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>mlflow.log_artifact(<span class="st">"model_state_dict.pth"</span>)</span></code></pre></div></div>
</section>
<section id="log-with-custom-code" class="level4">
<h4 class="anchored" data-anchor-id="log-with-custom-code">3. Log with Custom Code</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Log model with custom code for loading</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    model, </span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"model"</span>,</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    code_paths<span class="op">=</span>[<span class="st">"model_definition.py"</span>]  <span class="co"># Include custom model definition</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="log-with-conda-environment" class="level4">
<h4 class="anchored" data-anchor-id="log-with-conda-environment">4. Log with Conda Environment</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pytorch</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Create conda environment specification</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>conda_env <span class="op">=</span> {</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'channels'</span>: [<span class="st">'defaults'</span>, <span class="st">'pytorch'</span>],</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'dependencies'</span>: [</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'python=3.8'</span>,</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'pytorch'</span>,</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'torchvision'</span>,</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        {<span class="st">'pip'</span>: [<span class="st">'mlflow'</span>]}</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    ],</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="st">'name'</span>: <span class="st">'pytorch_env'</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    model,</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    <span class="st">"model"</span>,</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    conda_env<span class="op">=</span>conda_env</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
</section>
<section id="model-registry" class="level2">
<h2 class="anchored" data-anchor-id="model-registry" id="model-registry">Model Registry</h2>
<section id="register-model" class="level3">
<h3 class="anchored" data-anchor-id="register-model" id="register-model">Register Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Register model during logging</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.log_model(</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    model, </span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"model"</span>,</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    registered_model_name<span class="op">=</span><span class="st">"MyPyTorchModel"</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Or register existing run</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>model_uri <span class="op">=</span> <span class="st">"runs:/your_run_id/model"</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>mlflow.register_model(model_uri, <span class="st">"MyPyTorchModel"</span>)</span></code></pre></div></div>
</section>
<section id="model-versioning-and-stages" class="level3">
<h3 class="anchored" data-anchor-id="model-versioning-and-stages" id="model-versioning-and-stages">Model Versioning and Stages</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> mlflow.tracking <span class="im">import</span> MlflowClient</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>client <span class="op">=</span> MlflowClient()</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Transition model to different stages</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>client.transition_model_version_stage(</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    name<span class="op">=</span><span class="st">"MyPyTorchModel"</span>,</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    version<span class="op">=</span><span class="dv">1</span>,</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    stage<span class="op">=</span><span class="st">"Production"</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Get model by stage</span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>model_version <span class="op">=</span> client.get_latest_versions(</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    <span class="st">"MyPyTorchModel"</span>, </span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    stages<span class="op">=</span>[<span class="st">"Production"</span>]</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>)[<span class="dv">0</span>]</span></code></pre></div></div>
</section>
<section id="load-registered-model" class="level3">
<h3 class="anchored" data-anchor-id="load-registered-model" id="load-registered-model">Load Registered Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model from registry</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> mlflow.pytorch.load_model(</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    model_uri<span class="op">=</span><span class="ss">f"models:/MyPyTorchModel/Production"</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Or load specific version</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> mlflow.pytorch.load_model(</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    model_uri<span class="op">=</span><span class="ss">f"models:/MyPyTorchModel/1"</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
<section id="model-deployment" class="level2">
<h2 class="anchored" data-anchor-id="model-deployment" id="model-deployment">Model Deployment</h2>
<section id="local-serving" class="level3">
<h3 class="anchored" data-anchor-id="local-serving" id="local-serving">Local Serving</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Serve model locally</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Run in terminal:</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co"># mlflow models serve -m models:/MyPyTorchModel/Production -p 1234</span></span></code></pre></div></div>
</section>
<section id="prediction-with-served-model" class="level3">
<h3 class="anchored" data-anchor-id="prediction-with-served-model" id="prediction-with-served-model">Prediction with Served Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Prepare data</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> {</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"inputs"</span>: [[<span class="fl">1.0</span>, <span class="fl">2.0</span>, <span class="fl">3.0</span>, <span class="fl">4.0</span>]]  <span class="co"># Your input features</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Make prediction request</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>response <span class="op">=</span> requests.post(</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">"http://localhost:1234/invocations"</span>,</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>    headers<span class="op">=</span>{<span class="st">"Content-Type"</span>: <span class="st">"application/json"</span>},</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>    data<span class="op">=</span>json.dumps(data)</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>predictions <span class="op">=</span> response.json()</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(predictions)</span></code></pre></div></div>
</section>
<section id="docker-deployment" class="level3">
<h3 class="anchored" data-anchor-id="docker-deployment" id="docker-deployment">Docker Deployment</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build Docker image</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="ex">mlflow</span> models build-docker <span class="at">-m</span> models:/MyPyTorchModel/Production <span class="at">-n</span> my-pytorch-model</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Run Docker container</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> run <span class="at">-p</span> 8080:8080 my-pytorch-model</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="custom-mlflow-model" class="level3">
<h3 class="anchored" data-anchor-id="custom-mlflow-model" id="custom-mlflow-model">Custom MLflow Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> mlflow.pyfunc</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PyTorchModelWrapper(mlflow.pyfunc.PythonModel):</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model):</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, context, model_input):</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Custom prediction logic</span></span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>            tensor_input <span class="op">=</span> torch.FloatTensor(model_input.values)</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> <span class="va">self</span>.model(tensor_input)</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> predictions.numpy()</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Log custom model</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>wrapped_model <span class="op">=</span> PyTorchModelWrapper(model)</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>mlflow.pyfunc.log_model(</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    <span class="st">"custom_model"</span>, </span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>    python_model<span class="op">=</span>wrapped_model</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="automatic-logging" class="level3">
<h3 class="anchored" data-anchor-id="automatic-logging" id="automatic-logging">Automatic Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable automatic logging</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>mlflow.pytorch.autolog()</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Your training code - metrics and models are logged automatically</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run():</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training happens here</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</section>
<section id="logging-hyperparameter-sweeps" class="level3">
<h3 class="anchored" data-anchor-id="logging-hyperparameter-sweeps" id="logging-hyperparameter-sweeps">Logging Hyperparameter Sweeps</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> itertools</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Define hyperparameter grid</span></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>param_grid <span class="op">=</span> {</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'learning_rate'</span>: [<span class="fl">0.001</span>, <span class="fl">0.01</span>, <span class="fl">0.1</span>],</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'hidden_size'</span>: [<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">256</span>],</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'batch_size'</span>: [<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">128</span>]</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Run experiments</span></span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> params <span class="kw">in</span> [<span class="bu">dict</span>(<span class="bu">zip</span>(param_grid.keys(), v)) </span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>               <span class="cf">for</span> v <span class="kw">in</span> itertools.product(<span class="op">*</span>param_grid.values())]:</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> mlflow.start_run():</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log parameters</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> key, value <span class="kw">in</span> params.items():</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>            mlflow.log_param(key, value)</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Train model with these parameters</span></span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> train_with_params(params)</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log results</span></span>
<span id="cb18-22"><a href="#cb18-22" aria-hidden="true" tabindex="-1"></a>        mlflow.log_metric(<span class="st">"final_accuracy"</span>, accuracy)</span>
<span id="cb18-23"><a href="#cb18-23" aria-hidden="true" tabindex="-1"></a>        mlflow.pytorch.log_model(model, <span class="st">"model"</span>)</span></code></pre></div></div>
</section>
<section id="logging-artifacts-and-plots" class="level3">
<h3 class="anchored" data-anchor-id="logging-artifacts-and-plots" id="logging-artifacts-and-plots">Logging Artifacts and Plots</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and log plots</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_training_plots(train_losses, val_losses):</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    plt.plot(train_losses, label<span class="op">=</span><span class="st">'Training Loss'</span>)</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    plt.plot(val_losses, label<span class="op">=</span><span class="st">'Validation Loss'</span>)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    plt.xlabel(<span class="st">'Epoch'</span>)</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>    plt.ylabel(<span class="st">'Loss'</span>)</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    plt.legend()</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>    plt.title(<span class="st">'Training and Validation Loss'</span>)</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    plt.savefig(<span class="st">'loss_plot.png'</span>)</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>    mlflow.log_artifact(<span class="st">'loss_plot.png'</span>)</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>    plt.close()</span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Log confusion matrix</span></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> log_confusion_matrix(y_true, y_pred, class_names):</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>    <span class="im">from</span> sklearn.metrics <span class="im">import</span> confusion_matrix</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>    <span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-22"><a href="#cb19-22" aria-hidden="true" tabindex="-1"></a>    cm <span class="op">=</span> confusion_matrix(y_true, y_pred)</span>
<span id="cb19-23"><a href="#cb19-23" aria-hidden="true" tabindex="-1"></a>    plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">6</span>))</span>
<span id="cb19-24"><a href="#cb19-24" aria-hidden="true" tabindex="-1"></a>    sns.heatmap(cm, annot<span class="op">=</span><span class="va">True</span>, fmt<span class="op">=</span><span class="st">'d'</span>, cmap<span class="op">=</span><span class="st">'Blues'</span>, </span>
<span id="cb19-25"><a href="#cb19-25" aria-hidden="true" tabindex="-1"></a>                xticklabels<span class="op">=</span>class_names, yticklabels<span class="op">=</span>class_names)</span>
<span id="cb19-26"><a href="#cb19-26" aria-hidden="true" tabindex="-1"></a>    plt.title(<span class="st">'Confusion Matrix'</span>)</span>
<span id="cb19-27"><a href="#cb19-27" aria-hidden="true" tabindex="-1"></a>    plt.savefig(<span class="st">'confusion_matrix.png'</span>)</span>
<span id="cb19-28"><a href="#cb19-28" aria-hidden="true" tabindex="-1"></a>    mlflow.log_artifact(<span class="st">'confusion_matrix.png'</span>)</span>
<span id="cb19-29"><a href="#cb19-29" aria-hidden="true" tabindex="-1"></a>    plt.close()</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="organize-experiments" class="level3">
<h3 class="anchored" data-anchor-id="organize-experiments" id="organize-experiments">1. Organize Experiments</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use descriptive experiment names</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>mlflow.set_experiment(<span class="st">"image_classification_resnet"</span>)</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Use run names for specific configurations</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> mlflow.start_run(run_name<span class="op">=</span><span class="st">"resnet50_adam_lr001"</span>):</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</section>
<section id="comprehensive-logging" class="level3">
<h3 class="anchored" data-anchor-id="comprehensive-logging" id="comprehensive-logging">2. Comprehensive Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> comprehensive_logging(model, optimizer, criterion, config):</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log hyperparameters</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>    mlflow.log_params(config)</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log model architecture info</span></span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    total_params <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"total_parameters"</span>, total_params)</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"model_architecture"</span>, <span class="bu">str</span>(model))</span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log optimizer info</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"optimizer"</span>, <span class="bu">type</span>(optimizer).<span class="va">__name__</span>)</span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"criterion"</span>, <span class="bu">type</span>(criterion).<span class="va">__name__</span>)</span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log system info</span></span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a>    mlflow.log_param(<span class="st">"cuda_available"</span>, torch.cuda.is_available())</span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.cuda.is_available():</span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"gpu_name"</span>, torch.cuda.get_device_name(<span class="dv">0</span>))</span></code></pre></div></div>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">3. Error Handling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> safe_mlflow_run(training_function, <span class="op">**</span>kwargs):</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> mlflow.start_run():</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> training_function(<span class="op">**</span>kwargs)</span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>            mlflow.log_param(<span class="st">"status"</span>, <span class="st">"success"</span>)</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"status"</span>, <span class="st">"failed"</span>)</span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>        mlflow.log_param(<span class="st">"error"</span>, <span class="bu">str</span>(e))</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> e</span></code></pre></div></div>
</section>
<section id="model-comparison" class="level3">
<h3 class="anchored" data-anchor-id="model-comparison" id="model-comparison">4. Model Comparison</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compare_models():</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get experiment</span></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>    experiment <span class="op">=</span> mlflow.get_experiment_by_name(<span class="st">"pytorch_experiments"</span>)</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a>    runs <span class="op">=</span> mlflow.search_runs(experiment_ids<span class="op">=</span>[experiment.experiment_id])</span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sort by accuracy</span></span>
<span id="cb23-7"><a href="#cb23-7" aria-hidden="true" tabindex="-1"></a>    best_runs <span class="op">=</span> runs.sort_values(<span class="st">"metrics.accuracy"</span>, ascending<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb23-8"><a href="#cb23-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb23-9"><a href="#cb23-9" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Top 5 models by accuracy:"</span>)</span>
<span id="cb23-10"><a href="#cb23-10" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(best_runs[[<span class="st">"run_id"</span>, <span class="st">"metrics.accuracy"</span>, <span class="st">"params.learning_rate"</span>]].head())</span></code></pre></div></div>
</section>
<section id="model-loading-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="model-loading-best-practices" id="model-loading-best-practices">5. Model Loading Best Practices</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_model_safely(model_uri):</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> mlflow.pytorch.load_model(model_uri)</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>        model.<span class="bu">eval</span>()  <span class="co"># Set to evaluation mode</span></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> model</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Error loading model: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> load_model_safely(<span class="st">"models:/MyPyTorchModel/Production"</span>)</span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> model:</span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use model for inference</span></span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">pass</span></span></code></pre></div></div>
</section>
</section>
<section id="summary" class="level2">
<h2 class="anchored" data-anchor-id="summary" id="summary">Summary</h2>
<p>MLflow provides a comprehensive solution for managing PyTorch ML workflows:</p>
<ul>
<li><strong>Experiment Tracking</strong>: Log parameters, metrics, and artifacts</li>
<li><strong>Model Management</strong>: Version and organize your models</li>
<li><strong>Model Registry</strong>: Centralized model store with lifecycle management<br>
</li>
<li><strong>Deployment</strong>: Easy model serving and deployment options</li>
<li><strong>Reproducibility</strong>: Track everything needed to reproduce experiments</li>
</ul>
<p>Start with basic experiment tracking, then gradually adopt more advanced features like the model registry and deployment capabilities as your ML workflow matures.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[CLIP Code Guide: Complete Implementation and Usage]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/clip-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/clip-code/</guid>
      <pubDate>Thu, 29 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="clip-code-guide-complete-implementation-and-usage" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/clip-code/clip.png" class="img-fluid"></p>
<section id="introduction-to-clip" class="level2">
<h2 class="anchored" data-anchor-id="introduction-to-clip" id="introduction-to-clip">Introduction to CLIP</h2>
<p>CLIP (Contrastive Language-Image Pre-training) is a neural network architecture developed by OpenAI that learns visual concepts from natural language supervision. It can understand images in the context of natural language descriptions, enabling zero-shot classification and multimodal understanding.</p>
<section id="key-features" class="level3">
<h3 class="anchored" data-anchor-id="key-features" id="key-features">Key Features:</h3>
<ul>
<li>Zero-shot image classification</li>
<li>Text-image similarity computation</li>
<li>Multimodal embeddings</li>
<li>Transfer learning capabilities</li>
</ul>
</section>
</section>
<section id="architecture-overview" class="level2">
<h2 class="anchored" data-anchor-id="architecture-overview" id="architecture-overview">Architecture Overview</h2>
<p>CLIP consists of two main components:</p>
<ol type="1">
<li><strong>Text Encoder</strong>: Processes text descriptions (typically a Transformer)</li>
<li><strong>Image Encoder</strong>: Processes images (typically a Vision Transformer or ResNet)</li>
</ol>
<p>The model learns to maximize the cosine similarity between corresponding text-image pairs while minimizing it for non-corresponding pairs.</p>
</section>
<section id="setting-up-the-environment" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-the-environment" id="setting-up-the-environment">Setting Up the Environment</h2>
<section id="installing-dependencies" class="level3">
<h3 class="anchored" data-anchor-id="installing-dependencies" id="installing-dependencies">Installing Dependencies</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Basic installation</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision transformers</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install clip-by-openai  <span class="co"># Official OpenAI CLIP</span></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install open-clip-torch  <span class="co"># OpenCLIP (more models)</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="co"># For development and training</span></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install wandb datasets accelerate</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install matplotlib pillow requests</span></code></pre></div></div>
</section>
<section id="alternative-installation" class="level3">
<h3 class="anchored" data-anchor-id="alternative-installation" id="alternative-installation">Alternative Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install from source</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/openai/CLIP.git</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> CLIP</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span></code></pre></div></div>
</section>
</section>
<section id="basic-clip-usage" class="level2">
<h2 class="anchored" data-anchor-id="basic-clip-usage" id="basic-clip-usage">Basic CLIP Usage</h2>
<section id="loading-pre-trained-clip-model" class="level3">
<h3 class="anchored" data-anchor-id="loading-pre-trained-clip-model" id="loading-pre-trained-clip-model">1. Loading Pre-trained CLIP Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> clip</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> io <span class="im">import</span> BytesIO</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model and preprocessing</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> <span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>model, preprocess <span class="op">=</span> clip.load(<span class="st">"ViT-B/32"</span>, device<span class="op">=</span>device)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Available models: ViT-B/32, ViT-B/16, ViT-L/14, RN50, RN101, RN50x4, etc.</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Available models: </span><span class="sc">{</span>clip<span class="sc">.</span>available_models()<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="image-classification-zero-shot" class="level3">
<h3 class="anchored" data-anchor-id="image-classification-zero-shot" id="image-classification-zero-shot">2. Image Classification (Zero-shot)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> zero_shot_classification(image_path, text_options):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load and preprocess image</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    image_input <span class="op">=</span> preprocess(image).unsqueeze(<span class="dv">0</span>).to(device)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Tokenize text options</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    text_inputs <span class="op">=</span> clip.tokenize(text_options).to(device)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get predictions</span></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> model.encode_image(image_input)</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> model.encode_text(text_inputs)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate similarities</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        similarities <span class="op">=</span> (<span class="fl">100.0</span> <span class="op">*</span> image_features <span class="op">@</span> text_features.T).softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get results</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    values, indices <span class="op">=</span> similarities[<span class="dv">0</span>].topk(<span class="bu">len</span>(text_options))</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> value, index <span class="kw">in</span> <span class="bu">zip</span>(values, indices):</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        results.append({</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            <span class="st">'label'</span>: text_options[index],</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            <span class="st">'confidence'</span>: value.item()</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> results</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>text_options <span class="op">=</span> [<span class="st">"a dog"</span>, <span class="st">"a cat"</span>, <span class="st">"a car"</span>, <span class="st">"a bird"</span>, <span class="st">"a house"</span>]</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> zero_shot_classification(<span class="st">"path/to/image.jpg"</span>, text_options)</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> result <span class="kw">in</span> results:</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"</span><span class="sc">{</span>result[<span class="st">'label'</span>]<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>result[<span class="st">'confidence'</span>]<span class="sc">:.2%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="text-image-similarity" class="level3">
<h3 class="anchored" data-anchor-id="text-image-similarity" id="text-image-similarity">3. Text-Image Similarity</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> compute_similarity(image_path, text_description):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Load image</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    image_input <span class="op">=</span> preprocess(image).unsqueeze(<span class="dv">0</span>).to(device)</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Tokenize text</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    text_input <span class="op">=</span> clip.tokenize([text_description]).to(device)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get features</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> model.encode_image(image_input)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> model.encode_text(text_input)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize features</span></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> image_features <span class="op">/</span> image_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> text_features <span class="op">/</span> text_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute similarity</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        similarity <span class="op">=</span> (image_features <span class="op">@</span> text_features.T).item()</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> similarity</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>similarity <span class="op">=</span> compute_similarity(<span class="st">"dog.jpg"</span>, <span class="st">"a golden retriever sitting in grass"</span>)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Similarity: </span><span class="sc">{</span>similarity<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="custom-clip-implementation" class="level2">
<h2 class="anchored" data-anchor-id="custom-clip-implementation" id="custom-clip-implementation">Custom CLIP Implementation</h2>
<section id="basic-clip-architecture" class="level3">
<h3 class="anchored" data-anchor-id="basic-clip-architecture" id="basic-clip-architecture">Basic CLIP Architecture</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> GPT2Model, GPT2Tokenizer</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> timm</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CLIPModel(nn.Module):</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, </span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>                 image_encoder_name<span class="op">=</span><span class="st">'resnet50'</span>,</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>                 text_encoder_name<span class="op">=</span><span class="st">'gpt2'</span>,</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>                 embed_dim<span class="op">=</span><span class="dv">512</span>,</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>                 image_resolution<span class="op">=</span><span class="dv">224</span>,</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>                 vocab_size<span class="op">=</span><span class="dv">49408</span>):</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.embed_dim <span class="op">=</span> embed_dim</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_resolution <span class="op">=</span> image_resolution</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Image encoder</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.visual <span class="op">=</span> timm.create_model(image_encoder_name, pretrained<span class="op">=</span><span class="va">True</span>, num_classes<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        visual_dim <span class="op">=</span> <span class="va">self</span>.visual.num_features</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Text encoder</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text_encoder <span class="op">=</span> GPT2Model.from_pretrained(text_encoder_name)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        text_dim <span class="op">=</span> <span class="va">self</span>.text_encoder.config.n_embd</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Projection layers</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.visual_projection <span class="op">=</span> nn.Linear(visual_dim, embed_dim, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.text_projection <span class="op">=</span> nn.Linear(text_dim, embed_dim, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Learnable temperature parameter</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logit_scale <span class="op">=</span> nn.Parameter(torch.ones([]) <span class="op">*</span> np.log(<span class="dv">1</span> <span class="op">/</span> <span class="fl">0.07</span>))</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.initialize_parameters()</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> initialize_parameters(<span class="va">self</span>):</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize projection layers</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        nn.init.normal_(<span class="va">self</span>.visual_projection.weight, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>        nn.init.normal_(<span class="va">self</span>.text_projection.weight, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_image(<span class="va">self</span>, image):</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract visual features</span></span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        visual_features <span class="op">=</span> <span class="va">self</span>.visual(image)</span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Project to common embedding space</span></span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> <span class="va">self</span>.visual_projection(visual_features)</span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize</span></span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> F.normalize(image_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image_features</span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_text(<span class="va">self</span>, text):</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get text features from last token</span></span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>        text_outputs <span class="op">=</span> <span class="va">self</span>.text_encoder(text)</span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use last token's representation</span></span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> text_outputs.last_hidden_state[:, <span class="op">-</span><span class="dv">1</span>, :]</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Project to common embedding space</span></span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> <span class="va">self</span>.text_projection(text_features)</span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize</span></span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> F.normalize(text_features, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> text_features</span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, image, text):</span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> <span class="va">self</span>.encode_image(image)</span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> <span class="va">self</span>.encode_text(text)</span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute logits</span></span>
<span id="cb6-66"><a href="#cb6-66" aria-hidden="true" tabindex="-1"></a>        logit_scale <span class="op">=</span> <span class="va">self</span>.logit_scale.exp()</span>
<span id="cb6-67"><a href="#cb6-67" aria-hidden="true" tabindex="-1"></a>        logits_per_image <span class="op">=</span> logit_scale <span class="op">*</span> image_features <span class="op">@</span> text_features.t()</span>
<span id="cb6-68"><a href="#cb6-68" aria-hidden="true" tabindex="-1"></a>        logits_per_text <span class="op">=</span> logits_per_image.t()</span>
<span id="cb6-69"><a href="#cb6-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-70"><a href="#cb6-70" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits_per_image, logits_per_text</span></code></pre></div></div>
</section>
<section id="clip-loss-function" class="level3">
<h3 class="anchored" data-anchor-id="clip-loss-function" id="clip-loss-function">CLIP Loss Function</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> clip_loss(logits_per_image, logits_per_text):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Contrastive loss for CLIP training</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    batch_size <span class="op">=</span> logits_per_image.shape[<span class="dv">0</span>]</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    labels <span class="op">=</span> torch.arange(batch_size, device<span class="op">=</span>logits_per_image.device)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Cross-entropy loss for both directions</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    loss_i <span class="op">=</span> F.cross_entropy(logits_per_image, labels)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    loss_t <span class="op">=</span> F.cross_entropy(logits_per_text, labels)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Average the losses</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> (loss_i <span class="op">+</span> loss_t) <span class="op">/</span> <span class="dv">2</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss</span></code></pre></div></div>
</section>
</section>
<section id="training-clip-from-scratch" class="level2">
<h2 class="anchored" data-anchor-id="training-clip-from-scratch" id="training-clip-from-scratch">Training CLIP from Scratch</h2>
<section id="dataset-preparation" class="level3">
<h3 class="anchored" data-anchor-id="dataset-preparation" id="dataset-preparation">Dataset Preparation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> Dataset, DataLoader</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ImageTextDataset(Dataset):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_path, image_dir, transform<span class="op">=</span><span class="va">None</span>, tokenizer<span class="op">=</span><span class="va">None</span>, max_length<span class="op">=</span><span class="dv">77</span>):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(data_path, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.data <span class="op">=</span> json.load(f)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_dir <span class="op">=</span> image_dir</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transform</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.tokenizer <span class="op">=</span> tokenizer</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.max_length <span class="op">=</span> max_length</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.data)</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        item <span class="op">=</span> <span class="va">self</span>.data[idx]</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load image</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        image_path <span class="op">=</span> os.path.join(<span class="va">self</span>.image_dir, item[<span class="st">'image'</span>])</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(image_path).convert(<span class="st">'RGB'</span>)</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.transform:</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> <span class="va">self</span>.transform(image)</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Tokenize text</span></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        text <span class="op">=</span> item[<span class="st">'caption'</span>]</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.tokenizer:</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>            text_tokens <span class="op">=</span> <span class="va">self</span>.tokenizer(</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>                text,</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>                max_length<span class="op">=</span><span class="va">self</span>.max_length,</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>                padding<span class="op">=</span><span class="st">'max_length'</span>,</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>                truncation<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>                return_tensors<span class="op">=</span><span class="st">'pt'</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>            )[<span class="st">'input_ids'</span>].squeeze(<span class="dv">0</span>)</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Simple tokenization for demonstration</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>            text_tokens <span class="op">=</span> torch.zeros(<span class="va">self</span>.max_length, dtype<span class="op">=</span>torch.<span class="bu">long</span>)</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image, text_tokens, text</span></code></pre></div></div>
</section>
<section id="training-loop" class="level3">
<h3 class="anchored" data-anchor-id="training-loop" id="training-loop">Training Loop</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_clip(model, dataloader, optimizer, scheduler, device, num_epochs):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (images, text_tokens, _) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>            images <span class="op">=</span> images.to(device)</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>            text_tokens <span class="op">=</span> text_tokens.to(device)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>            logits_per_image, logits_per_text <span class="op">=</span> model(images, text_tokens)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute loss</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> clip_loss(logits_per_image, logits_per_text)</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            num_batches <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Logging</span></span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update learning rate</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        scheduler.step()</span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        avg_loss <span class="op">=</span> total_loss <span class="op">/</span> num_batches</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss"> completed. Average Loss: </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Training setup</span></span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> CLIPModel().to(device)</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">1e-4</span>, weight_decay<span class="op">=</span><span class="fl">0.01</span>)</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>scheduler <span class="op">=</span> torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max<span class="op">=</span>num_epochs)</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Start training</span></span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>train_clip(model, dataloader, optimizer, scheduler, device, num_epochs<span class="op">=</span><span class="dv">100</span>)</span></code></pre></div></div>
</section>
</section>
<section id="fine-tuning-clip" class="level2">
<h2 class="anchored" data-anchor-id="fine-tuning-clip" id="fine-tuning-clip">Fine-tuning CLIP</h2>
<section id="domain-specific-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="domain-specific-fine-tuning" id="domain-specific-fine-tuning">Domain-Specific Fine-tuning</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fine_tune_clip(pretrained_model, dataloader, num_epochs<span class="op">=</span><span class="dv">10</span>, lr<span class="op">=</span><span class="fl">1e-5</span>):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Freeze most layers, only fine-tune projection layers</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> param <span class="kw">in</span> pretrained_model.visual.parameters():</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>        param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> param <span class="kw">in</span> pretrained_model.text_encoder.parameters():</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Only train projection layers</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam([</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        {<span class="st">'params'</span>: pretrained_model.visual_projection.parameters()},</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        {<span class="st">'params'</span>: pretrained_model.text_projection.parameters()},</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        {<span class="st">'params'</span>: [pretrained_model.logit_scale]}</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    ], lr<span class="op">=</span>lr)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    pretrained_model.train()</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (images, text_tokens, _) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>            images <span class="op">=</span> images.to(device)</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>            text_tokens <span class="op">=</span> text_tokens.to(device)</span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>            logits_per_image, logits_per_text <span class="op">=</span> pretrained_model(images, text_tokens)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> clip_loss(logits_per_image, logits_per_text)</span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">50</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Fine-tune Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">'</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-applications" class="level2">
<h2 class="anchored" data-anchor-id="advanced-applications" id="advanced-applications">Advanced Applications</h2>
<section id="image-search-with-clip" class="level3">
<h3 class="anchored" data-anchor-id="image-search-with-clip" id="image-search-with-clip">1. Image Search with CLIP</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CLIPImageSearch:</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, preprocess):</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.preprocess <span class="op">=</span> preprocess</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_features <span class="op">=</span> <span class="va">None</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_paths <span class="op">=</span> <span class="va">None</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> index_images(<span class="va">self</span>, image_paths):</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Pre-compute features for all images"""</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_paths <span class="op">=</span> image_paths</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> []</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> img_path <span class="kw">in</span> image_paths:</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(img_path)</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>            image_input <span class="op">=</span> <span class="va">self</span>.preprocess(image).unsqueeze(<span class="dv">0</span>).to(device)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>                image_feature <span class="op">=</span> <span class="va">self</span>.model.encode_image(image_input)</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>                features.append(image_feature)</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_features <span class="op">=</span> torch.cat(features, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_features <span class="op">=</span> <span class="va">self</span>.image_features <span class="op">/</span> <span class="va">self</span>.image_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> search(<span class="va">self</span>, query_text, top_k<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Search for images matching the text query"""</span></span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        text_input <span class="op">=</span> clip.tokenize([query_text]).to(device)</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>            text_features <span class="op">=</span> <span class="va">self</span>.model.encode_text(text_input)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>            text_features <span class="op">=</span> text_features <span class="op">/</span> text_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute similarities</span></span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>            similarities <span class="op">=</span> (text_features <span class="op">@</span> <span class="va">self</span>.image_features.T).squeeze(<span class="dv">0</span>)</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Get top-k results</span></span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>            top_similarities, top_indices <span class="op">=</span> similarities.topk(top_k)</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> sim, idx <span class="kw">in</span> <span class="bu">zip</span>(top_similarities, top_indices):</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>                <span class="st">'path'</span>: <span class="va">self</span>.image_paths[idx],</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>                <span class="st">'similarity'</span>: sim.item()</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage example</span></span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>search_engine <span class="op">=</span> CLIPImageSearch(model, preprocess)</span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>search_engine.index_images(list_of_image_paths)</span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>results <span class="op">=</span> search_engine.search(<span class="st">"a red sports car"</span>, top_k<span class="op">=</span><span class="dv">10</span>)</span></code></pre></div></div>
</section>
<section id="content-based-image-clustering" class="level3">
<h3 class="anchored" data-anchor-id="content-based-image-clustering" id="content-based-image-clustering">2. Content-Based Image Clustering</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.cluster <span class="im">import</span> KMeans</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cluster_images_by_content(image_paths, n_clusters<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Extract features for all images</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    features <span class="op">=</span> []</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> img_path <span class="kw">in</span> image_paths:</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(img_path)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        image_input <span class="op">=</span> preprocess(image).unsqueeze(<span class="dv">0</span>).to(device)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>            feature <span class="op">=</span> model.encode_image(image_input)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>            features.append(feature.cpu().numpy())</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to numpy array</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    features <span class="op">=</span> np.vstack(features)</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Perform clustering</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>    kmeans <span class="op">=</span> KMeans(n_clusters<span class="op">=</span>n_clusters, random_state<span class="op">=</span><span class="dv">42</span>)</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    cluster_labels <span class="op">=</span> kmeans.fit_predict(features)</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Organize results</span></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>    clusters <span class="op">=</span> {}</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, label <span class="kw">in</span> <span class="bu">enumerate</span>(cluster_labels):</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> label <span class="kw">not</span> <span class="kw">in</span> clusters:</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>            clusters[label] <span class="op">=</span> []</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>        clusters[label].append(image_paths[i])</span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> clusters</span></code></pre></div></div>
</section>
<section id="visual-question-answering" class="level3">
<h3 class="anchored" data-anchor-id="visual-question-answering" id="visual-question-answering">3. Visual Question Answering</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> visual_qa(image_path, question, answer_choices):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simple VQA using CLIP"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> Image.<span class="bu">open</span>(image_path)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    image_input <span class="op">=</span> preprocess(image).unsqueeze(<span class="dv">0</span>).to(device)</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create prompts combining question with each answer</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    prompts <span class="op">=</span> [<span class="ss">f"Question: </span><span class="sc">{</span>question<span class="sc">}</span><span class="ss"> Answer: </span><span class="sc">{</span>choice<span class="sc">}</span><span class="ss">"</span> <span class="cf">for</span> choice <span class="kw">in</span> answer_choices]</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    text_inputs <span class="op">=</span> clip.tokenize(prompts).to(device)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        image_features <span class="op">=</span> model.encode_image(image_input)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> model.encode_text(text_inputs)</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute similarities</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        similarities <span class="op">=</span> (<span class="fl">100.0</span> <span class="op">*</span> image_features <span class="op">@</span> text_features.T).softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Return the most likely answer</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    best_idx <span class="op">=</span> similarities.argmax().item()</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> answer_choices[best_idx], similarities[<span class="dv">0</span>][best_idx].item()</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>answer, confidence <span class="op">=</span> visual_qa(</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>    <span class="st">"image.jpg"</span>,</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>    <span class="st">"What color is the car?"</span>,</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>    [<span class="st">"red"</span>, <span class="st">"blue"</span>, <span class="st">"green"</span>, <span class="st">"yellow"</span>, <span class="st">"black"</span>]</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Answer: </span><span class="sc">{</span>answer<span class="sc">}</span><span class="ss">, Confidence: </span><span class="sc">{</span>confidence<span class="sc">:.2%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="batch-processing" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing" id="batch-processing">1. Batch Processing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_encode_images(image_paths, batch_size<span class="op">=</span><span class="dv">32</span>):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Process images in batches for better efficiency"""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    all_features <span class="op">=</span> []</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(image_paths), batch_size):</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        batch_paths <span class="op">=</span> image_paths[i:i<span class="op">+</span>batch_size]</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        batch_images <span class="op">=</span> []</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> path <span class="kw">in</span> batch_paths:</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(path)</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>            image_input <span class="op">=</span> preprocess(image)</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>            batch_images.append(image_input)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        batch_tensor <span class="op">=</span> torch.stack(batch_images).to(device)</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>            batch_features <span class="op">=</span> model.encode_image(batch_tensor)</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>            all_features.append(batch_features.cpu())</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> torch.cat(all_features, dim<span class="op">=</span><span class="dv">0</span>)</span></code></pre></div></div>
</section>
<section id="mixed-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">2. Mixed Precision Training</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_mixed_precision(model, dataloader, optimizer, num_epochs):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    scaler <span class="op">=</span> GradScaler()</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> images, text_tokens, _ <span class="kw">in</span> dataloader:</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>            images <span class="op">=</span> images.to(device)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            text_tokens <span class="op">=</span> text_tokens.to(device)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass with autocast</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> autocast():</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>                logits_per_image, logits_per_text <span class="op">=</span> model(images, text_tokens)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> clip_loss(logits_per_image, logits_per_text)</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass with scaling</span></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>            scaler.scale(loss).backward()</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>            scaler.step(optimizer)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>            scaler.update()</span></code></pre></div></div>
</section>
<section id="model-quantization" class="level3">
<h3 class="anchored" data-anchor-id="model-quantization" id="model-quantization">3. Model Quantization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization <span class="im">as</span> quantization</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> quantize_clip_model(model):</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Quantize CLIP model for inference"""</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Specify quantization configuration</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    model.qconfig <span class="op">=</span> quantization.get_default_qconfig(<span class="st">'fbgemm'</span>)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Prepare model for quantization</span></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    model_prepared <span class="op">=</span> quantization.prepare(model, inplace<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calibrate with sample data (you need to provide calibration data)</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ... calibration code here ...</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to quantized model</span></span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>    model_quantized <span class="op">=</span> quantization.convert(model_prepared, inplace<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model_quantized</span></code></pre></div></div>
</section>
</section>
<section id="common-issues-and-solutions" class="level2">
<h2 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h2>
<section id="memory-management" class="level3">
<h3 class="anchored" data-anchor-id="memory-management" id="memory-management">1. Memory Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Clear GPU cache</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>torch.cuda.empty_cache()</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Use gradient checkpointing for large models</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> enable_gradient_checkpointing(model):</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">hasattr</span>(model.visual, <span class="st">'set_grad_checkpointing'</span>):</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        model.visual.set_grad_checkpointing(<span class="va">True</span>)</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">hasattr</span>(model.text_encoder, <span class="st">'gradient_checkpointing_enable'</span>):</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        model.text_encoder.gradient_checkpointing_enable()</span></code></pre></div></div>
</section>
<section id="handling-different-image-sizes" class="level3">
<h3 class="anchored" data-anchor-id="handling-different-image-sizes" id="handling-different-image-sizes">2. Handling Different Image Sizes</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_adaptive_transform(target_size<span class="op">=</span><span class="dv">224</span>):</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> transforms.Compose([</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>        transforms.Resize(target_size, interpolation<span class="op">=</span>transforms.InterpolationMode.BICUBIC),</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        transforms.CenterCrop(target_size),</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize(</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>            mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>            std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    ])</span></code></pre></div></div>
</section>
<section id="text-preprocessing" class="level3">
<h3 class="anchored" data-anchor-id="text-preprocessing" id="text-preprocessing">3. Text Preprocessing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> re</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_text(text, max_length<span class="op">=</span><span class="dv">77</span>):</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Clean and preprocess text for CLIP"""</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Remove special characters and extra whitespace</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    text <span class="op">=</span> re.sub(<span class="vs">r'</span><span class="pp">[^</span><span class="dv">\w\s</span><span class="pp">]</span><span class="vs">'</span>, <span class="st">''</span>, text)</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    text <span class="op">=</span> <span class="st">' '</span>.join(text.split())</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Truncate if too long</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>    words <span class="op">=</span> text.split()</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(words) <span class="op">&gt;</span> max_length <span class="op">-</span> <span class="dv">2</span>:  <span class="co"># Account for special tokens</span></span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        text <span class="op">=</span> <span class="st">' '</span>.join(words[:max_length<span class="op">-</span><span class="dv">2</span>])</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> text</span></code></pre></div></div>
</section>
<section id="model-evaluation-utilities" class="level3">
<h3 class="anchored" data-anchor-id="model-evaluation-utilities" id="model-evaluation-utilities">4. Model Evaluation Utilities</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate_zero_shot_accuracy(model, preprocess, test_loader, class_names):</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Evaluate zero-shot classification accuracy"""</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Encode class names</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    text_inputs <span class="op">=</span> clip.tokenize([<span class="ss">f"a photo of a </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">"</span> <span class="cf">for</span> name <span class="kw">in</span> class_names]).to(device)</span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> model.encode_text(text_inputs)</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        text_features <span class="op">=</span> text_features <span class="op">/</span> text_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> images, labels <span class="kw">in</span> test_loader:</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>            images <span class="op">=</span> images.to(device)</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Encode images</span></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>            image_features <span class="op">=</span> model.encode_image(images)</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>            image_features <span class="op">=</span> image_features <span class="op">/</span> image_features.norm(dim<span class="op">=-</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute similarities</span></span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>            similarities <span class="op">=</span> (<span class="fl">100.0</span> <span class="op">*</span> image_features <span class="op">@</span> text_features.T).softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> similarities.argmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> (predictions <span class="op">==</span> labels.to(device)).<span class="bu">sum</span>().item()</span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> labels.size(<span class="dv">0</span>)</span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> correct <span class="op">/</span> total</span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>accuracy <span class="op">=</span> evaluate_zero_shot_accuracy(model, preprocess, test_loader, class_names)</span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Zero-shot accuracy: </span><span class="sc">{</span>accuracy<span class="sc">:.2%}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>This guide covers the essential aspects of working with CLIP, from basic usage to advanced implementations. Key takeaways:</p>
<ol type="1">
<li><strong>Start Simple</strong>: Use pre-trained models for most applications</li>
<li><strong>Understand the Architecture</strong>: CLIP’s power comes from joint text-image training</li>
<li><strong>Optimize for Your Use Case</strong>: Fine-tune or customize based on your specific needs</li>
<li><strong>Monitor Performance</strong>: Use proper evaluation metrics and optimization techniques</li>
<li><strong>Handle Edge Cases</strong>: Implement robust preprocessing and error handling</li>
</ol>
<p>For production deployments, consider:</p>
<ul>
<li>Model quantization for faster inference</li>
<li>Batch processing for efficiency</li>
<li>Proper error handling and fallbacks</li>
<li>Monitoring and logging for performance tracking</li>
</ul>
<p>The field of multimodal AI is rapidly evolving, so stay updated with the latest research and implementations to leverage CLIP’s full potential in your applications.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Self-Supervised Learning: Training AI Without Labels]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/self-supervised-explained/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/self-supervised-explained/</guid>
      <pubDate>Thu, 29 May 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="self-supervised-learning-training-ai-without-labels" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/self-supervised-explained/selfsupervised.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Machine learning has traditionally relied on vast amounts of labeled data to train models effectively. However, acquiring high-quality labeled datasets is expensive, time-consuming, and often impractical for many real-world applications. Self-supervised learning has emerged as a revolutionary paradigm that addresses these challenges by learning meaningful representations from unlabeled data itself.</p>
</section>
<section id="what-is-self-supervised-learning" class="level2">
<h2 class="anchored" data-anchor-id="what-is-self-supervised-learning" id="what-is-self-supervised-learning">What is Self-Supervised Learning?</h2>
<p>Self-supervised learning is a machine learning approach where models learn to understand and represent data by predicting parts of the input from other parts, without requiring external labels or human annotations. Instead of relying on manually created labels, the model generates its own supervisory signal from the inherent structure and patterns within the data.</p>
<p>The key insight behind self-supervised learning is that data contains rich internal structure and relationships that can serve as teaching signals. By designing tasks that require the model to understand these relationships, we can train systems that develop sophisticated representations of the underlying data distribution.</p>
</section>
<section id="core-principles-and-mechanisms" class="level2">
<h2 class="anchored" data-anchor-id="core-principles-and-mechanisms" id="core-principles-and-mechanisms">Core Principles and Mechanisms</h2>
<p>Self-supervised learning operates on several fundamental principles that distinguish it from traditional supervised learning approaches.</p>
<section id="pretext-tasks" class="level3">
<h3 class="anchored" data-anchor-id="pretext-tasks" id="pretext-tasks">Pretext Tasks</h3>
<p>The foundation of self-supervised learning lies in carefully designed pretext tasks. These are artificial objectives created from the data itself, such as predicting missing words in a sentence, reconstructing masked portions of an image, or forecasting future frames in a video sequence. While these tasks may seem simple, they force the model to develop deep understanding of the data’s underlying structure.</p>
</section>
<section id="representation-learning" class="level3">
<h3 class="anchored" data-anchor-id="representation-learning" id="representation-learning">Representation Learning</h3>
<p>Rather than training models for specific end tasks, self-supervised learning focuses on learning general-purpose representations that capture the essential characteristics of the data. These learned representations can then be transferred to downstream tasks with minimal additional training, making them highly versatile and efficient.</p>
</section>
<section id="data-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="data-efficiency" id="data-efficiency">Data Efficiency</h3>
<p>By leveraging the vast amounts of unlabeled data available in the real world, self-supervised learning can achieve performance comparable to or exceeding supervised methods while requiring significantly fewer labeled examples for fine-tuning on specific tasks.</p>
</section>
</section>
<section id="training-methodology" class="level2">
<h2 class="anchored" data-anchor-id="training-methodology" id="training-methodology">Training Methodology</h2>
<p>The training process for self-supervised learning involves several distinct phases, each designed to maximize the model’s ability to extract meaningful patterns from unlabeled data.</p>
<section id="phase-1-pretext-task-design" class="level3">
<h3 class="anchored" data-anchor-id="phase-1-pretext-task-design" id="phase-1-pretext-task-design">Phase 1: Pretext Task Design</h3>
<p>The success of self-supervised learning heavily depends on the choice and design of pretext tasks. Effective pretext tasks must strike a delicate balance: they should be challenging enough to require sophisticated understanding of the data, yet solvable enough to provide clear learning signals.</p>
<p>In natural language processing, common pretext tasks include masked language modeling, where random words in sentences are hidden and the model must predict them based on context. For computer vision, popular approaches include image inpainting, where portions of images are masked and must be reconstructed, or contrastive learning, where the model learns to distinguish between similar and dissimilar image pairs.</p>
</section>
<section id="phase-2-architecture-selection" class="level3">
<h3 class="anchored" data-anchor-id="phase-2-architecture-selection" id="phase-2-architecture-selection">Phase 2: Architecture Selection</h3>
<p>Self-supervised learning models typically employ architectures specifically designed to excel at the chosen pretext tasks. Transformer architectures have proven particularly effective for language tasks due to their ability to capture long-range dependencies and contextual relationships. For vision tasks, convolutional neural networks, vision transformers, and hybrid architectures are commonly used depending on the specific requirements.</p>
<p>The architecture must be capable of processing the input data format while being flexible enough to handle the artificial constraints imposed by the pretext task. Many self-supervised models use encoder-decoder structures, where the encoder learns compressed representations and the decoder reconstructs or predicts the target output.</p>
</section>
<section id="phase-3-training-process" class="level3">
<h3 class="anchored" data-anchor-id="phase-3-training-process" id="phase-3-training-process">Phase 3: Training Process</h3>
<p>During training, the model processes large quantities of unlabeled data, continuously solving the pretext task and adjusting its parameters through backpropagation. The training objective is typically formulated as minimizing a loss function that measures how well the model performs on the pretext task.</p>
<p>Unlike supervised learning, where the model sees explicit input-output pairs, self-supervised training involves creating these pairs automatically from the data itself. For example, in masked language modeling, the complete sentence serves as both input (with masks) and target output (original words), while in image reconstruction tasks, corrupted images are inputs and clean images are targets.</p>
</section>
<section id="phase-4-fine-tuning-and-transfer" class="level3">
<h3 class="anchored" data-anchor-id="phase-4-fine-tuning-and-transfer" id="phase-4-fine-tuning-and-transfer">Phase 4: Fine-tuning and Transfer</h3>
<p>After pretraining on the self-supervised task, the learned representations are adapted for specific downstream applications through fine-tuning. This process typically requires only small amounts of labeled data and relatively few training iterations, as the model has already learned to extract relevant features from the pretraining phase.</p>
<p>The fine-tuning process often involves adding task-specific layers on top of the pretrained encoder and training the entire system on the target task. Alternatively, the pretrained representations can be used as fixed feature extractors, with only the final classification or regression layers being trained.</p>
</section>
</section>
<section id="common-training-strategies" class="level2">
<h2 class="anchored" data-anchor-id="common-training-strategies" id="common-training-strategies">Common Training Strategies</h2>
<p>Several proven strategies have emerged for training effective self-supervised models across different domains.</p>
<section id="contrastive-learning" class="level3">
<h3 class="anchored" data-anchor-id="contrastive-learning" id="contrastive-learning">Contrastive Learning</h3>
<p>Contrastive learning has become one of the most successful approaches, particularly in computer vision. This method teaches models to distinguish between positive pairs (similar or related data points) and negative pairs (dissimilar or unrelated data points). By maximizing agreement between positive pairs while minimizing agreement between negative pairs, models learn representations that capture semantic similarity and difference.</p>
</section>
<section id="masked-modeling" class="level3">
<h3 class="anchored" data-anchor-id="masked-modeling" id="masked-modeling">Masked Modeling</h3>
<p>Masked modeling represents another highly effective strategy, where portions of the input are randomly hidden and the model must predict the missing content. This approach forces the model to develop understanding of context and relationships within the data, leading to rich representational learning.</p>
</section>
<section id="predictive-modeling" class="level3">
<h3 class="anchored" data-anchor-id="predictive-modeling" id="predictive-modeling">Predictive Modeling</h3>
<p>Predictive modeling involves training models to forecast future states or missing information based on available context. This could include predicting future video frames, completing partial sequences, or inferring hidden attributes from observable features.</p>
</section>
</section>
<section id="advantages-and-applications" class="level2">
<h2 class="anchored" data-anchor-id="advantages-and-applications" id="advantages-and-applications">Advantages and Applications</h2>
<p>Self-supervised learning offers several compelling advantages over traditional supervised approaches. The most significant benefit is the ability to leverage vast amounts of unlabeled data that would otherwise remain unused, dramatically expanding the available training resources. This approach also reduces dependence on expensive human annotation processes and can discover patterns and relationships that might not be obvious to human labelers.</p>
<p>The versatility of self-supervised representations makes them valuable across numerous applications. In natural language processing, models like BERT and GPT have revolutionized tasks ranging from translation and summarization to question answering and text generation. Computer vision applications include object recognition, image segmentation, and visual reasoning, while in other domains, self-supervised learning has shown promise for speech recognition, drug discovery, and robotic control.</p>
</section>
<section id="challenges-and-limitations" class="level2">
<h2 class="anchored" data-anchor-id="challenges-and-limitations" id="challenges-and-limitations">Challenges and Limitations</h2>
<p>Despite its promise, self-supervised learning faces several important challenges. Designing effective pretext tasks requires deep understanding of the data domain and careful consideration of what patterns the model should learn. Poor pretext task design can lead to models that excel at artificial objectives but fail to capture semantically meaningful representations.</p>
<p>The computational requirements for self-supervised learning can be substantial, as these models often require processing massive datasets and training large architectures for extended periods. Additionally, evaluation of self-supervised models can be complex, as their quality is ultimately measured by performance on downstream tasks rather than the pretext task itself.</p>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<p>The field of self-supervised learning continues to evolve rapidly, with researchers exploring new pretext tasks, architectural innovations, and training methodologies. Emerging trends include multi-modal self-supervised learning that combines different data types, more sophisticated contrastive learning strategies, and the development of unified frameworks that can handle diverse self-supervised objectives.</p>
<p>As computational resources continue to grow and new algorithmic innovations emerge, self-supervised learning is poised to play an increasingly central role in artificial intelligence, potentially reducing our dependence on labeled data while improving model performance and generalization capabilities.</p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Self-supervised learning represents a fundamental shift in how we approach machine learning, moving from explicit supervision toward learning from the inherent structure of data itself. This paradigm promises to unlock the vast potential of unlabeled data while creating more robust and generalizable AI systems.</p>
<p>The key contributions of self-supervised learning include:</p>
<ul>
<li>Enabling training without manual labels by leveraging data’s inherent structure</li>
<li>Reducing dependence on expensive labeled datasets through efficient representation learning</li>
<li>Providing versatile representations that transfer well across diverse applications</li>
<li>Opening new possibilities for learning from the vast amounts of unlabeled data available in the real world</li>
</ul>
<p>As the field continues to mature, self-supervised learning will likely become an increasingly important component of the machine learning toolkit, particularly as we seek to build more capable and generalizable AI systems.</p>
<div style="page-break-after: always;"></div>
</section>
<section id="appendix-key-concepts-summary" class="level2">
<h2 class="anchored" data-anchor-id="appendix-key-concepts-summary" id="appendix-key-concepts-summary">Appendix: Key Concepts Summary</h2>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Self-supervised learning</strong> enables training without manual labels by using data’s inherent structure</li>
<li><strong>Pretext tasks</strong> are crucial for effective representation learning and must be carefully designed</li>
<li><strong>Representation learning</strong> focuses on general-purpose features that transfer to downstream tasks</li>
<li><strong>Data efficiency</strong> is achieved by leveraging vast amounts of unlabeled data</li>
<li><strong>Applications</strong> span NLP, computer vision, and many other domains</li>
<li><strong>Future developments</strong> focus on multi-modal and unified learning frameworks</li>
</ul>
</div>
</div>
<section id="glossary" class="level3">
<h3 class="anchored" data-anchor-id="glossary" id="glossary">Glossary</h3>
<dl>
<dt><strong>Pretext Task</strong></dt>
<dd>
An artificial objective created from the data itself to provide supervisory signals for learning representations.
</dd>
<dt><strong>Representation Learning</strong></dt>
<dd>
The process of learning general-purpose data representations that capture essential characteristics and can be transferred to downstream tasks.
</dd>
<dt><strong>Fine-tuning</strong></dt>
<dd>
The process of adapting pretrained representations for specific downstream applications using small amounts of labeled data.
</dd>
<dt><strong>Contrastive Learning</strong></dt>
<dd>
A training strategy that teaches models to distinguish between positive (similar) and negative (dissimilar) data point pairs.
</dd>
<dt><strong>Masked Modeling</strong></dt>
<dd>
A strategy where portions of input are hidden and the model must predict the missing content based on context.
</dd>
</dl>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DINOv2 Student-Teacher Network Training Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/dino-v2-scratch/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/dino-v2-scratch/</guid>
      <pubDate>Wed, 28 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="dinov2-student-teacher-network-training-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/dino-v2-scratch/dino.png" class="img-fluid"></p>
<p>This guide provides a complete implementation for training a DINOv2 (DINO version 2) student-teacher network from scratch using PyTorch. DINOv2 is a self-supervised learning method that trains vision transformers without labels using a teacher-student distillation framework.</p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>DINOv2 uses a student-teacher framework where:</p>
<ul>
<li><strong>Teacher network</strong>: Provides stable targets (EMA of student weights)</li>
<li><strong>Student network</strong>: Learns to match teacher outputs</li>
<li><strong>Multi-crop strategy</strong>: Uses different image crops for robustness</li>
<li><strong>Centering mechanism</strong>: Prevents mode collapse</li>
</ul>
</section>
<section id="architecture-components" class="level2">
<h2 class="anchored" data-anchor-id="architecture-components" id="architecture-components">Architecture Components</h2>
<section id="vision-transformer-vit-backbone" class="level3">
<h3 class="anchored" data-anchor-id="vision-transformer-vit-backbone" id="vision-transformer-vit-backbone">Vision Transformer (ViT) Backbone</h3>
<section id="import-libraries" class="level4">
<h4 class="anchored" data-anchor-id="import-libraries">Import Libraries</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.distributed <span class="im">as</span> dist</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.datasets <span class="im">import</span> ImageFolder</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> math</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Optional, List, Tuple</span></code></pre></div></div>
</section>
<section id="patch-embedding" class="level4">
<h4 class="anchored" data-anchor-id="patch-embedding">Patch Embedding</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PatchEmbed(nn.Module):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Image to Patch Embedding"""</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, img_size<span class="op">=</span><span class="dv">224</span>, patch_size<span class="op">=</span><span class="dv">16</span>, in_chans<span class="op">=</span><span class="dv">3</span>, embed_dim<span class="op">=</span><span class="dv">768</span>):</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.img_size <span class="op">=</span> img_size</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_size <span class="op">=</span> patch_size</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> (img_size <span class="op">//</span> patch_size) <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.proj <span class="op">=</span> nn.Conv2d(in_chans, embed_dim, kernel_size<span class="op">=</span>patch_size, stride<span class="op">=</span>patch_size)</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        B, C, H, W <span class="op">=</span> x.shape</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.proj(x).flatten(<span class="dv">2</span>).transpose(<span class="dv">1</span>, <span class="dv">2</span>)  <span class="co"># [B, N, D]</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="multi-head-self-attention" class="level4">
<h4 class="anchored" data-anchor-id="multi-head-self-attention">Multi-head Self-Attention</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiheadAttention(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Multi-head Self Attention"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, embed_dim, num_heads, dropout<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.embed_dim <span class="op">=</span> embed_dim</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_heads <span class="op">=</span> num_heads</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.head_dim <span class="op">=</span> embed_dim <span class="op">//</span> num_heads</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.qkv <span class="op">=</span> nn.Linear(embed_dim, embed_dim <span class="op">*</span> <span class="dv">3</span>)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.proj <span class="op">=</span> nn.Linear(embed_dim, embed_dim)</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        B, N, D <span class="op">=</span> x.shape</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> <span class="va">self</span>.qkv(x).reshape(B, N, <span class="dv">3</span>, <span class="va">self</span>.num_heads, <span class="va">self</span>.head_dim).permute(<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">4</span>)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        q, k, v <span class="op">=</span> qkv[<span class="dv">0</span>], qkv[<span class="dv">1</span>], qkv[<span class="dv">2</span>]</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> (q <span class="op">@</span> k.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>)) <span class="op">*</span> (<span class="va">self</span>.head_dim <span class="op">**</span> <span class="op">-</span><span class="fl">0.5</span>)</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> attn.softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> <span class="va">self</span>.dropout(attn)</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> (attn <span class="op">@</span> v).transpose(<span class="dv">1</span>, <span class="dv">2</span>).reshape(B, N, D)</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.proj(x)</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="transformer-block" class="level4">
<h4 class="anchored" data-anchor-id="transformer-block">Transformer Block</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TransformerBlock(nn.Module):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Transformer Block"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, embed_dim, num_heads, mlp_ratio<span class="op">=</span><span class="fl">4.0</span>, dropout<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm1 <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.attn <span class="op">=</span> MultiheadAttention(embed_dim, num_heads, dropout)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm2 <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        mlp_hidden_dim <span class="op">=</span> <span class="bu">int</span>(embed_dim <span class="op">*</span> mlp_ratio)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mlp <span class="op">=</span> nn.Sequential(</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>            nn.Linear(embed_dim, mlp_hidden_dim),</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>            nn.GELU(),</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout),</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>            nn.Linear(mlp_hidden_dim, embed_dim),</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.attn(<span class="va">self</span>.norm1(x))</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.mlp(<span class="va">self</span>.norm2(x))</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="vision-transformer" class="level4">
<h4 class="anchored" data-anchor-id="vision-transformer">Vision Transformer</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionTransformer(nn.Module):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Vision Transformer"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, img_size<span class="op">=</span><span class="dv">224</span>, patch_size<span class="op">=</span><span class="dv">16</span>, in_chans<span class="op">=</span><span class="dv">3</span>, embed_dim<span class="op">=</span><span class="dv">768</span>, </span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>                 depth<span class="op">=</span><span class="dv">12</span>, num_heads<span class="op">=</span><span class="dv">12</span>, mlp_ratio<span class="op">=</span><span class="fl">4.0</span>, dropout<span class="op">=</span><span class="fl">0.0</span>):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_embed <span class="op">=</span> PatchEmbed(img_size, patch_size, in_chans, embed_dim)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        num_patches <span class="op">=</span> <span class="va">self</span>.patch_embed.num_patches</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cls_token <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, embed_dim))</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pos_embed <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, num_patches <span class="op">+</span> <span class="dv">1</span>, embed_dim))</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.blocks <span class="op">=</span> nn.ModuleList([</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(depth)</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize weights</span></span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._init_weights()</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _init_weights(<span class="va">self</span>):</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        nn.init.trunc_normal_(<span class="va">self</span>.pos_embed, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        nn.init.trunc_normal_(<span class="va">self</span>.cls_token, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> m <span class="kw">in</span> <span class="va">self</span>.modules():</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(m, nn.Linear):</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>                nn.init.trunc_normal_(m.weight, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> m.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>                    nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="bu">isinstance</span>(m, nn.LayerNorm):</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.weight, <span class="fl">1.0</span>)</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a>        B <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.patch_embed(x)</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a>        cls_tokens <span class="op">=</span> <span class="va">self</span>.cls_token.expand(B, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.cat((cls_tokens, x), dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.pos_embed</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> block <span class="kw">in</span> <span class="va">self</span>.blocks:</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> block(x)</span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm(x)</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x[:, <span class="dv">0</span>]  <span class="co"># Return CLS token</span></span></code></pre></div></div>
</section>
<section id="dino-head" class="level4">
<h4 class="anchored" data-anchor-id="dino-head">Dino Head</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOHead(nn.Module):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""DINO Projection Head"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, in_dim, out_dim, hidden_dim<span class="op">=</span><span class="dv">2048</span>, bottleneck_dim<span class="op">=</span><span class="dv">256</span>, </span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>                 num_layers<span class="op">=</span><span class="dv">3</span>, use_bn<span class="op">=</span><span class="va">False</span>, norm_last_layer<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> num_layers <span class="op">==</span> <span class="dv">1</span>:</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mlp <span class="op">=</span> nn.Linear(in_dim, bottleneck_dim)</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            layers <span class="op">=</span> [nn.Linear(in_dim, hidden_dim)]</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> use_bn:</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>                layers.append(nn.BatchNorm1d(hidden_dim))</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>            layers.append(nn.GELU())</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_layers <span class="op">-</span> <span class="dv">2</span>):</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>                layers.append(nn.Linear(hidden_dim, hidden_dim))</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> use_bn:</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>                    layers.append(nn.BatchNorm1d(hidden_dim))</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>                layers.append(nn.GELU())</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>            layers.append(nn.Linear(hidden_dim, bottleneck_dim))</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mlp <span class="op">=</span> nn.Sequential(<span class="op">*</span>layers)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.<span class="bu">apply</span>(<span class="va">self</span>._init_weights)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.last_layer <span class="op">=</span> nn.utils.weight_norm(</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>            nn.Linear(bottleneck_dim, out_dim, bias<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.last_layer.weight_g.data.fill_(<span class="dv">1</span>)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> norm_last_layer:</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.last_layer.weight_g.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _init_weights(<span class="va">self</span>, m):</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(m, nn.Linear):</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>            nn.init.trunc_normal_(m.weight, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> m.bias <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>                nn.init.constant_(m.bias, <span class="dv">0</span>)</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.mlp(x)</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> nn.functional.normalize(x, dim<span class="op">=-</span><span class="dv">1</span>, p<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.last_layer(x)</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="dinov2-model" class="level4">
<h4 class="anchored" data-anchor-id="dinov2-model">DINOv2 Model</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv2(nn.Module):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Complete DINOv2 Model"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, backbone_args, head_args):</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> VisionTransformer(<span class="op">**</span>backbone_args)</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.head <span class="op">=</span> DINOHead(<span class="op">**</span>head_args)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.backbone(x)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.head(x)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
</section>
<section id="data-augmentation-and-multi-crop-strategy" class="level3">
<h3 class="anchored" data-anchor-id="data-augmentation-and-multi-crop-strategy" id="data-augmentation-and-multi-crop-strategy">Data Augmentation and Multi-Crop Strategy</h3>
<section id="multi-crop-data-augmentation" class="level4">
<h4 class="anchored" data-anchor-id="multi-crop-data-augmentation">Multi-Crop Data Augmentation</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiCropDataAugmentation:</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Multi-crop data augmentation for DINOv2"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, global_crops_scale<span class="op">=</span>(<span class="fl">0.4</span>, <span class="fl">1.0</span>), local_crops_scale<span class="op">=</span>(<span class="fl">0.05</span>, <span class="fl">0.4</span>),</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>                 global_crops_number<span class="op">=</span><span class="dv">2</span>, local_crops_number<span class="op">=</span><span class="dv">6</span>, size_crops<span class="op">=</span>(<span class="dv">224</span>, <span class="dv">96</span>)):</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_crops_number <span class="op">=</span> global_crops_number</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.local_crops_number <span class="op">=</span> local_crops_number</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global crops (teacher and student)</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.global_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>            transforms.RandomResizedCrop(size_crops[<span class="dv">0</span>], scale<span class="op">=</span>global_crops_scale, </span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>                                       interpolation<span class="op">=</span>transforms.InterpolationMode.BICUBIC),</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            transforms.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>            transforms.ColorJitter(brightness<span class="op">=</span><span class="fl">0.4</span>, contrast<span class="op">=</span><span class="fl">0.4</span>, saturation<span class="op">=</span><span class="fl">0.2</span>, hue<span class="op">=</span><span class="fl">0.1</span>),</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>            transforms.RandomGrayscale(p<span class="op">=</span><span class="fl">0.2</span>),</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>            GaussianBlur(p<span class="op">=</span><span class="fl">1.0</span>),</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>            Solarization(p<span class="op">=</span><span class="fl">0.0</span>),</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize((<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>), (<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>))</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Local crops (student only)</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.local_transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>            transforms.RandomResizedCrop(size_crops[<span class="dv">1</span>], scale<span class="op">=</span>local_crops_scale,</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>                                       interpolation<span class="op">=</span>transforms.InterpolationMode.BICUBIC),</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>            transforms.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>            transforms.ColorJitter(brightness<span class="op">=</span><span class="fl">0.4</span>, contrast<span class="op">=</span><span class="fl">0.4</span>, saturation<span class="op">=</span><span class="fl">0.2</span>, hue<span class="op">=</span><span class="fl">0.1</span>),</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>            transforms.RandomGrayscale(p<span class="op">=</span><span class="fl">0.2</span>),</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>            GaussianBlur(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>            Solarization(p<span class="op">=</span><span class="fl">0.2</span>),</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize((<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>), (<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>))</span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, image):</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        crops <span class="op">=</span> []</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Global crops</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.global_crops_number):</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>            crops.append(<span class="va">self</span>.global_transform(image))</span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Local crops</span></span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.local_crops_number):</span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>            crops.append(<span class="va">self</span>.local_transform(image))</span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> crops</span></code></pre></div></div>
</section>
<section id="augmentation-utilities" class="level4">
<h4 class="anchored" data-anchor-id="augmentation-utilities">Augmentation Utilities</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> GaussianBlur:</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Gaussian blur augmentation"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, p<span class="op">=</span><span class="fl">0.5</span>, radius_min<span class="op">=</span><span class="fl">0.1</span>, radius_max<span class="op">=</span><span class="fl">2.0</span>):</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.prob <span class="op">=</span> p</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.radius_min <span class="op">=</span> radius_min</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.radius_max <span class="op">=</span> radius_max</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, img):</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> torch.rand(<span class="dv">1</span>) <span class="op">&lt;</span> <span class="va">self</span>.prob:</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>            radius <span class="op">=</span> <span class="va">self</span>.radius_min <span class="op">+</span> torch.rand(<span class="dv">1</span>) <span class="op">*</span> (<span class="va">self</span>.radius_max <span class="op">-</span> <span class="va">self</span>.radius_min)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> transforms.functional.gaussian_blur(img, kernel_size<span class="op">=</span><span class="dv">9</span>, sigma<span class="op">=</span>radius.item())</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> img</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Solarization:</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Solarization augmentation"""</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, p<span class="op">=</span><span class="fl">0.2</span>):</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.p <span class="op">=</span> p</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, img):</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> torch.rand(<span class="dv">1</span>) <span class="op">&lt;</span> <span class="va">self</span>.p:</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> transforms.functional.solarize(img, threshold<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> img</span></code></pre></div></div>
</section>
</section>
<section id="loss-functions-and-training-components" class="level3">
<h3 class="anchored" data-anchor-id="loss-functions-and-training-components" id="loss-functions-and-training-components">Loss Functions and Training Components</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOLoss(nn.Module):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""DINO Loss with centering and sharpening"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, out_dim, ncrops, warmup_teacher_temp<span class="op">=</span><span class="fl">0.04</span>, </span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>                 teacher_temp<span class="op">=</span><span class="fl">0.04</span>, warmup_teacher_temp_epochs<span class="op">=</span><span class="dv">0</span>, </span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>                 student_temp<span class="op">=</span><span class="fl">0.1</span>, center_momentum<span class="op">=</span><span class="fl">0.9</span>):</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.student_temp <span class="op">=</span> student_temp</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.center_momentum <span class="op">=</span> center_momentum</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ncrops <span class="op">=</span> ncrops</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">"center"</span>, torch.zeros(<span class="dv">1</span>, out_dim))</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Temperature schedule</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher_temp_schedule <span class="op">=</span> np.concatenate((</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>            np.ones(<span class="dv">1000</span>) <span class="op">*</span> teacher_temp  <span class="co"># Assume max 1000 epochs</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        ))</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, student_output, teacher_output, epoch):</span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="co">        Cross-entropy between softmax outputs of the teacher and student networks.</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a><span class="co">        """</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>        student_out <span class="op">=</span> student_output <span class="op">/</span> <span class="va">self</span>.student_temp</span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>        student_out <span class="op">=</span> student_out.chunk(<span class="va">self</span>.ncrops)</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Teacher centering and sharpening</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a>        temp <span class="op">=</span> <span class="va">self</span>.teacher_temp_schedule[epoch]</span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a>        teacher_out <span class="op">=</span> F.softmax((teacher_output <span class="op">-</span> <span class="va">self</span>.center) <span class="op">/</span> temp, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a>        teacher_out <span class="op">=</span> teacher_out.detach().chunk(<span class="dv">2</span>)  <span class="co"># Only 2 global crops for teacher</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a>        n_loss_terms <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> iq, q <span class="kw">in</span> <span class="bu">enumerate</span>(teacher_out):</span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> v <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(student_out)):</span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> v <span class="op">==</span> iq:</span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">continue</span>  <span class="co"># Skip same crop</span></span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> torch.<span class="bu">sum</span>(<span class="op">-</span>q <span class="op">*</span> F.log_softmax(student_out[v], dim<span class="op">=-</span><span class="dv">1</span>), dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a>                total_loss <span class="op">+=</span> loss.mean()</span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a>                n_loss_terms <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">/=</span> n_loss_terms</span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.update_center(teacher_output)</span>
<span id="cb10-43"><a href="#cb10-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss</span>
<span id="cb10-44"><a href="#cb10-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-45"><a href="#cb10-45" aria-hidden="true" tabindex="-1"></a>    <span class="at">@torch.no_grad</span>()</span>
<span id="cb10-46"><a href="#cb10-46" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> update_center(<span class="va">self</span>, teacher_output):</span>
<span id="cb10-47"><a href="#cb10-47" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Update center used for teacher output."""</span></span>
<span id="cb10-48"><a href="#cb10-48" aria-hidden="true" tabindex="-1"></a>        batch_center <span class="op">=</span> torch.<span class="bu">sum</span>(teacher_output, dim<span class="op">=</span><span class="dv">0</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb10-49"><a href="#cb10-49" aria-hidden="true" tabindex="-1"></a>        batch_center <span class="op">=</span> batch_center <span class="op">/</span> <span class="bu">len</span>(teacher_output)</span>
<span id="cb10-50"><a href="#cb10-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-51"><a href="#cb10-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># EMA update</span></span>
<span id="cb10-52"><a href="#cb10-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.center <span class="op">=</span> <span class="va">self</span>.center <span class="op">*</span> <span class="va">self</span>.center_momentum <span class="op">+</span> batch_center <span class="op">*</span> (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.center_momentum)</span></code></pre></div></div>
<section id="training-utilities" class="level4">
<h4 class="anchored" data-anchor-id="training-utilities">Training Utilities</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="at">@torch.no_grad</span>()</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> update_teacher(student, teacher, momentum):</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""EMA update of the teacher network."""</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> param_student, param_teacher <span class="kw">in</span> <span class="bu">zip</span>(student.parameters(), teacher.parameters()):</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>        param_teacher.data.mul_(momentum).add_(param_student.data, alpha<span class="op">=</span><span class="dv">1</span> <span class="op">-</span> momentum)</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs<span class="op">=</span><span class="dv">0</span>):</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Cosine learning rate schedule with linear warmup."""</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    warmup_schedule <span class="op">=</span> np.array([])</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    warmup_iters <span class="op">=</span> warmup_epochs <span class="op">*</span> niter_per_ep</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> warmup_epochs <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        warmup_schedule <span class="op">=</span> np.linspace(<span class="dv">0</span>, base_value, warmup_iters)</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    iters <span class="op">=</span> np.arange(epochs <span class="op">*</span> niter_per_ep <span class="op">-</span> warmup_iters)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    schedule <span class="op">=</span> final_value <span class="op">+</span> <span class="fl">0.5</span> <span class="op">*</span> (base_value <span class="op">-</span> final_value) <span class="op">*</span> (<span class="dv">1</span> <span class="op">+</span> np.cos(np.pi <span class="op">*</span> iters <span class="op">/</span> <span class="bu">len</span>(iters)))</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    schedule <span class="op">=</span> np.concatenate((warmup_schedule, schedule))</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">assert</span> <span class="bu">len</span>(schedule) <span class="op">==</span> epochs <span class="op">*</span> niter_per_ep</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> schedule</span></code></pre></div></div>
</section>
</section>
</section>
<section id="training-loop-implementation" class="level2">
<h2 class="anchored" data-anchor-id="training-loop-implementation" id="training-loop-implementation">Training Loop Implementation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv2Trainer:</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""DINOv2 Training Pipeline"""</span></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, config):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.config <span class="op">=</span> config</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Model architecture configs</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>        backbone_args <span class="op">=</span> {</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">'img_size'</span>: <span class="dv">224</span>,</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">'patch_size'</span>: <span class="dv">16</span>,</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">'embed_dim'</span>: <span class="dv">768</span>,</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">'depth'</span>: <span class="dv">12</span>,</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">'num_heads'</span>: <span class="dv">12</span>,</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">'mlp_ratio'</span>: <span class="fl">4.0</span>,</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">'dropout'</span>: <span class="fl">0.0</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>        head_args <span class="op">=</span> {</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>            <span class="st">'in_dim'</span>: <span class="dv">768</span>,</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>            <span class="st">'out_dim'</span>: <span class="dv">65536</span>,  <span class="co"># Large output dimension</span></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>            <span class="st">'hidden_dim'</span>: <span class="dv">2048</span>,</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>            <span class="st">'bottleneck_dim'</span>: <span class="dv">256</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize student and teacher networks</span></span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.student <span class="op">=</span> DINOv2(backbone_args, head_args).to(<span class="va">self</span>.device)</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher <span class="op">=</span> DINOv2(backbone_args, head_args).to(<span class="va">self</span>.device)</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Teacher starts as copy of student</span></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher.load_state_dict(<span class="va">self</span>.student.state_dict())</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Teacher parameters are not updated by gradients</span></span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> p <span class="kw">in</span> <span class="va">self</span>.teacher.parameters():</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>            p.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Loss function</span></span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dino_loss <span class="op">=</span> DINOLoss(</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>            out_dim<span class="op">=</span>head_args[<span class="st">'out_dim'</span>],</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>            ncrops<span class="op">=</span><span class="dv">8</span>,  <span class="co"># 2 global + 6 local crops</span></span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>            student_temp<span class="op">=</span><span class="fl">0.1</span>,</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>            teacher_temp<span class="op">=</span><span class="fl">0.04</span>,</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>            center_momentum<span class="op">=</span><span class="fl">0.9</span></span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>        ).to(<span class="va">self</span>.device)</span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Optimizer</span></span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.optimizer <span class="op">=</span> torch.optim.AdamW(</span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.student.parameters(),</span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>            lr<span class="op">=</span>config[<span class="st">'base_lr'</span>],</span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>            weight_decay<span class="op">=</span>config[<span class="st">'weight_decay'</span>]</span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Learning rate scheduler</span></span>
<span id="cb12-53"><a href="#cb12-53" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lr_schedule <span class="op">=</span> cosine_scheduler(</span>
<span id="cb12-54"><a href="#cb12-54" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'base_lr'</span>],</span>
<span id="cb12-55"><a href="#cb12-55" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'final_lr'</span>],</span>
<span id="cb12-56"><a href="#cb12-56" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'epochs'</span>],</span>
<span id="cb12-57"><a href="#cb12-57" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'niter_per_ep'</span>],</span>
<span id="cb12-58"><a href="#cb12-58" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'warmup_epochs'</span>]</span>
<span id="cb12-59"><a href="#cb12-59" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb12-60"><a href="#cb12-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-61"><a href="#cb12-61" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Momentum schedule for teacher updates</span></span>
<span id="cb12-62"><a href="#cb12-62" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.momentum_schedule <span class="op">=</span> cosine_scheduler(</span>
<span id="cb12-63"><a href="#cb12-63" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'momentum_teacher'</span>],</span>
<span id="cb12-64"><a href="#cb12-64" aria-hidden="true" tabindex="-1"></a>            <span class="fl">1.0</span>,</span>
<span id="cb12-65"><a href="#cb12-65" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'epochs'</span>],</span>
<span id="cb12-66"><a href="#cb12-66" aria-hidden="true" tabindex="-1"></a>            config[<span class="st">'niter_per_ep'</span>]</span>
<span id="cb12-67"><a href="#cb12-67" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb12-68"><a href="#cb12-68" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-69"><a href="#cb12-69" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_epoch(<span class="va">self</span>, dataloader, epoch):</span>
<span id="cb12-70"><a href="#cb12-70" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Train for one epoch"""</span></span>
<span id="cb12-71"><a href="#cb12-71" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.student.train()</span>
<span id="cb12-72"><a href="#cb12-72" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.teacher.<span class="bu">eval</span>()</span>
<span id="cb12-73"><a href="#cb12-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-74"><a href="#cb12-74" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb12-75"><a href="#cb12-75" aria-hidden="true" tabindex="-1"></a>        num_batches <span class="op">=</span> <span class="bu">len</span>(dataloader)</span>
<span id="cb12-76"><a href="#cb12-76" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-77"><a href="#cb12-77" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> it, (images, _) <span class="kw">in</span> <span class="bu">enumerate</span>(dataloader):</span>
<span id="cb12-78"><a href="#cb12-78" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update learning rate</span></span>
<span id="cb12-79"><a href="#cb12-79" aria-hidden="true" tabindex="-1"></a>            lr <span class="op">=</span> <span class="va">self</span>.lr_schedule[epoch <span class="op">*</span> num_batches <span class="op">+</span> it]</span>
<span id="cb12-80"><a href="#cb12-80" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> param_group <span class="kw">in</span> <span class="va">self</span>.optimizer.param_groups:</span>
<span id="cb12-81"><a href="#cb12-81" aria-hidden="true" tabindex="-1"></a>                param_group[<span class="st">'lr'</span>] <span class="op">=</span> lr</span>
<span id="cb12-82"><a href="#cb12-82" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-83"><a href="#cb12-83" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Move to device and prepare crops</span></span>
<span id="cb12-84"><a href="#cb12-84" aria-hidden="true" tabindex="-1"></a>            images <span class="op">=</span> [im.to(<span class="va">self</span>.device, non_blocking<span class="op">=</span><span class="va">True</span>) <span class="cf">for</span> im <span class="kw">in</span> images]</span>
<span id="cb12-85"><a href="#cb12-85" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-86"><a href="#cb12-86" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Teacher forward pass (only on global crops)</span></span>
<span id="cb12-87"><a href="#cb12-87" aria-hidden="true" tabindex="-1"></a>            teacher_output <span class="op">=</span> <span class="va">self</span>.teacher(torch.cat(images[:<span class="dv">2</span>]))</span>
<span id="cb12-88"><a href="#cb12-88" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-89"><a href="#cb12-89" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Student forward pass (on all crops)</span></span>
<span id="cb12-90"><a href="#cb12-90" aria-hidden="true" tabindex="-1"></a>            student_output <span class="op">=</span> <span class="va">self</span>.student(torch.cat(images))</span>
<span id="cb12-91"><a href="#cb12-91" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-92"><a href="#cb12-92" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Compute loss</span></span>
<span id="cb12-93"><a href="#cb12-93" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> <span class="va">self</span>.dino_loss(student_output, teacher_output, epoch)</span>
<span id="cb12-94"><a href="#cb12-94" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-95"><a href="#cb12-95" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass</span></span>
<span id="cb12-96"><a href="#cb12-96" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.zero_grad()</span>
<span id="cb12-97"><a href="#cb12-97" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb12-98"><a href="#cb12-98" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-99"><a href="#cb12-99" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Gradient clipping</span></span>
<span id="cb12-100"><a href="#cb12-100" aria-hidden="true" tabindex="-1"></a>            torch.nn.utils.clip_grad_norm_(<span class="va">self</span>.student.parameters(), max_norm<span class="op">=</span><span class="fl">3.0</span>)</span>
<span id="cb12-101"><a href="#cb12-101" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-102"><a href="#cb12-102" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.optimizer.step()</span>
<span id="cb12-103"><a href="#cb12-103" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-104"><a href="#cb12-104" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update teacher with EMA</span></span>
<span id="cb12-105"><a href="#cb12-105" aria-hidden="true" tabindex="-1"></a>            momentum <span class="op">=</span> <span class="va">self</span>.momentum_schedule[epoch <span class="op">*</span> num_batches <span class="op">+</span> it]</span>
<span id="cb12-106"><a href="#cb12-106" aria-hidden="true" tabindex="-1"></a>            update_teacher(<span class="va">self</span>.student, <span class="va">self</span>.teacher, momentum)</span>
<span id="cb12-107"><a href="#cb12-107" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-108"><a href="#cb12-108" aria-hidden="true" tabindex="-1"></a>            total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb12-109"><a href="#cb12-109" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-110"><a href="#cb12-110" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> it <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb12-111"><a href="#cb12-111" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Iter </span><span class="sc">{</span>it<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>num_batches<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">:.4f}</span><span class="ss">, LR: </span><span class="sc">{</span>lr<span class="sc">:.6f}</span><span class="ss">'</span>)</span>
<span id="cb12-112"><a href="#cb12-112" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-113"><a href="#cb12-113" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss <span class="op">/</span> num_batches</span>
<span id="cb12-114"><a href="#cb12-114" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-115"><a href="#cb12-115" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train(<span class="va">self</span>, dataloader):</span>
<span id="cb12-116"><a href="#cb12-116" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Full training loop"""</span></span>
<span id="cb12-117"><a href="#cb12-117" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(<span class="va">self</span>.config[<span class="st">'epochs'</span>]):</span>
<span id="cb12-118"><a href="#cb12-118" aria-hidden="true" tabindex="-1"></a>            avg_loss <span class="op">=</span> <span class="va">self</span>.train_epoch(dataloader, epoch)</span>
<span id="cb12-119"><a href="#cb12-119" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">/</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>config[<span class="st">"epochs"</span>]<span class="sc">}</span><span class="ss">, Average Loss: </span><span class="sc">{</span>avg_loss<span class="sc">:.4f}</span><span class="ss">'</span>)</span>
<span id="cb12-120"><a href="#cb12-120" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-121"><a href="#cb12-121" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Save checkpoint</span></span>
<span id="cb12-122"><a href="#cb12-122" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> epoch <span class="op">%</span> <span class="va">self</span>.config[<span class="st">'save_every'</span>] <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb12-123"><a href="#cb12-123" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.save_checkpoint(epoch)</span>
<span id="cb12-124"><a href="#cb12-124" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-125"><a href="#cb12-125" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> save_checkpoint(<span class="va">self</span>, epoch):</span>
<span id="cb12-126"><a href="#cb12-126" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Save model checkpoint"""</span></span>
<span id="cb12-127"><a href="#cb12-127" aria-hidden="true" tabindex="-1"></a>        checkpoint <span class="op">=</span> {</span>
<span id="cb12-128"><a href="#cb12-128" aria-hidden="true" tabindex="-1"></a>            <span class="st">'epoch'</span>: epoch,</span>
<span id="cb12-129"><a href="#cb12-129" aria-hidden="true" tabindex="-1"></a>            <span class="st">'student_state_dict'</span>: <span class="va">self</span>.student.state_dict(),</span>
<span id="cb12-130"><a href="#cb12-130" aria-hidden="true" tabindex="-1"></a>            <span class="st">'teacher_state_dict'</span>: <span class="va">self</span>.teacher.state_dict(),</span>
<span id="cb12-131"><a href="#cb12-131" aria-hidden="true" tabindex="-1"></a>            <span class="st">'optimizer_state_dict'</span>: <span class="va">self</span>.optimizer.state_dict(),</span>
<span id="cb12-132"><a href="#cb12-132" aria-hidden="true" tabindex="-1"></a>            <span class="st">'config'</span>: <span class="va">self</span>.config</span>
<span id="cb12-133"><a href="#cb12-133" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb12-134"><a href="#cb12-134" aria-hidden="true" tabindex="-1"></a>        torch.save(checkpoint, <span class="ss">f'dinov2_checkpoint_epoch_</span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">.pth'</span>)</span></code></pre></div></div>
</section>
<section id="usage-example" class="level2">
<h2 class="anchored" data-anchor-id="usage-example" id="usage-example">Usage Example</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> main():</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training configuration</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    config <span class="op">=</span> {</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="st">'base_lr'</span>: <span class="fl">5e-4</span>,</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="st">'final_lr'</span>: <span class="fl">1e-6</span>,</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="st">'weight_decay'</span>: <span class="fl">0.04</span>,</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="st">'momentum_teacher'</span>: <span class="fl">0.996</span>,</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="st">'epochs'</span>: <span class="dv">100</span>,</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="st">'warmup_epochs'</span>: <span class="dv">10</span>,</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">'batch_size'</span>: <span class="dv">64</span>,</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">'save_every'</span>: <span class="dv">10</span>,</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="st">'niter_per_ep'</span>: <span class="va">None</span>  <span class="co"># Will be set after dataloader creation</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Data setup</span></span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    transform <span class="op">=</span> MultiCropDataAugmentation()</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    dataset <span class="op">=</span> ImageFolder(root<span class="op">=</span><span class="st">'path/to/your/dataset'</span>, transform<span class="op">=</span>transform)</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    dataloader <span class="op">=</span> DataLoader(</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        dataset, </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        batch_size<span class="op">=</span>config[<span class="st">'batch_size'</span>], </span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        shuffle<span class="op">=</span><span class="va">True</span>, </span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        pin_memory<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        drop_last<span class="op">=</span><span class="va">True</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    config[<span class="st">'niter_per_ep'</span>] <span class="op">=</span> <span class="bu">len</span>(dataloader)</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize trainer and start training</span></span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> DINOv2Trainer(config)</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>    trainer.train(dataloader)</span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>main()</span></code></pre></div></div>
</section>
<section id="key-features-implemented" class="level2">
<h2 class="anchored" data-anchor-id="key-features-implemented" id="key-features-implemented">Key Features Implemented</h2>
<ol type="1">
<li><strong>Vision Transformer Backbone</strong>: Complete ViT implementation with patch embedding, multi-head attention, and transformer blocks</li>
<li><strong>Multi-crop Strategy</strong>: Global and local crops with different augmentations</li>
<li><strong>Teacher-Student Framework</strong>: EMA updates for teacher network</li>
<li><strong>DINO Loss</strong>: Cross-entropy loss with centering mechanism to prevent collapse</li>
<li><strong>Learning Rate Scheduling</strong>: Cosine annealing with warmup</li>
<li><strong>Gradient Clipping</strong>: Stability during training</li>
<li><strong>Checkpointing</strong>: Save/load model states</li>
</ol>
</section>
<section id="training-tips" class="level2">
<h2 class="anchored" data-anchor-id="training-tips" id="training-tips">Training Tips</h2>
<ol type="1">
<li><strong>Batch Size</strong>: Use large batch sizes (256-1024) for better performance</li>
<li><strong>Data Augmentation</strong>: Strong augmentations are crucial for self-supervised learning</li>
<li><strong>Temperature Scheduling</strong>: Gradually increase teacher temperature</li>
<li><strong>Momentum Scheduling</strong>: Start with high momentum and decrease over time</li>
<li><strong>Multi-GPU Training</strong>: Use DistributedDataParallel for faster training</li>
</ol>
<p>This implementation provides a solid foundation for training DINOv2 models. Adjust hyperparameters based on your dataset size and computational resources.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Student-Teacher Network Training Guide in PyTorch]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/student-teacher-vanilla/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/student-teacher-vanilla/</guid>
      <pubDate>Wed, 28 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="student-teacher-network-training-guide-in-pytorch" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/student-teacher-vanilla/stutea.jpg" class="img-fluid"></p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>Student-teacher networks, also known as knowledge distillation, involve training a smaller “student” model to mimic the behavior of a larger, pre-trained “teacher” model. This technique helps compress large models while maintaining performance.</p>
</section>
<section id="key-concepts" class="level2">
<h2 class="anchored" data-anchor-id="key-concepts" id="key-concepts">Key Concepts</h2>
<section id="knowledge-distillation-loss" class="level3">
<h3 class="anchored" data-anchor-id="knowledge-distillation-loss" id="knowledge-distillation-loss">Knowledge Distillation Loss</h3>
<p>The student learns from both:</p>
<ol type="1">
<li><strong>Hard targets</strong>: Original ground truth labels</li>
<li><strong>Soft targets</strong>: Teacher’s probability distributions (softened with temperature)</li>
</ol>
</section>
<section id="temperature-scaling" class="level3">
<h3 class="anchored" data-anchor-id="temperature-scaling" id="temperature-scaling">Temperature Scaling</h3>
<p>Higher temperature values create softer probability distributions, making it easier for the student to learn from the teacher’s uncertainty.</p>
</section>
</section>
<section id="complete-implementation" class="level2">
<h2 class="anchored" data-anchor-id="complete-implementation" id="complete-implementation">Complete Implementation</h2>
<section id="import-libraries" class="level3">
<h3 class="anchored" data-anchor-id="import-libraries" id="import-libraries">Import Libraries</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> tqdm <span class="im">import</span> tqdm</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span></code></pre></div></div>
</section>
<section id="set-device" class="level3">
<h3 class="anchored" data-anchor-id="set-device" id="set-device">Set device</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Set device</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Using device: </span><span class="sc">{</span>device<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="define-teacher-model" class="level3">
<h3 class="anchored" data-anchor-id="define-teacher-model" id="define-teacher-model">Define Teacher Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TeacherNetwork(nn.Module):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Large teacher network (e.g., ResNet-50 equivalent)"""</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(TeacherNetwork, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">64</span>),</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">64</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">64</span>),</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">128</span>),</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">128</span>, <span class="dv">128</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">128</span>),</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">128</span>, <span class="dv">256</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">256</span>),</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">256</span>, <span class="dv">256</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">256</span>),</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>            nn.AdaptiveAvgPool2d((<span class="dv">1</span>, <span class="dv">1</span>)),</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>            nn.Flatten(),</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">256</span>, <span class="dv">512</span>),</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.5</span>),</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">512</span>, num_classes)</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="define-student-model" class="level3">
<h3 class="anchored" data-anchor-id="define-student-model" id="define-student-model">Define Student Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> StudentNetwork(nn.Module):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Smaller student network"""</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(StudentNetwork, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> nn.Sequential(</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">3</span>, <span class="dv">32</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">32</span>),</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">64</span>),</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(<span class="dv">64</span>, <span class="dv">128</span>, <span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(<span class="dv">128</span>),</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>            nn.MaxPool2d(<span class="dv">2</span>, <span class="dv">2</span>),</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Sequential(</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            nn.AdaptiveAvgPool2d((<span class="dv">1</span>, <span class="dv">1</span>)),</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            nn.Flatten(),</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">128</span>, <span class="dv">64</span>),</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">64</span>, num_classes)</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.features(x)</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.classifier(x)</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="define-distillation-loss" class="level3">
<h3 class="anchored" data-anchor-id="define-distillation-loss" id="define-distillation-loss">Define Distillation Loss</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DistillationLoss(nn.Module):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co">    Knowledge Distillation Loss combining:</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="co">    1. Cross-entropy loss with true labels</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co">    2. KL divergence loss with teacher predictions</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, alpha<span class="op">=</span><span class="fl">0.7</span>, temperature<span class="op">=</span><span class="fl">4.0</span>):</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(DistillationLoss, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> alpha  <span class="co"># Weight for distillation loss</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.temperature <span class="op">=</span> temperature</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ce_loss <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.kl_loss <span class="op">=</span> nn.KLDivLoss(reduction<span class="op">=</span><span class="st">'batchmean'</span>)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, student_logits, teacher_logits, labels):</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Cross-entropy loss with true labels</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        ce_loss <span class="op">=</span> <span class="va">self</span>.ce_loss(student_logits, labels)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Soft targets from teacher</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        teacher_probs <span class="op">=</span> F.softmax(teacher_logits <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        student_log_probs <span class="op">=</span> F.log_softmax(student_logits <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># KL divergence loss</span></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        kl_loss <span class="op">=</span> <span class="va">self</span>.kl_loss(student_log_probs, teacher_probs) <span class="op">*</span> (<span class="va">self</span>.temperature <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combined loss</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.alpha) <span class="op">*</span> ce_loss <span class="op">+</span> <span class="va">self</span>.alpha <span class="op">*</span> kl_loss</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss, ce_loss, kl_loss</span></code></pre></div></div>
</section>
<section id="load-and-preprocess-data" class="level3">
<h3 class="anchored" data-anchor-id="load-and-preprocess-data" id="load-and-preprocess-data">Load and Preprocess Data</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_data(batch_size<span class="op">=</span><span class="dv">128</span>):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Load CIFAR-10 dataset"""</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    transform_train <span class="op">=</span> transforms.Compose([</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        transforms.RandomCrop(<span class="dv">32</span>, padding<span class="op">=</span><span class="dv">4</span>),</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        transforms.RandomHorizontalFlip(),</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize((<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>), (<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>))</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    transform_test <span class="op">=</span> transforms.Compose([</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        transforms.ToTensor(),</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        transforms.Normalize((<span class="fl">0.4914</span>, <span class="fl">0.4822</span>, <span class="fl">0.4465</span>), (<span class="fl">0.2023</span>, <span class="fl">0.1994</span>, <span class="fl">0.2010</span>))</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    train_dataset <span class="op">=</span> torchvision.datasets.CIFAR10(</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        root<span class="op">=</span><span class="st">'./data'</span>, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>transform_train</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    test_dataset <span class="op">=</span> torchvision.datasets.CIFAR10(</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        root<span class="op">=</span><span class="st">'./data'</span>, train<span class="op">=</span><span class="va">False</span>, download<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span>transform_test</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span>batch_size, shuffle<span class="op">=</span><span class="va">True</span>, num_workers<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>    test_loader <span class="op">=</span> DataLoader(test_dataset, batch_size<span class="op">=</span>batch_size, shuffle<span class="op">=</span><span class="va">False</span>, num_workers<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> train_loader, test_loader</span></code></pre></div></div>
</section>
<section id="train-teacher-model" class="level3">
<h3 class="anchored" data-anchor-id="train-teacher-model" id="train-teacher-model">Train Teacher Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_teacher(model, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">50</span>, lr<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Train the teacher network"""</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Training Teacher Network..."</span>)</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.Adam(model.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">20</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    best_acc <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        progress_bar <span class="op">=</span> tqdm(train_loader, desc<span class="op">=</span><span class="ss">f'Teacher Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">'</span>)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (inputs, targets) <span class="kw">in</span> <span class="bu">enumerate</span>(progress_bar):</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>            progress_bar.set_postfix({</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Loss'</span>: <span class="ss">f'</span><span class="sc">{</span>running_loss<span class="op">/</span>(batch_idx<span class="op">+</span><span class="dv">1</span>)<span class="sc">:.4f}</span><span class="ss">'</span>,</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Acc'</span>: <span class="ss">f'</span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%'</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>        test_acc <span class="op">=</span> evaluate_model(model, test_loader)</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f'Teacher Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: Train Acc: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%, Test Acc: </span><span class="sc">{</span>test_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> test_acc <span class="op">&gt;</span> best_acc:</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>            best_acc <span class="op">=</span> test_acc</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>            torch.save(model.state_dict(), <span class="st">'teacher_best.pth'</span>)</span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>        scheduler.step()</span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Teacher training completed. Best accuracy: </span><span class="sc">{</span>best_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model</span></code></pre></div></div>
</section>
<section id="train-student-model-with-distillation" class="level3">
<h3 class="anchored" data-anchor-id="train-student-model-with-distillation" id="train-student-model-with-distillation">Train Student Model with Distillation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_student(student, teacher, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">100</span>, lr<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Train the student network using knowledge distillation"""</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Training Student Network with Knowledge Distillation..."</span>)</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    distillation_loss <span class="op">=</span> DistillationLoss(alpha<span class="op">=</span><span class="fl">0.7</span>, temperature<span class="op">=</span><span class="fl">4.0</span>)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.Adam(student.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">30</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    teacher.<span class="bu">eval</span>()  <span class="co"># Teacher in evaluation mode</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    student.train()</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    best_acc <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        running_ce_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>        running_kl_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        progress_bar <span class="op">=</span> tqdm(train_loader, desc<span class="op">=</span><span class="ss">f'Student Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">'</span>)</span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (inputs, targets) <span class="kw">in</span> <span class="bu">enumerate</span>(progress_bar):</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Get predictions from both networks</span></span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>                teacher_logits <span class="op">=</span> teacher(inputs)</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>            student_logits <span class="op">=</span> student(inputs)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Calculate distillation loss</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>            total_loss, ce_loss, kl_loss <span class="op">=</span> distillation_loss(</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>                student_logits, teacher_logits, targets</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>            total_loss.backward()</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Statistics</span></span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> total_loss.item()</span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>            running_ce_loss <span class="op">+=</span> ce_loss.item()</span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>            running_kl_loss <span class="op">+=</span> kl_loss.item()</span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> student_logits.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a>            progress_bar.set_postfix({</span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Loss'</span>: <span class="ss">f'</span><span class="sc">{</span>running_loss<span class="op">/</span>(batch_idx<span class="op">+</span><span class="dv">1</span>)<span class="sc">:.4f}</span><span class="ss">'</span>,</span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>                <span class="st">'CE'</span>: <span class="ss">f'</span><span class="sc">{</span>running_ce_loss<span class="op">/</span>(batch_idx<span class="op">+</span><span class="dv">1</span>)<span class="sc">:.4f}</span><span class="ss">'</span>,</span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>                <span class="st">'KL'</span>: <span class="ss">f'</span><span class="sc">{</span>running_kl_loss<span class="op">/</span>(batch_idx<span class="op">+</span><span class="dv">1</span>)<span class="sc">:.4f}</span><span class="ss">'</span>,</span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Acc'</span>: <span class="ss">f'</span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%'</span></span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate</span></span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a>        test_acc <span class="op">=</span> evaluate_model(student, test_loader)</span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f'Student Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: Train Acc: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%, Test Acc: </span><span class="sc">{</span>test_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> test_acc <span class="op">&gt;</span> best_acc:</span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a>            best_acc <span class="op">=</span> test_acc</span>
<span id="cb8-62"><a href="#cb8-62" aria-hidden="true" tabindex="-1"></a>            torch.save(student.state_dict(), <span class="st">'student_best.pth'</span>)</span>
<span id="cb8-63"><a href="#cb8-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-64"><a href="#cb8-64" aria-hidden="true" tabindex="-1"></a>        scheduler.step()</span>
<span id="cb8-65"><a href="#cb8-65" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-66"><a href="#cb8-66" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Student training completed. Best accuracy: </span><span class="sc">{</span>best_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb8-67"><a href="#cb8-67" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> student</span></code></pre></div></div>
</section>
<section id="train-student-model-baseline" class="level3">
<h3 class="anchored" data-anchor-id="train-student-model-baseline" id="train-student-model-baseline">Train Student Model Baseline</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_student_baseline(student, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">100</span>, lr<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Train student without distillation (baseline comparison)"""</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Training Student Network (Baseline - No Distillation)..."</span>)</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> optim.Adam(student.parameters(), lr<span class="op">=</span>lr, weight_decay<span class="op">=</span><span class="fl">1e-4</span>)</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">30</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    student.train()</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    best_acc <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        progress_bar <span class="op">=</span> tqdm(train_loader, desc<span class="op">=</span><span class="ss">f'Baseline Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>epochs<span class="sc">}</span><span class="ss">'</span>)</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (inputs, targets) <span class="kw">in</span> <span class="bu">enumerate</span>(progress_bar):</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> student(inputs)</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>            progress_bar.set_postfix({</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Loss'</span>: <span class="ss">f'</span><span class="sc">{</span>running_loss<span class="op">/</span>(batch_idx<span class="op">+</span><span class="dv">1</span>)<span class="sc">:.4f}</span><span class="ss">'</span>,</span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>                <span class="st">'Acc'</span>: <span class="ss">f'</span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%'</span></span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate</span></span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>        test_acc <span class="op">=</span> evaluate_model(student, test_loader)</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f'Baseline Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: Train Acc: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%, Test Acc: </span><span class="sc">{</span>test_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> test_acc <span class="op">&gt;</span> best_acc:</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>            best_acc <span class="op">=</span> test_acc</span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>            torch.save(student.state_dict(), <span class="st">'student_baseline_best.pth'</span>)</span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>        scheduler.step()</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Baseline training completed. Best accuracy: </span><span class="sc">{</span>best_acc<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> student</span></code></pre></div></div>
</section>
<section id="evaluate-model" class="level3">
<h3 class="anchored" data-anchor-id="evaluate-model" id="evaluate-model">Evaluate Model</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> evaluate_model(model, test_loader):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Evaluate model accuracy"""</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> inputs, targets <span class="kw">in</span> test_loader:</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> <span class="fl">100.</span> <span class="op">*</span> correct <span class="op">/</span> total</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> accuracy</span></code></pre></div></div>
</section>
<section id="count-parameters" class="level3">
<h3 class="anchored" data-anchor-id="count-parameters" id="count-parameters">Count Parameters</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> count_parameters(model):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Count trainable parameters"""</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters() <span class="cf">if</span> p.requires_grad)</span></code></pre></div></div>
</section>
<section id="main-execution" class="level3">
<h3 class="anchored" data-anchor-id="main-execution" id="main-execution">Main Execution</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Load data</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>train_loader, test_loader <span class="op">=</span> load_data(batch_size<span class="op">=</span><span class="dv">128</span>)</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize networks</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>teacher <span class="op">=</span> TeacherNetwork(num_classes<span class="op">=</span><span class="dv">10</span>).to(device)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>student_distilled <span class="op">=</span> StudentNetwork(num_classes<span class="op">=</span><span class="dv">10</span>).to(device)</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>student_baseline <span class="op">=</span> StudentNetwork(num_classes<span class="op">=</span><span class="dv">10</span>).to(device)</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Teacher parameters: </span><span class="sc">{</span>count_parameters(teacher)<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Student parameters: </span><span class="sc">{</span>count_parameters(student_distilled)<span class="sc">:,}</span><span class="ss">"</span>)</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Compression ratio: </span><span class="sc">{</span>count_parameters(teacher) <span class="op">/</span> count_parameters(student_distilled)<span class="sc">:.1f}</span><span class="ss">x"</span>)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Train teacher (or load pre-trained)</span></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    teacher.load_state_dict(torch.load(<span class="st">'teacher_best.pth'</span>))</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Loaded pre-trained teacher model"</span>)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> <span class="pp">FileNotFoundError</span>:</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Training teacher from scratch..."</span>)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    teacher <span class="op">=</span> train_teacher(teacher, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>teacher_acc <span class="op">=</span> evaluate_model(teacher, test_loader)</span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Teacher accuracy: </span><span class="sc">{</span>teacher_acc<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Train student with knowledge distillation</span></span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>student_distilled <span class="op">=</span> train_student(</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>    student_distilled, teacher, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">100</span></span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Train student baseline (without distillation)</span></span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>student_baseline <span class="op">=</span> train_student_baseline(</span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>    student_baseline, train_loader, test_loader, epochs<span class="op">=</span><span class="dv">100</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Final evaluation</span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>distilled_acc <span class="op">=</span> evaluate_model(student_distilled, test_loader)</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>baseline_acc <span class="op">=</span> evaluate_model(student_baseline, test_loader)</span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"</span><span class="ch">\n</span><span class="st">"</span> <span class="op">+</span> <span class="st">"="</span><span class="op">*</span><span class="dv">50</span>)</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"FINAL RESULTS"</span>)</span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"="</span><span class="op">*</span><span class="dv">50</span>)</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Teacher accuracy:           </span><span class="sc">{</span>teacher_acc<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Student (distilled):        </span><span class="sc">{</span>distilled_acc<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Student (baseline):         </span><span class="sc">{</span>baseline_acc<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Distillation improvement:   </span><span class="sc">{</span>distilled_acc <span class="op">-</span> baseline_acc<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Parameters reduction:       </span><span class="sc">{</span>count_parameters(teacher) <span class="op">/</span> count_parameters(student_distilled)<span class="sc">:.1f}</span><span class="ss">x"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-techniques" class="level2">
<h2 class="anchored" data-anchor-id="advanced-techniques" id="advanced-techniques">Advanced Techniques</h2>
<section id="feature-level-distillation" class="level3">
<h3 class="anchored" data-anchor-id="feature-level-distillation" id="feature-level-distillation">1. Feature-Level Distillation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FeatureDistillationLoss(nn.Module):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Distillation using intermediate feature maps"""</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, alpha<span class="op">=</span><span class="fl">0.5</span>, beta<span class="op">=</span><span class="fl">0.3</span>, temperature<span class="op">=</span><span class="fl">4.0</span>):</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.alpha <span class="op">=</span> alpha      <span class="co"># Weight for prediction distillation</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.beta <span class="op">=</span> beta        <span class="co"># Weight for feature distillation</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.temperature <span class="op">=</span> temperature</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.ce_loss <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.kl_loss <span class="op">=</span> nn.KLDivLoss(reduction<span class="op">=</span><span class="st">'batchmean'</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mse_loss <span class="op">=</span> nn.MSELoss()</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, student_logits, teacher_logits, student_features, teacher_features, labels):</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Standard distillation loss</span></span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        ce_loss <span class="op">=</span> <span class="va">self</span>.ce_loss(student_logits, labels)</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        teacher_probs <span class="op">=</span> F.softmax(teacher_logits <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        student_log_probs <span class="op">=</span> F.log_softmax(student_logits <span class="op">/</span> <span class="va">self</span>.temperature, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        kl_loss <span class="op">=</span> <span class="va">self</span>.kl_loss(student_log_probs, teacher_probs) <span class="op">*</span> (<span class="va">self</span>.temperature <span class="op">**</span> <span class="dv">2</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Feature distillation loss</span></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        feature_loss <span class="op">=</span> <span class="va">self</span>.mse_loss(student_features, teacher_features)</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">=</span> (<span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.alpha <span class="op">-</span> <span class="va">self</span>.beta) <span class="op">*</span> ce_loss <span class="op">+</span> <span class="va">self</span>.alpha <span class="op">*</span> kl_loss <span class="op">+</span> <span class="va">self</span>.beta <span class="op">*</span> feature_loss</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> total_loss, ce_loss, kl_loss, feature_loss</span></code></pre></div></div>
</section>
<section id="attention-transfer" class="level3">
<h3 class="anchored" data-anchor-id="attention-transfer" id="attention-transfer">2. Attention Transfer</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AttentionTransferLoss(nn.Module):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Transfer attention maps from teacher to student"""</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, p<span class="op">=</span><span class="dv">2</span>):</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.p <span class="op">=</span> p</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> attention_map(<span class="va">self</span>, feature_map):</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Compute attention as the L2 norm across channels</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.norm(feature_map, p<span class="op">=</span><span class="va">self</span>.p, dim<span class="op">=</span><span class="dv">1</span>, keepdim<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, student_features, teacher_features):</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        student_attention <span class="op">=</span> <span class="va">self</span>.attention_map(student_features)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        teacher_attention <span class="op">=</span> <span class="va">self</span>.attention_map(teacher_features)</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize attention maps</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        student_attention <span class="op">=</span> F.normalize(student_attention.view(student_attention.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        teacher_attention <span class="op">=</span> F.normalize(teacher_attention.view(teacher_attention.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> F.mse_loss(student_attention, teacher_attention)</span></code></pre></div></div>
</section>
<section id="progressive-knowledge-distillation" class="level3">
<h3 class="anchored" data-anchor-id="progressive-knowledge-distillation" id="progressive-knowledge-distillation">3. Progressive Knowledge Distillation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ProgressiveDistillation:</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Gradually increase distillation weight during training"""</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, initial_alpha<span class="op">=</span><span class="fl">0.1</span>, final_alpha<span class="op">=</span><span class="fl">0.9</span>, warmup_epochs<span class="op">=</span><span class="dv">20</span>):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.initial_alpha <span class="op">=</span> initial_alpha</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.final_alpha <span class="op">=</span> final_alpha</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.warmup_epochs <span class="op">=</span> warmup_epochs</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_alpha(<span class="va">self</span>, epoch):</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> epoch <span class="op">&lt;</span> <span class="va">self</span>.warmup_epochs:</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>            alpha <span class="op">=</span> <span class="va">self</span>.initial_alpha <span class="op">+</span> (<span class="va">self</span>.final_alpha <span class="op">-</span> <span class="va">self</span>.initial_alpha) <span class="op">*</span> (epoch <span class="op">/</span> <span class="va">self</span>.warmup_epochs)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>            alpha <span class="op">=</span> <span class="va">self</span>.final_alpha</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> alpha</span></code></pre></div></div>
</section>
</section>
<section id="hyperparameter-guidelines" class="level2">
<h2 class="anchored" data-anchor-id="hyperparameter-guidelines" id="hyperparameter-guidelines">Hyperparameter Guidelines</h2>
<section id="temperature-t" class="level3">
<h3 class="anchored" data-anchor-id="temperature-t" id="temperature-t">Temperature (T)</h3>
<ul>
<li><strong>Low (1-2)</strong>: Hard targets, less knowledge transfer</li>
<li><strong>Medium (3-5)</strong>: Balanced knowledge transfer (recommended)</li>
<li><strong>High (6-10)</strong>: Very soft targets, may lose important information</li>
</ul>
</section>
<section id="alpha-α" class="level3">
<h3 class="anchored" data-anchor-id="alpha-α" id="alpha-α">Alpha (α)</h3>
<ul>
<li><strong>0.1-0.3</strong>: Focus on ground truth labels</li>
<li><strong>0.5-0.7</strong>: Balanced approach (recommended)</li>
<li><strong>0.8-0.9</strong>: Heavy focus on teacher knowledge</li>
</ul>
</section>
<section id="learning-rate" class="level3">
<h3 class="anchored" data-anchor-id="learning-rate" id="learning-rate">Learning Rate</h3>
<ul>
<li>Start with same LR as baseline training</li>
<li>Consider lower LR for student to avoid overfitting to teacher</li>
<li>Use learning rate scheduling</li>
</ul>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<ol type="1">
<li><strong>Teacher Quality</strong>: Ensure teacher model is well-trained and robust</li>
<li><strong>Architecture Matching</strong>: Student should have similar structure but smaller capacity</li>
<li><strong>Temperature Tuning</strong>: Experiment with different temperature values</li>
<li><strong>Regularization</strong>: Use dropout and weight decay to prevent overfitting</li>
<li><strong>Evaluation</strong>: Compare against baseline student training</li>
<li><strong>Multi-Teacher</strong>: Consider ensemble of teachers for better knowledge transfer</li>
</ol>
</section>
<section id="common-issues-and-solutions" class="level2">
<h2 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h2>
<section id="problem-student-performs-worse-than-baseline" class="level3">
<h3 class="anchored" data-anchor-id="problem-student-performs-worse-than-baseline" id="problem-student-performs-worse-than-baseline">Problem: Student performs worse than baseline</h3>
<p><strong>Solutions:</strong></p>
<ul>
<li>Reduce temperature value</li>
<li>Decrease alpha (give more weight to ground truth)</li>
<li>Check teacher model quality</li>
<li>Ensure proper normalization</li>
</ul>
</section>
<section id="problem-slow-convergence" class="level3">
<h3 class="anchored" data-anchor-id="problem-slow-convergence" id="problem-slow-convergence">Problem: Slow convergence</h3>
<p><strong>Solutions:</strong></p>
<ul>
<li>Increase learning rate</li>
<li>Use progressive distillation</li>
<li>Warm up the distillation loss</li>
<li>Check gradient flow</li>
</ul>
</section>
<section id="problem-overfitting-to-teacher" class="level3">
<h3 class="anchored" data-anchor-id="problem-overfitting-to-teacher" id="problem-overfitting-to-teacher">Problem: Overfitting to teacher</h3>
<p><strong>Solutions:</strong></p>
<ul>
<li>Add regularization</li>
<li>Reduce alpha value</li>
<li>Use data augmentation</li>
<li>Early stopping based on validation loss</li>
</ul>
<p>This comprehensive guide provides both theoretical understanding and practical implementation of student-teacher networks in PyTorch, with advanced techniques and troubleshooting tips for successful knowledge distillation.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[LitServe Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/litserve-basics/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/litserve-basics/</guid>
      <pubDate>Tue, 27 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="litserve-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/litserve-basics/litserve.jpg" class="img-fluid"></p>
<p>LitServe is a high-performance, flexible AI model serving framework designed to deploy machine learning models with minimal code. It provides automatic batching, GPU acceleration, and easy scaling capabilities.</p>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install LitServe</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install litserve</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># For GPU support</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install litserve<span class="pp">[</span><span class="ss">gpu</span><span class="pp">]</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a><span class="co"># For development dependencies</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install litserve<span class="pp">[</span><span class="ss">dev</span><span class="pp">]</span></span></code></pre></div></div>
</section>
<section id="basic-usage" class="level2">
<h2 class="anchored" data-anchor-id="basic-usage" id="basic-usage">Basic Usage</h2>
<section id="creating-your-first-litserve-api" class="level3">
<h3 class="anchored" data-anchor-id="creating-your-first-litserve-api" id="creating-your-first-litserve-api">1. Creating Your First LitServe API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoTokenizer, AutoModelForCausalLM</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleTextGenerator(ls.LitAPI):</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load model and tokenizer during server startup</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.tokenizer <span class="op">=</span> AutoTokenizer.from_pretrained(<span class="st">"gpt2"</span>)</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> AutoModelForCausalLM.from_pretrained(<span class="st">"gpt2"</span>)</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Process incoming request</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> request[<span class="st">"prompt"</span>]</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run inference</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        inputs <span class="op">=</span> <span class="va">self</span>.tokenizer.encode(x, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model.generate(</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>                inputs, </span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>                max_length<span class="op">=</span><span class="dv">100</span>, </span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>                num_return_sequences<span class="op">=</span><span class="dv">1</span>,</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>                temperature<span class="op">=</span><span class="fl">0.7</span></span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.tokenizer.decode(outputs[<span class="dv">0</span>], skip_special_tokens<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Format response</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"generated_text"</span>: output}</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a><span class="co"># Create and start server</span></span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> SimpleTextGenerator()</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(api, accelerator<span class="op">=</span><span class="st">"auto"</span>, max_batch_size<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>    server.run(port<span class="op">=</span><span class="dv">8000</span>)</span></code></pre></div></div>
</section>
<section id="making-requests-to-your-api" class="level3">
<h3 class="anchored" data-anchor-id="making-requests-to-your-api" id="making-requests-to-your-api">2. Making Requests to Your API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Test the API</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>response <span class="op">=</span> requests.post(</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"http://localhost:8000/predict"</span>,</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    json<span class="op">=</span>{<span class="st">"prompt"</span>: <span class="st">"The future of AI is"</span>}</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(response.json())</span></code></pre></div></div>
</section>
</section>
<section id="core-concepts" class="level2">
<h2 class="anchored" data-anchor-id="core-concepts" id="core-concepts">Core Concepts</h2>
<section id="litapi-class-structure" class="level3">
<h3 class="anchored" data-anchor-id="litapi-class-structure" id="litapi-class-structure">LitAPI Class Structure</h3>
<p>Every LitServe API must inherit from <code>ls.LitAPI</code> and implement these core methods:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MyAPI(ls.LitAPI):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize models, load weights, set up preprocessing"""</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Parse and validate incoming requests"""</span></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run model inference"""</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Format model output for HTTP response"""</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span></code></pre></div></div>
</section>
<section id="optional-methods" class="level3">
<h3 class="anchored" data-anchor-id="optional-methods" id="optional-methods">Optional Methods</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdvancedAPI(ls.LitAPI):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> batch(<span class="va">self</span>, inputs):</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Custom batching logic (optional)"""</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> inputs</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unbatch(<span class="va">self</span>, output):</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Custom unbatching logic (optional)"""</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> preprocess(<span class="va">self</span>, input_data):</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Additional preprocessing (optional)"""</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> input_data</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> postprocess(<span class="va">self</span>, output):</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Additional postprocessing (optional)"""</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="custom-batching" class="level3">
<h3 class="anchored" data-anchor-id="custom-batching" id="custom-batching">1. Custom Batching</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BatchedImageClassifier(ls.LitAPI):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> torchvision <span class="im">import</span> models, transforms</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> models.resnet50(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>                               std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> base64</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> io</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decode base64 image</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        image_data <span class="op">=</span> base64.b64decode(request[<span class="st">"image"</span>])</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data))</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.transform(image).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> batch(<span class="va">self</span>, inputs):</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Custom batching for images</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.cat(inputs, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, batch):</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(batch)</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> torch.nn.functional.softmax(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> probabilities</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unbatch(<span class="va">self</span>, output):</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Split batch back to individual predictions</span></span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [pred.unsqueeze(<span class="dv">0</span>) <span class="cf">for</span> pred <span class="kw">in</span> output]</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get top prediction</span></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        confidence, predicted <span class="op">=</span> torch.<span class="bu">max</span>(output, <span class="dv">1</span>)</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>            <span class="st">"class_id"</span>: predicted.item(),</span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>            <span class="st">"confidence"</span>: confidence.item()</span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
<section id="streaming-responses" class="level3">
<h3 class="anchored" data-anchor-id="streaming-responses" id="streaming-responses">2. Streaming Responses</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> StreamingChatAPI(ls.LitAPI):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>        <span class="im">from</span> transformers <span class="im">import</span> AutoTokenizer, AutoModelForCausalLM</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.tokenizer <span class="op">=</span> AutoTokenizer.from_pretrained(<span class="st">"microsoft/DialoGPT-medium"</span>)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> AutoModelForCausalLM.from_pretrained(<span class="st">"microsoft/DialoGPT-medium"</span>)</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> request[<span class="st">"message"</span>]</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generator for streaming</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        inputs <span class="op">=</span> <span class="va">self</span>.tokenizer.encode(x, return_tensors<span class="op">=</span><span class="st">"pt"</span>)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">50</span>):  <span class="co"># Generate up to 50 tokens</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.model(inputs)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>                next_token_logits <span class="op">=</span> outputs.logits[:, <span class="op">-</span><span class="dv">1</span>, :]</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>                next_token <span class="op">=</span> torch.multinomial(</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>                    torch.softmax(next_token_logits, dim<span class="op">=-</span><span class="dv">1</span>), </span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>                    num_samples<span class="op">=</span><span class="dv">1</span></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>                )</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>                inputs <span class="op">=</span> torch.cat([inputs, next_token], dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Yield each token</span></span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>                token_text <span class="op">=</span> <span class="va">self</span>.tokenizer.decode(next_token[<span class="dv">0</span>])</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>                <span class="cf">yield</span> {<span class="st">"token"</span>: token_text}</span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Stop if end token</span></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> next_token.item() <span class="op">==</span> <span class="va">self</span>.tokenizer.eos_token_id:</span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">break</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable streaming</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(api, accelerator<span class="op">=</span><span class="st">"auto"</span>, stream<span class="op">=</span><span class="va">True</span>)</span></code></pre></div></div>
</section>
<section id="multiple-gpu-support" class="level3">
<h3 class="anchored" data-anchor-id="multiple-gpu-support" id="multiple-gpu-support">3. Multiple GPU Support</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Automatic multi-GPU scaling</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    api, </span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="st">"auto"</span>,  <span class="co"># Use all available GPUs</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    max_batch_size<span class="op">=</span><span class="dv">8</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Specify specific GPUs</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    api,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"gpu"</span>,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">2</span>],  <span class="co"># Use GPUs 0, 1, and 2</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    max_batch_size<span class="op">=</span><span class="dv">16</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="authentication-and-middleware" class="level3">
<h3 class="anchored" data-anchor-id="authentication-and-middleware" id="authentication-and-middleware">4. Authentication and Middleware</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AuthenticatedAPI(ls.LitAPI):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Your model setup</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> authenticate(<span class="va">self</span>, request):</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Custom authentication logic"""</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        api_key <span class="op">=</span> request.headers.get(<span class="st">"Authorization"</span>)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> api_key <span class="kw">or</span> <span class="kw">not</span> <span class="va">self</span>.validate_api_key(api_key):</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> ls.AuthenticationError(<span class="st">"Invalid API key"</span>)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">True</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validate_api_key(<span class="va">self</span>, api_key):</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Your API key validation logic</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        valid_keys <span class="op">=</span> [<span class="st">"your-secret-key-1"</span>, <span class="st">"your-secret-key-2"</span>]</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> api_key.replace(<span class="st">"Bearer "</span>, <span class="st">""</span>) <span class="kw">in</span> valid_keys</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> request[<span class="st">"data"</span>]</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Your prediction logic</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="ss">f"Processed: </span><span class="sc">{</span>x<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"result"</span>: output}</span></code></pre></div></div>
</section>
</section>
<section id="configuration-options" class="level2">
<h2 class="anchored" data-anchor-id="configuration-options" id="configuration-options">Configuration Options</h2>
<section id="server-configuration" class="level3">
<h3 class="anchored" data-anchor-id="server-configuration" id="server-configuration">Server Configuration</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    api<span class="op">=</span>api,</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"auto"</span>,           <span class="co"># "auto", "cpu", "gpu", "mps"</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="st">"auto"</span>,               <span class="co"># Device selection</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    max_batch_size<span class="op">=</span><span class="dv">4</span>,            <span class="co"># Maximum batch size</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    batch_timeout<span class="op">=</span><span class="fl">0.1</span>,           <span class="co"># Batch timeout in seconds</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    workers_per_device<span class="op">=</span><span class="dv">1</span>,        <span class="co"># Workers per device</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    timeout<span class="op">=</span><span class="dv">30</span>,                  <span class="co"># Request timeout</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    stream<span class="op">=</span><span class="va">False</span>,                <span class="co"># Enable streaming</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    spec<span class="op">=</span><span class="va">None</span>,                   <span class="co"># Custom OpenAPI spec</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="environment-variables" class="level3">
<h3 class="anchored" data-anchor-id="environment-variables" id="environment-variables">Environment Variables</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Set device preferences</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="bu">export</span> <span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0,1,2</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Set batch configuration</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="bu">export</span> <span class="va">LITSERVE_MAX_BATCH_SIZE</span><span class="op">=</span>8</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="bu">export</span> <span class="va">LITSERVE_BATCH_TIMEOUT</span><span class="op">=</span>0.05</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Set worker configuration</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a><span class="bu">export</span> <span class="va">LITSERVE_WORKERS_PER_DEVICE</span><span class="op">=</span>2</span></code></pre></div></div>
</section>
<section id="custom-configuration-file" class="level3">
<h3 class="anchored" data-anchor-id="custom-configuration-file" id="custom-configuration-file">Custom Configuration File</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># config.py</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Config:</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    MAX_BATCH_SIZE <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    BATCH_TIMEOUT <span class="op">=</span> <span class="fl">0.1</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    WORKERS_PER_DEVICE <span class="op">=</span> <span class="dv">2</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    ACCELERATOR <span class="op">=</span> <span class="st">"auto"</span></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    TIMEOUT <span class="op">=</span> <span class="dv">60</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Use in your API</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> config <span class="im">import</span> Config</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    api, </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    max_batch_size<span class="op">=</span>Config.MAX_BATCH_SIZE,</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    batch_timeout<span class="op">=</span>Config.BATCH_TIMEOUT,</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    workers_per_device<span class="op">=</span>Config.WORKERS_PER_DEVICE</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
<section id="examples" class="level2">
<h2 class="anchored" data-anchor-id="examples" id="examples">Examples</h2>
<section id="image-classification-api" class="level3">
<h3 class="anchored" data-anchor-id="image-classification-api" id="image-classification-api">1. Image Classification API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models, transforms</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ImageClassificationAPI(ls.LitAPI):</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load pre-trained ResNet model</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> models.resnet50(weights<span class="op">=</span>models.ResNet50_Weights.IMAGENET1K_V1)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Image preprocessing</span></span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.preprocess <span class="op">=</span> transforms.Compose([</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], </span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>                               std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load ImageNet class labels</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> <span class="bu">open</span>(<span class="st">'imagenet_classes.txt'</span>) <span class="im">as</span> f:</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.classes <span class="op">=</span> [line.strip() <span class="cf">for</span> line <span class="kw">in</span> f.readlines()]</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decode base64 image</span></span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>        encoded_image <span class="op">=</span> request[<span class="st">"image"</span>]</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>        image_bytes <span class="op">=</span> base64.b64decode(encoded_image)</span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_bytes)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.preprocess(image).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> torch.nn.functional.softmax(output[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> probabilities</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get top 5 predictions</span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>        top5_prob, top5_catid <span class="op">=</span> torch.topk(output, <span class="dv">5</span>)</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(top5_prob.size(<span class="dv">0</span>)):</span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class"</span>: <span class="va">self</span>.classes[top5_catid[i]],</span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>                <span class="st">"probability"</span>: <span class="bu">float</span>(top5_prob[i])</span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"predictions"</span>: results}</span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> ImageClassificationAPI()</span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(api, accelerator<span class="op">=</span><span class="st">"auto"</span>, max_batch_size<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a>    server.run(port<span class="op">=</span><span class="dv">8000</span>)</span></code></pre></div></div>
</section>
<section id="text-embedding-api" class="level3">
<h3 class="anchored" data-anchor-id="text-embedding-api" id="text-embedding-api">2. Text Embedding API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sentence_transformers <span class="im">import</span> SentenceTransformer</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TextEmbeddingAPI(ls.LitAPI):</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> SentenceTransformer(<span class="st">'all-MiniLM-L6-v2'</span>)</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        texts <span class="op">=</span> request.get(<span class="st">"texts"</span>, [])</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(texts, <span class="bu">str</span>):</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>            texts <span class="op">=</span> [texts]</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> texts</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, texts):</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>        embeddings <span class="op">=</span> <span class="va">self</span>.model.encode(texts)</span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> embeddings</span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, embeddings):</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>            <span class="st">"embeddings"</span>: embeddings.tolist(),</span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>            <span class="st">"dimension"</span>: embeddings.shape[<span class="dv">1</span>] <span class="cf">if</span> <span class="bu">len</span>(embeddings.shape) <span class="op">&gt;</span> <span class="dv">1</span> <span class="cf">else</span> <span class="bu">len</span>(embeddings)</span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> TextEmbeddingAPI()</span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(api, accelerator<span class="op">=</span><span class="st">"auto"</span>, max_batch_size<span class="op">=</span><span class="dv">32</span>)</span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>    server.run(port<span class="op">=</span><span class="dv">8000</span>)</span></code></pre></div></div>
</section>
<section id="multi-modal-api-text-image" class="level3">
<h3 class="anchored" data-anchor-id="multi-modal-api-text-image" id="multi-modal-api-text-image">3. Multi-Modal API (Text + Image)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> BlipProcessor, BlipForConditionalGeneration</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ImageCaptioningAPI(ls.LitAPI):</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.processor <span class="op">=</span> BlipProcessor.from_pretrained(<span class="st">"Salesforce/blip-image-captioning-base"</span>)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> BlipForConditionalGeneration.from_pretrained(<span class="st">"Salesforce/blip-image-captioning-base"</span>)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Handle both image and optional text input</span></span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        encoded_image <span class="op">=</span> request[<span class="st">"image"</span>]</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        text_prompt <span class="op">=</span> request.get(<span class="st">"text"</span>, <span class="st">""</span>)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>        image_bytes <span class="op">=</span> base64.b64decode(encoded_image)</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_bytes)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"image"</span>: image, <span class="st">"text"</span>: text_prompt}</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, inputs):</span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>        processed <span class="op">=</span> <span class="va">self</span>.processor(</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>            images<span class="op">=</span>inputs[<span class="st">"image"</span>], </span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>            text<span class="op">=</span>inputs[<span class="st">"text"</span>], </span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>            return_tensors<span class="op">=</span><span class="st">"pt"</span></span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model.generate(<span class="op">**</span>processed, max_length<span class="op">=</span><span class="dv">50</span>)</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>        caption <span class="op">=</span> <span class="va">self</span>.processor.decode(outputs[<span class="dv">0</span>], skip_special_tokens<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> caption</span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, caption):</span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"caption"</span>: caption}</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="resource-management" class="level3">
<h3 class="anchored" data-anchor-id="resource-management" id="resource-management">1. Resource Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedAPI(ls.LitAPI):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use torch.jit.script for optimization</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> torch.jit.script(your_model)</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Enable mixed precision if using GPU</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> device.<span class="bu">type</span> <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.scaler <span class="op">=</span> torch.cuda.amp.GradScaler()</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Pre-allocate tensors for common shapes</span></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.common_shapes <span class="op">=</span> {}</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use autocast for mixed precision</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">hasattr</span>(<span class="va">self</span>, <span class="st">'scaler'</span>):</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.cuda.amp.autocast():</span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="va">self</span>.model(x)</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.model(x)</span></code></pre></div></div>
</section>
<section id="error-handling" class="level3">
<h3 class="anchored" data-anchor-id="error-handling" id="error-handling">2. Error Handling</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> RobustAPI(ls.LitAPI):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Validate required fields</span></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">"input"</span> <span class="kw">not</span> <span class="kw">in</span> request:</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> ls.ValidationError(<span class="st">"Missing required field: input"</span>)</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>            data <span class="op">=</span> request[<span class="st">"input"</span>]</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Type validation</span></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="kw">not</span> <span class="bu">isinstance</span>(data, (<span class="bu">str</span>, <span class="bu">list</span>)):</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> ls.ValidationError(<span class="st">"Input must be string or list"</span>)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> data</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> ls.ValidationError(<span class="ss">f"Request parsing failed: </span><span class="sc">{</span><span class="bu">str</span>(e)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> torch.cuda.OutOfMemoryError:</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> ls.ServerError(<span class="st">"GPU memory exhausted"</span>)</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> ls.ServerError(<span class="ss">f"Prediction failed: </span><span class="sc">{</span><span class="bu">str</span>(e)<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="monitoring-and-logging" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-and-logging" id="monitoring-and-logging">3. Monitoring and Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MonitoredAPI(ls.LitAPI):</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Your model setup</span></span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(<span class="ss">f"Processing request #</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>request_count<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> request[<span class="st">"data"</span>]</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>        inference_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(<span class="ss">f"Inference completed in </span><span class="sc">{</span>inference_time<span class="sc">:.3f}</span><span class="ss">s"</span>)</span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span></code></pre></div></div>
</section>
<section id="model-versioning" class="level3">
<h3 class="anchored" data-anchor-id="model-versioning" id="model-versioning">4. Model Versioning</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VersionedAPI(ls.LitAPI):</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.version <span class="op">=</span> <span class="st">"1.0.0"</span></span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model_path <span class="op">=</span> <span class="ss">f"models/model_v</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>version<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load versioned model</span></span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, output):</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>            <span class="st">"result"</span>: output,</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>            <span class="st">"model_version"</span>: <span class="va">self</span>.version,</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>            <span class="st">"timestamp"</span>: time.time()</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<section id="cuda-out-of-memory" class="level4">
<h4 class="anchored" data-anchor-id="cuda-out-of-memory">1. CUDA Out of Memory</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Solution: Reduce batch size or implement gradient checkpointing</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(api, max_batch_size<span class="op">=</span><span class="dv">2</span>)  <span class="co"># Reduce batch size</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Or clear cache in your predict method</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> predict(<span class="va">self</span>, x):</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    torch.cuda.empty_cache()  <span class="co"># Clear unused memory</span></span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> result</span></code></pre></div></div>
</section>
<section id="slow-inference" class="level4">
<h4 class="anchored" data-anchor-id="slow-inference">2. Slow Inference</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable model optimization</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.model.<span class="bu">eval</span>()  <span class="co"># Set to evaluation mode</span></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.model <span class="op">=</span> torch.jit.script(<span class="va">self</span>.model)  <span class="co"># JIT compilation</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use half precision if supported</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> device.<span class="bu">type</span> <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.half()</span></code></pre></div></div>
</section>
<section id="request-timeout" class="level4">
<h4 class="anchored" data-anchor-id="request-timeout">3. Request Timeout</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Increase timeout settings</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>server <span class="op">=</span> ls.LitServer(</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    api, </span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>    timeout<span class="op">=</span><span class="dv">60</span>,  <span class="co"># Increase request timeout</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>    batch_timeout<span class="op">=</span><span class="fl">1.0</span>  <span class="co"># Increase batch timeout</span></span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="port-already-in-use" class="level4">
<h4 class="anchored" data-anchor-id="port-already-in-use">4. Port Already in Use</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Check and kill existing processes</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> subprocess</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>subprocess.run([<span class="st">"lsof"</span>, <span class="st">"-ti:8000"</span>, <span class="st">"|"</span>, <span class="st">"xargs"</span>, <span class="st">"kill"</span>, <span class="st">"-9"</span>], shell<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Or use a different port</span></span>
<span id="cb23-6"><a href="#cb23-6" aria-hidden="true" tabindex="-1"></a>server.run(port<span class="op">=</span><span class="dv">8001</span>)</span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization-tips" class="level3">
<h3 class="anchored" data-anchor-id="performance-optimization-tips" id="performance-optimization-tips">Performance Optimization Tips</h3>
<ol type="1">
<li><strong>Use appropriate batch sizes</strong>: Start with small batches and gradually increase</li>
<li><strong>Enable GPU acceleration</strong>: Use <code>accelerator="auto"</code> for automatic GPU detection</li>
<li><strong>Optimize model loading</strong>: Load models once in <code>setup()</code>, not in <code>predict()</code></li>
<li><strong>Use mixed precision</strong>: Enable autocast for GPU inference</li>
<li><strong>Profile your code</strong>: Use tools like <code>torch.profiler</code> to identify bottlenecks</li>
<li><strong>Cache preprocessed data</strong>: Store frequently used transformations</li>
</ol>
</section>
<section id="deployment-checklist" class="level3">
<h3 class="anchored" data-anchor-id="deployment-checklist" id="deployment-checklist">Deployment Checklist</h3>
<ul class="task-list">
<li><label><input type="checkbox">Test API locally with various input types</label></li>
<li><label><input type="checkbox">Validate error handling for malformed requests</label></li>
<li><label><input type="checkbox">Check memory usage under load</label></li>
<li><label><input type="checkbox">Verify GPU utilization (if using GPUs)</label></li>
<li><label><input type="checkbox">Test with maximum expected batch size</label></li>
<li><label><input type="checkbox">Implement proper logging and monitoring</label></li>
<li><label><input type="checkbox">Set up health check endpoints</label></li>
<li><label><input type="checkbox">Configure appropriate timeouts</label></li>
<li><label><input type="checkbox">Test authentication (if implemented)</label></li>
<li><label><input type="checkbox">Verify response format consistency</label></li>
</ul>
<p>This guide covers the essential aspects of using LitServe for deploying AI models. For the most up-to-date information, always refer to the official LitServe documentation.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[LitServe with MobileNetV2 - Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/litserve-mobilenet/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/litserve-mobilenet/</guid>
      <pubDate>Tue, 27 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="litserve-with-mobilenetv2---complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/litserve-mobilenet/litlog.jpg" class="img-fluid"></p>
<p>This guide demonstrates how to deploy a MobileNetV2 image classification model using LitServe for efficient, scalable inference.</p>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install required packages</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install litserve torch torchvision pillow requests</span></code></pre></div></div>
</section>
<section id="basic-implementation" class="level2">
<h2 class="anchored" data-anchor-id="basic-implementation" id="basic-implementation">Basic Implementation</h2>
<section id="simple-mobilenetv2-api" class="level3">
<h3 class="anchored" data-anchor-id="simple-mobilenetv2-api" id="simple-mobilenetv2-api">Simple MobileNetV2 API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># mobilenet_api.py</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetV2API(ls.LitAPI):</span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize the model and preprocessing pipeline"""</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load pre-trained MobileNetV2</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define image preprocessing</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>                mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>                std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        <span class="co"># ImageNet class labels (first 10 for brevity)</span></span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.class_labels <span class="op">=</span> [</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>            <span class="st">"tench"</span>, <span class="st">"goldfish"</span>, <span class="st">"great white shark"</span>, <span class="st">"tiger shark"</span>,</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>            <span class="st">"hammerhead"</span>, <span class="st">"electric ray"</span>, <span class="st">"stingray"</span>, <span class="st">"cock"</span>, </span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>            <span class="st">"hen"</span>, <span class="st">"ostrich"</span></span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>            <span class="co"># ... add all 1000 ImageNet classes</span></span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Parse incoming request and prepare image"""</span></span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(request, <span class="bu">dict</span>) <span class="kw">and</span> <span class="st">"image"</span> <span class="kw">in</span> request:</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Handle base64 encoded image</span></span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>            <span class="im">import</span> base64</span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>            image_data <span class="op">=</span> base64.b64decode(request[<span class="st">"image"</span>])</span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Handle direct image upload</span></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(request)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image</span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, image):</span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Run inference on the preprocessed image"""</span></span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess image</span></span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> <span class="va">self</span>.transform(image).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>        input_tensor <span class="op">=</span> input_tensor.to(<span class="bu">next</span>(<span class="va">self</span>.model.parameters()).device)</span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run inference</span></span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(input_tensor)</span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> torch.nn.functional.softmax(outputs[<span class="dv">0</span>], dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> probabilities</span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, probabilities):</span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Format the response"""</span></span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get top 5 predictions</span></span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a>        top5_prob, top5_indices <span class="op">=</span> torch.topk(probabilities, <span class="dv">5</span>)</span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a>            idx <span class="op">=</span> top5_indices[i].item()</span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a>            prob <span class="op">=</span> top5_prob[i].item()</span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a>            label <span class="op">=</span> <span class="va">self</span>.class_labels[idx] <span class="cf">if</span> idx <span class="op">&lt;</span> <span class="bu">len</span>(<span class="va">self</span>.class_labels) <span class="cf">else</span> <span class="ss">f"class_</span><span class="sc">{</span>idx<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class"</span>: label,</span>
<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a>                <span class="st">"confidence"</span>: <span class="bu">round</span>(prob, <span class="dv">4</span>),</span>
<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class_id"</span>: idx</span>
<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a>            <span class="st">"predictions"</span>: results,</span>
<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a>            <span class="st">"model"</span>: <span class="st">"mobilenet_v2"</span></span>
<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create and run the API</span></span>
<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> MobileNetV2API()</span>
<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(api, accelerator<span class="op">=</span><span class="st">"auto"</span>, max_batch_size<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a>    server.run(port<span class="op">=</span><span class="dv">8000</span>)</span></code></pre></div></div>
</section>
<section id="client-code-for-testing" class="level3">
<h3 class="anchored" data-anchor-id="client-code-for-testing" id="client-code-for-testing">Client Code for Testing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># client.py</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> encode_image(image_path):</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Encode image to base64"""</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(image_path, <span class="st">"rb"</span>) <span class="im">as</span> image_file:</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> base64.b64encode(image_file.read()).decode(<span class="st">'utf-8'</span>)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test_image_classification(image_path, server_url<span class="op">=</span><span class="st">"http://localhost:8000/predict"</span>):</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Test the MobileNetV2 API"""</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Method 1: Send as base64 in JSON</span></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    encoded_image <span class="op">=</span> encode_image(image_path)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    response <span class="op">=</span> requests.post(</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        server_url,</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        json<span class="op">=</span>{<span class="st">"image"</span>: encoded_image}</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> response.status_code <span class="op">==</span> <span class="dv">200</span>:</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> response.json()</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Top 5 Predictions:"</span>)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> pred <span class="kw">in</span> result[<span class="st">"predictions"</span>]:</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  </span><span class="sc">{</span>pred[<span class="st">'class'</span>]<span class="sc">}</span><span class="ss">: </span><span class="sc">{</span>pred[<span class="st">'confidence'</span>]<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">else</span>:</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Error: </span><span class="sc">{</span>response<span class="sc">.</span>status_code<span class="sc">}</span><span class="ss"> - </span><span class="sc">{</span>response<span class="sc">.</span>text<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test_direct_upload(image_path, server_url<span class="op">=</span><span class="st">"http://localhost:8000/predict"</span>):</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Test with direct image upload"""</span></span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> <span class="bu">open</span>(image_path, <span class="st">'rb'</span>) <span class="im">as</span> f:</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        files <span class="op">=</span> {<span class="st">'file'</span>: f}</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        response <span class="op">=</span> requests.post(server_url, files<span class="op">=</span>files)</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> response.status_code <span class="op">==</span> <span class="dv">200</span>:</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> response.json()</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Predictions:"</span>, result)</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Test with a sample image</span></span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>    test_image_classification(<span class="st">"sample_image.jpg"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="batch-processing-with-custom-batching" class="level3">
<h3 class="anchored" data-anchor-id="batch-processing-with-custom-batching" id="batch-processing-with-custom-batching">Batch Processing with Custom Batching</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># advanced_mobilenet_api.py</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> List, Any</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> BatchedMobileNetV2API(ls.LitAPI):</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Initialize model with batch processing capabilities"""</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.to(device)</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Optimized transform for batch processing</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>                mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>                std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load class labels from file or define them</span></span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.load_imagenet_labels()</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_imagenet_labels(<span class="va">self</span>):</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load ImageNet class labels"""</span></span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># You can download from: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt</span></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(<span class="st">'imagenet_classes.txt'</span>, <span class="st">'r'</span>) <span class="im">as</span> f:</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.class_labels <span class="op">=</span> [line.strip() <span class="cf">for</span> line <span class="kw">in</span> f.readlines()]</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">FileNotFoundError</span>:</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Fallback to first few classes</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.class_labels <span class="op">=</span> [<span class="ss">f"class_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1000</span>)]</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> decode_request(<span class="va">self</span>, request):</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Handle both single images and batch requests"""</span></span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(request, <span class="bu">dict</span>):</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="st">"images"</span> <span class="kw">in</span> request:  <span class="co"># Batch request</span></span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>                images <span class="op">=</span> []</span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> img_data <span class="kw">in</span> request[<span class="st">"images"</span>]:</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">if</span> <span class="bu">isinstance</span>(img_data, <span class="bu">str</span>):  <span class="co"># base64</span></span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>                        <span class="im">import</span> base64</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>                        image_data <span class="op">=</span> base64.b64decode(img_data)</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>                        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">else</span>:  <span class="co"># direct bytes</span></span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>                        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(img_data)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>                    images.append(image)</span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> images</span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="st">"image"</span> <span class="kw">in</span> request:  <span class="co"># Single request</span></span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>                <span class="im">import</span> base64</span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>                image_data <span class="op">=</span> base64.b64decode(request[<span class="st">"image"</span>])</span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>                image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> [image]  <span class="co"># Wrap in list for consistent handling</span></span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Direct upload</span></span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(request)).convert(<span class="st">'RGB'</span>)</span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [image]</span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> batch(<span class="va">self</span>, inputs: List[Any]) <span class="op">-&gt;</span> List[Any]:</span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Custom batching logic"""</span></span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Flatten all images from all requests</span></span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>        all_images <span class="op">=</span> []</span>
<span id="cb4-71"><a href="#cb4-71" aria-hidden="true" tabindex="-1"></a>        batch_sizes <span class="op">=</span> []</span>
<span id="cb4-72"><a href="#cb4-72" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-73"><a href="#cb4-73" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> inp <span class="kw">in</span> inputs:</span>
<span id="cb4-74"><a href="#cb4-74" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(inp, <span class="bu">list</span>):</span>
<span id="cb4-75"><a href="#cb4-75" aria-hidden="true" tabindex="-1"></a>                all_images.extend(inp)</span>
<span id="cb4-76"><a href="#cb4-76" aria-hidden="true" tabindex="-1"></a>                batch_sizes.append(<span class="bu">len</span>(inp))</span>
<span id="cb4-77"><a href="#cb4-77" aria-hidden="true" tabindex="-1"></a>            <span class="cf">else</span>:</span>
<span id="cb4-78"><a href="#cb4-78" aria-hidden="true" tabindex="-1"></a>                all_images.append(inp)</span>
<span id="cb4-79"><a href="#cb4-79" aria-hidden="true" tabindex="-1"></a>                batch_sizes.append(<span class="dv">1</span>)</span>
<span id="cb4-80"><a href="#cb4-80" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-81"><a href="#cb4-81" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> all_images, batch_sizes</span>
<span id="cb4-82"><a href="#cb4-82" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-83"><a href="#cb4-83" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, batch_data):</span>
<span id="cb4-84"><a href="#cb4-84" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Process batch of images efficiently"""</span></span>
<span id="cb4-85"><a href="#cb4-85" aria-hidden="true" tabindex="-1"></a>        images, batch_sizes <span class="op">=</span> batch_data</span>
<span id="cb4-86"><a href="#cb4-86" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-87"><a href="#cb4-87" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess all images</span></span>
<span id="cb4-88"><a href="#cb4-88" aria-hidden="true" tabindex="-1"></a>        batch_tensor <span class="op">=</span> torch.stack([</span>
<span id="cb4-89"><a href="#cb4-89" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.transform(img) <span class="cf">for</span> img <span class="kw">in</span> images</span>
<span id="cb4-90"><a href="#cb4-90" aria-hidden="true" tabindex="-1"></a>        ]).to(<span class="va">self</span>.device)</span>
<span id="cb4-91"><a href="#cb4-91" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-92"><a href="#cb4-92" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Run batch inference</span></span>
<span id="cb4-93"><a href="#cb4-93" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb4-94"><a href="#cb4-94" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(batch_tensor)</span>
<span id="cb4-95"><a href="#cb4-95" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> torch.nn.functional.softmax(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb4-96"><a href="#cb4-96" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-97"><a href="#cb4-97" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> probabilities, batch_sizes</span>
<span id="cb4-98"><a href="#cb4-98" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-99"><a href="#cb4-99" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unbatch(<span class="va">self</span>, output):</span>
<span id="cb4-100"><a href="#cb4-100" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Split batch results back to individual responses"""</span></span>
<span id="cb4-101"><a href="#cb4-101" aria-hidden="true" tabindex="-1"></a>        probabilities, batch_sizes <span class="op">=</span> output</span>
<span id="cb4-102"><a href="#cb4-102" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb4-103"><a href="#cb4-103" aria-hidden="true" tabindex="-1"></a>        start_idx <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb4-104"><a href="#cb4-104" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-105"><a href="#cb4-105" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_size <span class="kw">in</span> batch_sizes:</span>
<span id="cb4-106"><a href="#cb4-106" aria-hidden="true" tabindex="-1"></a>            batch_probs <span class="op">=</span> probabilities[start_idx:start_idx <span class="op">+</span> batch_size]</span>
<span id="cb4-107"><a href="#cb4-107" aria-hidden="true" tabindex="-1"></a>            results.append(batch_probs)</span>
<span id="cb4-108"><a href="#cb4-108" aria-hidden="true" tabindex="-1"></a>            start_idx <span class="op">+=</span> batch_size</span>
<span id="cb4-109"><a href="#cb4-109" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-110"><a href="#cb4-110" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb4-111"><a href="#cb4-111" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-112"><a href="#cb4-112" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, probabilities):</span>
<span id="cb4-113"><a href="#cb4-113" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Format response for batch or single predictions"""</span></span>
<span id="cb4-114"><a href="#cb4-114" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(probabilities.shape) <span class="op">==</span> <span class="dv">1</span>:  <span class="co"># Single prediction</span></span>
<span id="cb4-115"><a href="#cb4-115" aria-hidden="true" tabindex="-1"></a>            probabilities <span class="op">=</span> probabilities.unsqueeze(<span class="dv">0</span>)</span>
<span id="cb4-116"><a href="#cb4-116" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-117"><a href="#cb4-117" aria-hidden="true" tabindex="-1"></a>        all_results <span class="op">=</span> []</span>
<span id="cb4-118"><a href="#cb4-118" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> prob_vector <span class="kw">in</span> probabilities:</span>
<span id="cb4-119"><a href="#cb4-119" aria-hidden="true" tabindex="-1"></a>            top5_prob, top5_indices <span class="op">=</span> torch.topk(prob_vector, <span class="dv">5</span>)</span>
<span id="cb4-120"><a href="#cb4-120" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-121"><a href="#cb4-121" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> []</span>
<span id="cb4-122"><a href="#cb4-122" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>):</span>
<span id="cb4-123"><a href="#cb4-123" aria-hidden="true" tabindex="-1"></a>                idx <span class="op">=</span> top5_indices[i].item()</span>
<span id="cb4-124"><a href="#cb4-124" aria-hidden="true" tabindex="-1"></a>                prob <span class="op">=</span> top5_prob[i].item()</span>
<span id="cb4-125"><a href="#cb4-125" aria-hidden="true" tabindex="-1"></a>                predictions.append({</span>
<span id="cb4-126"><a href="#cb4-126" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"class"</span>: <span class="va">self</span>.class_labels[idx],</span>
<span id="cb4-127"><a href="#cb4-127" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"confidence"</span>: <span class="bu">round</span>(prob, <span class="dv">4</span>),</span>
<span id="cb4-128"><a href="#cb4-128" aria-hidden="true" tabindex="-1"></a>                    <span class="st">"class_id"</span>: idx</span>
<span id="cb4-129"><a href="#cb4-129" aria-hidden="true" tabindex="-1"></a>                })</span>
<span id="cb4-130"><a href="#cb4-130" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-131"><a href="#cb4-131" aria-hidden="true" tabindex="-1"></a>            all_results.append(predictions)</span>
<span id="cb4-132"><a href="#cb4-132" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-133"><a href="#cb4-133" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb4-134"><a href="#cb4-134" aria-hidden="true" tabindex="-1"></a>            <span class="st">"predictions"</span>: all_results[<span class="dv">0</span>] <span class="cf">if</span> <span class="bu">len</span>(all_results) <span class="op">==</span> <span class="dv">1</span> <span class="cf">else</span> all_results,</span>
<span id="cb4-135"><a href="#cb4-135" aria-hidden="true" tabindex="-1"></a>            <span class="st">"model"</span>: <span class="st">"mobilenet_v2"</span>,</span>
<span id="cb4-136"><a href="#cb4-136" aria-hidden="true" tabindex="-1"></a>            <span class="st">"batch_size"</span>: <span class="bu">len</span>(all_results)</span>
<span id="cb4-137"><a href="#cb4-137" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb4-138"><a href="#cb4-138" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-139"><a href="#cb4-139" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-140"><a href="#cb4-140" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb4-141"><a href="#cb4-141" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> BatchedMobileNetV2API()</span>
<span id="cb4-142"><a href="#cb4-142" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(</span>
<span id="cb4-143"><a href="#cb4-143" aria-hidden="true" tabindex="-1"></a>        api,</span>
<span id="cb4-144"><a href="#cb4-144" aria-hidden="true" tabindex="-1"></a>        accelerator<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb4-145"><a href="#cb4-145" aria-hidden="true" tabindex="-1"></a>        max_batch_size<span class="op">=</span><span class="dv">8</span>,</span>
<span id="cb4-146"><a href="#cb4-146" aria-hidden="true" tabindex="-1"></a>        batch_timeout<span class="op">=</span><span class="fl">0.1</span>,  <span class="co"># 100ms timeout for batching</span></span>
<span id="cb4-147"><a href="#cb4-147" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb4-148"><a href="#cb4-148" aria-hidden="true" tabindex="-1"></a>    server.run(port<span class="op">=</span><span class="dv">8000</span>, num_workers<span class="op">=</span><span class="dv">2</span>)</span></code></pre></div></div>
</section>
<section id="adding-model-quantization-for-better-performance" class="level3">
<h3 class="anchored" data-anchor-id="adding-model-quantization-for-better-performance" id="adding-model-quantization-for-better-performance">Adding Model Quantization for Better Performance</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># quantized_mobilenet_api.py</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.quantization</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> QuantizedMobileNetV2API(ls.LitAPI):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup quantized MobileNetV2 for faster inference"""</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load pre-trained model</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model.<span class="bu">eval</span>()</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply dynamic quantization for CPU inference</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> device <span class="op">==</span> <span class="st">"cpu"</span>:</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model <span class="op">=</span> torch.quantization.quantize_dynamic(</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.model,</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>                {torch.nn.Linear, torch.nn.Conv2d},</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>                dtype<span class="op">=</span>torch.qint8</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Applied dynamic quantization for CPU"</span>)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.to(device)</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Using device: </span><span class="sc">{</span>device<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Rest of setup code...</span></span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(<span class="dv">224</span>),</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(</span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>                mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>                std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ... rest of the methods remain the same</span></span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">Performance Optimization</h2>
<section id="configuration-for-production" class="level3">
<h3 class="anchored" data-anchor-id="configuration-for-production" id="configuration-for-production">Configuration for Production</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># production_config.py</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> advanced_mobilenet_api <span class="im">import</span> BatchedMobileNetV2API</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_production_server():</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Create optimized server for production"""</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    api <span class="op">=</span> BatchedMobileNetV2API()</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> ls.LitServer(</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        api,</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        accelerator<span class="op">=</span><span class="st">"auto"</span>,  <span class="co"># Auto-detect GPU/CPU</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        max_batch_size<span class="op">=</span><span class="dv">16</span>,   <span class="co"># Larger batches for throughput</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        batch_timeout<span class="op">=</span><span class="fl">0.05</span>,  <span class="co"># 50ms batching timeout</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        workers_per_device<span class="op">=</span><span class="dv">2</span>,  <span class="co"># Multiple workers per GPU</span></span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>        timeout<span class="op">=</span><span class="dv">30</span>,          <span class="co"># Request timeout</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> server</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>    server <span class="op">=</span> create_production_server()</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    server.run(</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        port<span class="op">=</span><span class="dv">8000</span>,</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        host<span class="op">=</span><span class="st">"0.0.0.0"</span>,  <span class="co"># Accept external connections</span></span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        num_workers<span class="op">=</span><span class="dv">4</span>     <span class="co"># Total number of workers</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    )</span></code></pre></div></div>
</section>
<section id="monitoring-and-logging" class="level3">
<h3 class="anchored" data-anchor-id="monitoring-and-logging" id="monitoring-and-logging">Monitoring and Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># monitored_api.py</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> defaultdict</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> litserve <span class="im">as</span> ls</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MonitoredMobileNetV2API(BatchedMobileNetV2API):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, device):</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup with monitoring"""</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().setup(device)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Setup logging</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>        logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Metrics tracking</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics <span class="op">=</span> defaultdict(<span class="bu">list</span>)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_count <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, batch_data):</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Predict with timing and metrics"""</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>        start_time <span class="op">=</span> time.time()</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> <span class="bu">super</span>().predict(batch_data)</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log timing</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        inference_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> <span class="bu">len</span>(batch_data[<span class="dv">0</span>])</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[<span class="st">'inference_times'</span>].append(inference_time)</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.metrics[<span class="st">'batch_sizes'</span>].append(batch_size)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.request_count <span class="op">+=</span> <span class="dv">1</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.info(</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"Processed batch of </span><span class="sc">{</span>batch_size<span class="sc">}</span><span class="ss"> images in </span><span class="sc">{</span>inference_time<span class="sc">:.3f}</span><span class="ss">s "</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"(Total requests: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>request_count<span class="sc">}</span><span class="ss">)"</span></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> result</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> encode_response(<span class="va">self</span>, probabilities):</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Add metrics to response"""</span></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>        response <span class="op">=</span> <span class="bu">super</span>().encode_response(probabilities)</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add performance metrics every 100 requests</span></span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.request_count <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>            avg_time <span class="op">=</span> <span class="bu">sum</span>(<span class="va">self</span>.metrics[<span class="st">'inference_times'</span>][<span class="op">-</span><span class="dv">100</span>:]) <span class="op">/</span> <span class="bu">min</span>(<span class="dv">100</span>, <span class="bu">len</span>(<span class="va">self</span>.metrics[<span class="st">'inference_times'</span>]))</span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a>            response[<span class="st">'metrics'</span>] <span class="op">=</span> {</span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a>                <span class="st">'avg_inference_time'</span>: <span class="bu">round</span>(avg_time, <span class="dv">4</span>),</span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>                <span class="st">'total_requests'</span>: <span class="va">self</span>.request_count</span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> response</span></code></pre></div></div>
</section>
</section>
<section id="deployment" class="level2">
<h2 class="anchored" data-anchor-id="deployment" id="deployment">Deployment</h2>
<section id="docker-deployment" class="level3">
<h3 class="anchored" data-anchor-id="docker-deployment" id="docker-deployment">Docker Deployment</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode dockerfile code-with-copy"><code class="sourceCode dockerfile"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Dockerfile</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> python:3.9-slim</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Install system dependencies</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">apt-get</span> update <span class="kw">&amp;&amp;</span> <span class="ex">apt-get</span> install <span class="at">-y</span> <span class="dt">\</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    wget <span class="dt">\</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">&amp;&amp;</span> <span class="fu">rm</span> <span class="at">-rf</span> /var/lib/apt/lists/<span class="pp">*</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy requirements</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> requirements.txt .</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="at">-r</span> requirements.txt</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy application code</span></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> . .</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Download ImageNet labels</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="fu">wget</span> <span class="at">-O</span> imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="kw">EXPOSE</span> 8000</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="kw">CMD</span> [<span class="st">"python"</span>, <span class="st">"production_config.py"</span>]</span></code></pre></div></div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># docker-compose.yml</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="fu">version</span><span class="kw">:</span><span class="at"> </span><span class="st">'3.8'</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="fu">services</span><span class="kw">:</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">mobilenet-api</span><span class="kw">:</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">build</span><span class="kw">:</span><span class="at"> .</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"8000:8000"</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> CUDA_VISIBLE_DEVICES=0</span><span class="co">  # Set GPU if available</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ./models:/app/models</span><span class="co">  # Optional: for custom models</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">restart</span><span class="kw">:</span><span class="at"> unless-stopped</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="co">  # Optional: Add nginx for load balancing</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">nginx</span><span class="kw">:</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> nginx:alpine</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"80:80"</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ./nginx.conf:/etc/nginx/nginx.conf</span></span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">depends_on</span><span class="kw">:</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> mobilenet-api</span></span></code></pre></div></div>
</section>
<section id="kubernetes-deployment" class="level3">
<h3 class="anchored" data-anchor-id="kubernetes-deployment" id="kubernetes-deployment">Kubernetes Deployment</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># k8s-deployment.yaml</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> apps/v1</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Deployment</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenet-api</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">selector</span><span class="kw">:</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">matchLabels</span><span class="kw">:</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenet-api</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">labels</span><span class="kw">:</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenet-api</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenet-api</span></span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">image</span><span class="kw">:</span><span class="at"> your-registry/mobilenet-api:latest</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="kw">-</span><span class="at"> </span><span class="fu">containerPort</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"1Gi"</span></span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"500m"</span></span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb10-26"><a href="#cb10-26" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"2Gi"</span></span>
<span id="cb10-27"><a href="#cb10-27" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1000m"</span></span>
<span id="cb10-28"><a href="#cb10-28" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb10-29"><a href="#cb10-29" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> WORKERS</span></span>
<span id="cb10-30"><a href="#cb10-30" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">value</span><span class="kw">:</span><span class="at"> </span><span class="st">"2"</span></span>
<span id="cb10-31"><a href="#cb10-31" aria-hidden="true" tabindex="-1"></a><span class="pp">---</span></span>
<span id="cb10-32"><a href="#cb10-32" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> v1</span></span>
<span id="cb10-33"><a href="#cb10-33" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Service</span></span>
<span id="cb10-34"><a href="#cb10-34" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb10-35"><a href="#cb10-35" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenet-service</span></span>
<span id="cb10-36"><a href="#cb10-36" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb10-37"><a href="#cb10-37" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">selector</span><span class="kw">:</span></span>
<span id="cb10-38"><a href="#cb10-38" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenet-api</span></span>
<span id="cb10-39"><a href="#cb10-39" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb10-40"><a href="#cb10-40" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="kw">-</span><span class="at"> </span><span class="fu">port</span><span class="kw">:</span><span class="at"> </span><span class="dv">80</span></span>
<span id="cb10-41"><a href="#cb10-41" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">targetPort</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
<span id="cb10-42"><a href="#cb10-42" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">type</span><span class="kw">:</span><span class="at"> LoadBalancer</span></span></code></pre></div></div>
</section>
</section>
<section id="testing" class="level2">
<h2 class="anchored" data-anchor-id="testing" id="testing">Testing</h2>
<section id="comprehensive-test-suite" class="level3">
<h3 class="anchored" data-anchor-id="comprehensive-test-suite" id="comprehensive-test-suite">Comprehensive Test Suite</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># test_api.py</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytest</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TestMobileNetV2API:</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup_class(<span class="va">self</span>):</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Setup test configuration"""</span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.base_url <span class="op">=</span> <span class="st">"http://localhost:8000"</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.test_image <span class="op">=</span> <span class="va">self</span>.create_test_image()</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> create_test_image(<span class="va">self</span>):</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Create a test image"""</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create a simple test image</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        img <span class="op">=</span> Image.new(<span class="st">'RGB'</span>, (<span class="dv">224</span>, <span class="dv">224</span>), color<span class="op">=</span><span class="st">'red'</span>)</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        img_buffer <span class="op">=</span> io.BytesIO()</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        img.save(img_buffer, <span class="bu">format</span><span class="op">=</span><span class="st">'JPEG'</span>)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        img_buffer.seek(<span class="dv">0</span>)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> img_buffer.getvalue()</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_single_prediction(<span class="va">self</span>):</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Test single image prediction"""</span></span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        encoded_image <span class="op">=</span> base64.b64encode(<span class="va">self</span>.test_image).decode(<span class="st">'utf-8'</span>)</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>        response <span class="op">=</span> requests.post(</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>base_url<span class="sc">}</span><span class="ss">/predict"</span>,</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>            json<span class="op">=</span>{<span class="st">"image"</span>: encoded_image}</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> response.status_code <span class="op">==</span> <span class="dv">200</span></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> response.json()</span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> <span class="st">"predictions"</span> <span class="kw">in</span> result</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> <span class="bu">len</span>(result[<span class="st">"predictions"</span>]) <span class="op">==</span> <span class="dv">5</span></span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> <span class="bu">all</span>(<span class="st">"class"</span> <span class="kw">in</span> pred <span class="kw">and</span> <span class="st">"confidence"</span> <span class="kw">in</span> pred <span class="cf">for</span> pred <span class="kw">in</span> result[<span class="st">"predictions"</span>])</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_batch_prediction(<span class="va">self</span>):</span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Test batch prediction"""</span></span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        encoded_image <span class="op">=</span> base64.b64encode(<span class="va">self</span>.test_image).decode(<span class="st">'utf-8'</span>)</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>        response <span class="op">=</span> requests.post(</span>
<span id="cb11-45"><a href="#cb11-45" aria-hidden="true" tabindex="-1"></a>            <span class="ss">f"</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>base_url<span class="sc">}</span><span class="ss">/predict"</span>,</span>
<span id="cb11-46"><a href="#cb11-46" aria-hidden="true" tabindex="-1"></a>            json<span class="op">=</span>{<span class="st">"images"</span>: [encoded_image, encoded_image]}</span>
<span id="cb11-47"><a href="#cb11-47" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb11-48"><a href="#cb11-48" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-49"><a href="#cb11-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> response.status_code <span class="op">==</span> <span class="dv">200</span></span>
<span id="cb11-50"><a href="#cb11-50" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> response.json()</span>
<span id="cb11-51"><a href="#cb11-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> <span class="st">"batch_size"</span> <span class="kw">in</span> result</span>
<span id="cb11-52"><a href="#cb11-52" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> result[<span class="st">"batch_size"</span>] <span class="op">==</span> <span class="dv">2</span></span>
<span id="cb11-53"><a href="#cb11-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-54"><a href="#cb11-54" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_performance(<span class="va">self</span>):</span>
<span id="cb11-55"><a href="#cb11-55" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Test API performance"""</span></span>
<span id="cb11-56"><a href="#cb11-56" aria-hidden="true" tabindex="-1"></a>        <span class="im">import</span> time</span>
<span id="cb11-57"><a href="#cb11-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-58"><a href="#cb11-58" aria-hidden="true" tabindex="-1"></a>        encoded_image <span class="op">=</span> base64.b64encode(<span class="va">self</span>.test_image).decode(<span class="st">'utf-8'</span>)</span>
<span id="cb11-59"><a href="#cb11-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-60"><a href="#cb11-60" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Warmup</span></span>
<span id="cb11-61"><a href="#cb11-61" aria-hidden="true" tabindex="-1"></a>        requests.post(<span class="ss">f"</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>base_url<span class="sc">}</span><span class="ss">/predict"</span>, json<span class="op">=</span>{<span class="st">"image"</span>: encoded_image})</span>
<span id="cb11-62"><a href="#cb11-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-63"><a href="#cb11-63" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Time multiple requests</span></span>
<span id="cb11-64"><a href="#cb11-64" aria-hidden="true" tabindex="-1"></a>        times <span class="op">=</span> []</span>
<span id="cb11-65"><a href="#cb11-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):</span>
<span id="cb11-66"><a href="#cb11-66" aria-hidden="true" tabindex="-1"></a>            start <span class="op">=</span> time.time()</span>
<span id="cb11-67"><a href="#cb11-67" aria-hidden="true" tabindex="-1"></a>            response <span class="op">=</span> requests.post(<span class="ss">f"</span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>base_url<span class="sc">}</span><span class="ss">/predict"</span>, json<span class="op">=</span>{<span class="st">"image"</span>: encoded_image})</span>
<span id="cb11-68"><a href="#cb11-68" aria-hidden="true" tabindex="-1"></a>            times.append(time.time() <span class="op">-</span> start)</span>
<span id="cb11-69"><a href="#cb11-69" aria-hidden="true" tabindex="-1"></a>            <span class="cf">assert</span> response.status_code <span class="op">==</span> <span class="dv">200</span></span>
<span id="cb11-70"><a href="#cb11-70" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-71"><a href="#cb11-71" aria-hidden="true" tabindex="-1"></a>        avg_time <span class="op">=</span> <span class="bu">sum</span>(times) <span class="op">/</span> <span class="bu">len</span>(times)</span>
<span id="cb11-72"><a href="#cb11-72" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Average response time: </span><span class="sc">{</span>avg_time<span class="sc">:.3f}</span><span class="ss">s"</span>)</span>
<span id="cb11-73"><a href="#cb11-73" aria-hidden="true" tabindex="-1"></a>        <span class="cf">assert</span> avg_time <span class="op">&lt;</span> <span class="fl">1.0</span>  <span class="co"># Should respond within 1 second</span></span>
<span id="cb11-74"><a href="#cb11-74" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-75"><a href="#cb11-75" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-76"><a href="#cb11-76" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb11-77"><a href="#cb11-77" aria-hidden="true" tabindex="-1"></a>    pytest.main([<span class="va">__file__</span>, <span class="st">"-v"</span>])</span></code></pre></div></div>
</section>
<section id="load-testing" class="level3">
<h3 class="anchored" data-anchor-id="load-testing" id="load-testing">Load Testing</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># load_test.py</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> asyncio</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> aiohttp</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> base64</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> send_request(session, url, data):</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Send a single request"""</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">async</span> <span class="cf">with</span> session.post(url, json<span class="op">=</span>data) <span class="im">as</span> response:</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> <span class="cf">await</span> response.json()</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> response.status, time.time()</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="dv">500</span>, time.time()</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> load_test(num_requests<span class="op">=</span><span class="dv">100</span>, concurrent<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Run load test"""</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create test image</span></span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>    img <span class="op">=</span> Image.new(<span class="st">'RGB'</span>, (<span class="dv">224</span>, <span class="dv">224</span>), color<span class="op">=</span><span class="st">'blue'</span>)</span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>    img_buffer <span class="op">=</span> io.BytesIO()</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>    img.save(img_buffer, <span class="bu">format</span><span class="op">=</span><span class="st">'JPEG'</span>)</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>    encoded_image <span class="op">=</span> base64.b64encode(img_buffer.getvalue()).decode(<span class="st">'utf-8'</span>)</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>    url <span class="op">=</span> <span class="st">"http://localhost:8000/predict"</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> {<span class="st">"image"</span>: encoded_image}</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>    <span class="cf">async</span> <span class="cf">with</span> aiohttp.ClientSession() <span class="im">as</span> session:</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create semaphore to limit concurrent requests</span></span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>        semaphore <span class="op">=</span> asyncio.Semaphore(concurrent)</span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">async</span> <span class="kw">def</span> bounded_request():</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>            <span class="cf">async</span> <span class="cf">with</span> semaphore:</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>                <span class="cf">return</span> <span class="cf">await</span> send_request(session, url, data)</span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Send all requests</span></span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>        tasks <span class="op">=</span> [bounded_request() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(num_requests)]</span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> <span class="cf">await</span> asyncio.gather(<span class="op">*</span>tasks)</span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-45"><a href="#cb12-45" aria-hidden="true" tabindex="-1"></a>    total_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb12-46"><a href="#cb12-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-47"><a href="#cb12-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Analyze results</span></span>
<span id="cb12-48"><a href="#cb12-48" aria-hidden="true" tabindex="-1"></a>    successful <span class="op">=</span> <span class="bu">sum</span>(<span class="dv">1</span> <span class="cf">for</span> status, _ <span class="kw">in</span> results <span class="cf">if</span> status <span class="op">==</span> <span class="dv">200</span>)</span>
<span id="cb12-49"><a href="#cb12-49" aria-hidden="true" tabindex="-1"></a>    failed <span class="op">=</span> num_requests <span class="op">-</span> successful</span>
<span id="cb12-50"><a href="#cb12-50" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-51"><a href="#cb12-51" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Load Test Results:"</span>)</span>
<span id="cb12-52"><a href="#cb12-52" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Total Requests: </span><span class="sc">{</span>num_requests<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-53"><a href="#cb12-53" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Successful: </span><span class="sc">{</span>successful<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-54"><a href="#cb12-54" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Failed: </span><span class="sc">{</span>failed<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-55"><a href="#cb12-55" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Total Time: </span><span class="sc">{</span>total_time<span class="sc">:.2f}</span><span class="ss">s"</span>)</span>
<span id="cb12-56"><a href="#cb12-56" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Requests/sec: </span><span class="sc">{</span>num_requests<span class="op">/</span>total_time<span class="sc">:.2f}</span><span class="ss">"</span>)</span>
<span id="cb12-57"><a href="#cb12-57" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"  Concurrent: </span><span class="sc">{</span>concurrent<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb12-58"><a href="#cb12-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-59"><a href="#cb12-59" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-60"><a href="#cb12-60" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb12-61"><a href="#cb12-61" aria-hidden="true" tabindex="-1"></a>    asyncio.run(load_test(num_requests<span class="op">=</span><span class="dv">200</span>, concurrent<span class="op">=</span><span class="dv">20</span>))</span></code></pre></div></div>
</section>
</section>
<section id="requirements-file" class="level2">
<h2 class="anchored" data-anchor-id="requirements-file" id="requirements-file">Requirements File</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode txt code-with-copy"><code class="sourceCode default"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a># requirements.txt</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>litserve&gt;=0.2.0</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>torch&gt;=1.9.0</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>torchvision&gt;=0.10.0</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>Pillow&gt;=8.0.0</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>requests&gt;=2.25.0</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>numpy&gt;=1.21.0</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>aiohttp&gt;=3.8.0  # For async testing</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>pytest&gt;=6.0.0  # For testing</span></code></pre></div></div>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<ol type="1">
<li><strong>Model Optimization</strong>: Use quantization and TorchScript for production</li>
<li><strong>Batch Processing</strong>: Configure appropriate batch sizes based on your hardware</li>
<li><strong>Error Handling</strong>: Implement comprehensive error handling for robustness</li>
<li><strong>Monitoring</strong>: Add logging and metrics collection for production monitoring</li>
<li><strong>Security</strong>: Implement authentication and input validation for production APIs</li>
<li><strong>Caching</strong>: Consider caching frequently requested predictions</li>
<li><strong>Scaling</strong>: Use container orchestration for high-availability deployments</li>
</ol>
<p>This guide provides a complete foundation for deploying MobileNetV2 with LitServe, from basic implementation to production-ready deployment with monitoring and testing.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Albumentations vs TorchVision Transforms: Complete Code Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/albumentations-vs-torchvision/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/albumentations-vs-torchvision/</guid>
      <pubDate>Tue, 27 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="albumentations-vs-torchvision-transforms-complete-code-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/albumentations-vs-torchvision/albumentations.png" class="img-fluid"></p>
<section id="overview" class="level2">
<h2 class="anchored" data-anchor-id="overview" id="overview">Overview</h2>
<p>This guide compares two popular image augmentation libraries for PyTorch:</p>
<ul>
<li><strong>TorchVision Transforms</strong>: Built-in PyTorch library for basic image transformations</li>
<li><strong>Albumentations</strong>: Fast, flexible library with advanced augmentation techniques</li>
</ul>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision (comes with PyTorch)</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install torch torchvision</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Albumentations</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install albumentations</span></code></pre></div></div>
</section>
<section id="basic-setup-and-imports" class="level2">
<h2 class="anchored" data-anchor-id="basic-setup-and-imports" id="basic-setup-and-imports">Basic Setup and Imports</h2>
<div id="a115ff4a" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.transforms <span class="im">import</span> functional <span class="im">as</span> TF</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> albumentations <span class="im">as</span> A</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> albumentations.pytorch <span class="im">import</span> ToTensorV2</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> cv2</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span></code></pre></div></div>
</div>
</section>
<section id="key-differences-at-a-glance" class="level2">
<h2 class="anchored" data-anchor-id="key-differences-at-a-glance" id="key-differences-at-a-glance">Key Differences at a Glance</h2>
<table class="caption-top table">
<colgroup>
<col style="width: 23%">
<col style="width: 34%">
<col style="width: 42%">
</colgroup>
<thead>
<tr class="header">
<th>Feature</th>
<th>TorchVision</th>
<th>Albumentations</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Input Format</strong></td>
<td>PIL Image, Tensor</td>
<td>NumPy array (OpenCV format)</td>
</tr>
<tr class="even">
<td><strong>Performance</strong></td>
<td>Moderate</td>
<td>Fast (optimized)</td>
</tr>
<tr class="odd">
<td><strong>Augmentation Variety</strong></td>
<td>Basic to intermediate</td>
<td>Extensive advanced options</td>
</tr>
<tr class="even">
<td><strong>Bounding Box Support</strong></td>
<td>Limited</td>
<td>Excellent</td>
</tr>
<tr class="odd">
<td><strong>Segmentation Masks</strong></td>
<td>Basic</td>
<td>Advanced</td>
</tr>
<tr class="even">
<td><strong>Keypoint Support</strong></td>
<td>No</td>
<td>Yes</td>
</tr>
<tr class="odd">
<td><strong>Probability Control</strong></td>
<td>Limited</td>
<td>Fine-grained</td>
</tr>
</tbody>
</table>
</section>
<section id="basic-transformations-comparison" class="level2">
<h2 class="anchored" data-anchor-id="basic-transformations-comparison" id="basic-transformations-comparison">Basic Transformations Comparison</h2>
<section id="image-loading-and-format-differences" class="level3">
<h3 class="anchored" data-anchor-id="image-loading-and-format-differences" id="image-loading-and-format-differences">Image Loading and Format Differences</h3>
<div id="d3c47867" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision approach</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_image_torchvision(path):</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> Image.<span class="bu">open</span>(path).convert(<span class="st">'RGB'</span>)</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Albumentations approach  </span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> load_image_albumentations(path):</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> cv2.imread(path)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> cv2.cvtColor(image, cv2.COLOR_BGR2RGB)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>img_path <span class="op">=</span> <span class="st">"cat.jpg"</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>torch_img <span class="op">=</span> load_image_torchvision(img_path)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>albu_img <span class="op">=</span> load_image_albumentations(img_path)</span></code></pre></div></div>
</div>
<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/albumentations-vs-torchvision/cat.jpg" class="img-fluid"></p>
</section>
<section id="basic-augmentations" class="level3">
<h3 class="anchored" data-anchor-id="basic-augmentations" id="basic-augmentations">Basic Augmentations</h3>
<div id="c0dc26eb" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision transforms</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>torchvision_transform <span class="op">=</span> T.Compose([</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    T.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    T.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>    T.RandomRotation(degrees<span class="op">=</span><span class="dv">15</span>),</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>    T.ColorJitter(brightness<span class="op">=</span><span class="fl">0.2</span>, contrast<span class="op">=</span><span class="fl">0.2</span>, saturation<span class="op">=</span><span class="fl">0.2</span>, hue<span class="op">=</span><span class="fl">0.1</span>),</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    T.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Albumentations equivalent</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>albumentations_transform <span class="op">=</span> A.Compose([</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    A.Resize(<span class="dv">224</span>, <span class="dv">224</span>),</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    A.HorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    A.Rotate(limit<span class="op">=</span><span class="dv">15</span>, p<span class="op">=</span><span class="fl">1.0</span>),</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    A.ColorJitter(brightness<span class="op">=</span><span class="fl">0.2</span>, contrast<span class="op">=</span><span class="fl">0.2</span>, saturation<span class="op">=</span><span class="fl">0.2</span>, hue<span class="op">=</span><span class="fl">0.1</span>, p<span class="op">=</span><span class="fl">1.0</span>),</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    A.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    ToTensorV2()</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply transforms</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>torch_result <span class="op">=</span> torchvision_transform(torch_img)</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>albu_result <span class="op">=</span> albumentations_transform(image<span class="op">=</span>albu_img)[<span class="st">'image'</span>]</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-augmentations" class="level2">
<h2 class="anchored" data-anchor-id="advanced-augmentations" id="advanced-augmentations">Advanced Augmentations</h2>
<section id="albumentations-exclusive-features" class="level3">
<h3 class="anchored" data-anchor-id="albumentations-exclusive-features" id="albumentations-exclusive-features">Albumentations Exclusive Features</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Advanced geometric transformations</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>advanced_geometric <span class="op">=</span> A.Compose([</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>    A.ShiftScaleRotate(shift_limit<span class="op">=</span><span class="fl">0.1</span>, scale_limit<span class="op">=</span><span class="fl">0.2</span>, rotate_limit<span class="op">=</span><span class="dv">30</span>, p<span class="op">=</span><span class="fl">0.8</span>),</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    A.ElasticTransform(alpha<span class="op">=</span><span class="dv">1</span>, sigma<span class="op">=</span><span class="dv">50</span>, alpha_affine<span class="op">=</span><span class="dv">50</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    A.GridDistortion(num_steps<span class="op">=</span><span class="dv">5</span>, distort_limit<span class="op">=</span><span class="fl">0.3</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    A.OpticalDistortion(distort_limit<span class="op">=</span><span class="fl">0.5</span>, shift_limit<span class="op">=</span><span class="fl">0.5</span>, p<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Weather and lighting effects</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>weather_effects <span class="op">=</span> A.Compose([</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    A.RandomRain(slant_lower<span class="op">=-</span><span class="dv">10</span>, slant_upper<span class="op">=</span><span class="dv">10</span>, drop_length<span class="op">=</span><span class="dv">20</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    A.RandomSnow(snow_point_lower<span class="op">=</span><span class="fl">0.1</span>, snow_point_upper<span class="op">=</span><span class="fl">0.3</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    A.RandomFog(fog_coef_lower<span class="op">=</span><span class="fl">0.3</span>, fog_coef_upper<span class="op">=</span><span class="dv">1</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    A.RandomSunFlare(flare_roi<span class="op">=</span>(<span class="dv">0</span>, <span class="dv">0</span>, <span class="dv">1</span>, <span class="fl">0.5</span>), angle_lower<span class="op">=</span><span class="dv">0</span>, p<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Noise and blur effects</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>noise_blur <span class="op">=</span> A.Compose([</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    A.GaussNoise(var_limit<span class="op">=</span>(<span class="fl">10.0</span>, <span class="fl">50.0</span>), p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    A.ISONoise(color_shift<span class="op">=</span>(<span class="fl">0.01</span>, <span class="fl">0.05</span>), intensity<span class="op">=</span>(<span class="fl">0.1</span>, <span class="fl">0.5</span>), p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>    A.MotionBlur(blur_limit<span class="op">=</span><span class="dv">7</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    A.MedianBlur(blur_limit<span class="op">=</span><span class="dv">7</span>, p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    A.Blur(blur_limit<span class="op">=</span><span class="dv">7</span>, p<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</section>
<section id="torchvision-v2-enhanced-features" class="level3">
<h3 class="anchored" data-anchor-id="torchvision-v2-enhanced-features" id="torchvision-v2-enhanced-features">TorchVision v2 Enhanced Features</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms.v2 <span class="im">as</span> T2</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision v2 with better functionality</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>torchvision_v2_transform <span class="op">=</span> T2.Compose([</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    T2.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    T2.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    T2.RandomChoice([</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        T2.ColorJitter(brightness<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        T2.ColorJitter(contrast<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        T2.ColorJitter(saturation<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    ]),</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    T2.RandomApply([T2.GaussianBlur(kernel_size<span class="op">=</span><span class="dv">3</span>)], p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    T2.ToTensor(),</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    T2.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</section>
</section>
<section id="working-with-bounding-boxes" class="level2">
<h2 class="anchored" data-anchor-id="working-with-bounding-boxes" id="working-with-bounding-boxes">Working with Bounding Boxes</h2>
<section id="albumentations-excellent-support" class="level3">
<h3 class="anchored" data-anchor-id="albumentations-excellent-support" id="albumentations-excellent-support">Albumentations (Excellent Support)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Define bounding boxes in Pascal VOC format (x_min, y_min, x_max, y_max)</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>bboxes <span class="op">=</span> [[<span class="dv">50</span>, <span class="dv">50</span>, <span class="dv">150</span>, <span class="dv">150</span>, <span class="st">'person'</span>], [<span class="dv">200</span>, <span class="dv">100</span>, <span class="dv">300</span>, <span class="dv">200</span>, <span class="st">'car'</span>]]</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>bbox_transform <span class="op">=</span> A.Compose([</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    A.Resize(<span class="dv">416</span>, <span class="dv">416</span>),</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    A.HorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    A.RandomBrightnessContrast(p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    A.ShiftScaleRotate(shift_limit<span class="op">=</span><span class="fl">0.1</span>, scale_limit<span class="op">=</span><span class="fl">0.2</span>, rotate_limit<span class="op">=</span><span class="dv">15</span>, p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>], bbox_params<span class="op">=</span>A.BboxParams(<span class="bu">format</span><span class="op">=</span><span class="st">'pascal_voc'</span>, label_fields<span class="op">=</span>[<span class="st">'class_labels'</span>]))</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply transform</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>transformed <span class="op">=</span> bbox_transform(</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    image<span class="op">=</span>image, </span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    bboxes<span class="op">=</span>[[<span class="dv">50</span>, <span class="dv">50</span>, <span class="dv">150</span>, <span class="dv">150</span>], [<span class="dv">200</span>, <span class="dv">100</span>, <span class="dv">300</span>, <span class="dv">200</span>]], </span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    class_labels<span class="op">=</span>[<span class="st">'person'</span>, <span class="st">'car'</span>]</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>transformed_image <span class="op">=</span> transformed[<span class="st">'image'</span>]</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>transformed_bboxes <span class="op">=</span> transformed[<span class="st">'bboxes'</span>]</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>transformed_labels <span class="op">=</span> transformed[<span class="st">'class_labels'</span>]</span></code></pre></div></div>
</section>
<section id="torchvision-limited-support" class="level3">
<h3 class="anchored" data-anchor-id="torchvision-limited-support" id="torchvision-limited-support">TorchVision (Limited Support)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision v2 has some bbox support</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms.v2 <span class="im">as</span> T2</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>bbox_torchvision <span class="op">=</span> T2.Compose([</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    T2.Resize((<span class="dv">416</span>, <span class="dv">416</span>)),</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    T2.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    T2.ToTensor()</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Requires manual handling of bounding boxes</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Less intuitive than Albumentations</span></span></code></pre></div></div>
</section>
</section>
<section id="working-with-segmentation-masks" class="level2">
<h2 class="anchored" data-anchor-id="working-with-segmentation-masks" id="working-with-segmentation-masks">Working with Segmentation Masks</h2>
<section id="albumentations" class="level3">
<h3 class="anchored" data-anchor-id="albumentations" id="albumentations">Albumentations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Segmentation mask handling</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>segmentation_transform <span class="op">=</span> A.Compose([</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    A.Resize(<span class="dv">512</span>, <span class="dv">512</span>),</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    A.HorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    A.RandomBrightnessContrast(p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>    A.ShiftScaleRotate(shift_limit<span class="op">=</span><span class="fl">0.1</span>, scale_limit<span class="op">=</span><span class="fl">0.2</span>, rotate_limit<span class="op">=</span><span class="dv">15</span>, p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    A.Normalize(),</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    ToTensorV2()</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply to image and mask simultaneously</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> segmentation_transform(image<span class="op">=</span>image, mask<span class="op">=</span>mask)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>transformed_image <span class="op">=</span> result[<span class="st">'image'</span>]</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>transformed_mask <span class="op">=</span> result[<span class="st">'mask'</span>]</span></code></pre></div></div>
</section>
<section id="torchvision" class="level3">
<h3 class="anchored" data-anchor-id="torchvision" id="torchvision">TorchVision</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># TorchVision requires separate handling</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> apply_transform_to_mask(transform, image, mask):</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Manual synchronization needed</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    seed <span class="op">=</span> torch.randint(<span class="dv">0</span>, <span class="dv">2</span><span class="op">**</span><span class="dv">32</span>, size<span class="op">=</span>(<span class="dv">1</span>,)).item()</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    torch.manual_seed(seed)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    transformed_image <span class="op">=</span> transform(image)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    torch.manual_seed(seed)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply only geometric transforms to mask</span></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    mask_transform <span class="op">=</span> T.Compose([</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>        T.Resize((<span class="dv">512</span>, <span class="dv">512</span>)),</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>        T.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        T.ToTensor()</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    transformed_mask <span class="op">=</span> mask_transform(mask)</span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> transformed_image, transformed_mask</span></code></pre></div></div>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<div id="f70e8f88" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> benchmark_transforms(image, iterations<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># TorchVision timing</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(iterations):</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        _ <span class="op">=</span> torchvision_transform(image.copy())</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    torch_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Convert to numpy for Albumentations</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    np_image <span class="op">=</span> np.array(image)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Albumentations timing</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(iterations):</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        _ <span class="op">=</span> albumentations_transform(image<span class="op">=</span>np_image.copy())</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    albu_time <span class="op">=</span> time.time() <span class="op">-</span> start_time</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"TorchVision: </span><span class="sc">{</span>torch_time<span class="sc">:.3f}</span><span class="ss">s (</span><span class="sc">{</span>iterations<span class="sc">}</span><span class="ss"> iterations)"</span>)</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Albumentations: </span><span class="sc">{</span>albu_time<span class="sc">:.3f}</span><span class="ss">s (</span><span class="sc">{</span>iterations<span class="sc">}</span><span class="ss"> iterations)"</span>)</span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Speedup: </span><span class="sc">{</span>torch_time<span class="op">/</span>albu_time<span class="sc">:.2f}</span><span class="ss">x"</span>)</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Run benchmark</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>benchmark_transforms(torch_img)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>TorchVision: 37.706s (1000 iterations)
Albumentations: 2.810s (1000 iterations)
Speedup: 13.42x</code></pre>
</div>
</div>
</section>
<section id="custom-pipeline-examples" class="level2">
<h2 class="anchored" data-anchor-id="custom-pipeline-examples" id="custom-pipeline-examples">Custom Pipeline Examples</h2>
<section id="data-science-pipeline-with-albumentations" class="level3">
<h3 class="anchored" data-anchor-id="data-science-pipeline-with-albumentations" id="data-science-pipeline-with-albumentations">Data Science Pipeline with Albumentations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_training_pipeline():</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> A.Compose([</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Geometric transformations</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        A.OneOf([</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>            A.ShiftScaleRotate(shift_limit<span class="op">=</span><span class="fl">0.1</span>, scale_limit<span class="op">=</span><span class="fl">0.2</span>, rotate_limit<span class="op">=</span><span class="dv">30</span>),</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>            A.ElasticTransform(alpha<span class="op">=</span><span class="dv">1</span>, sigma<span class="op">=</span><span class="dv">50</span>),</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>            A.GridDistortion(num_steps<span class="op">=</span><span class="dv">5</span>, distort_limit<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        ], p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Color augmentations</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        A.OneOf([</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>            A.ColorJitter(brightness<span class="op">=</span><span class="fl">0.3</span>, contrast<span class="op">=</span><span class="fl">0.3</span>, saturation<span class="op">=</span><span class="fl">0.3</span>, hue<span class="op">=</span><span class="fl">0.1</span>),</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>            A.HueSaturationValue(hue_shift_limit<span class="op">=</span><span class="dv">20</span>, sat_shift_limit<span class="op">=</span><span class="dv">30</span>, val_shift_limit<span class="op">=</span><span class="dv">20</span>),</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>            A.RandomBrightnessContrast(brightness_limit<span class="op">=</span><span class="fl">0.3</span>, contrast_limit<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        ], p<span class="op">=</span><span class="fl">0.8</span>),</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Noise and blur</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        A.OneOf([</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>            A.GaussNoise(var_limit<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">50</span>)),</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>            A.ISONoise(),</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>            A.MultiplicativeNoise(),</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        ], p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        A.OneOf([</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>            A.MotionBlur(blur_limit<span class="op">=</span><span class="dv">5</span>),</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>            A.MedianBlur(blur_limit<span class="op">=</span><span class="dv">5</span>),</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>            A.GaussianBlur(blur_limit<span class="op">=</span><span class="dv">5</span>),</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>        ], p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Final processing</span></span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>        A.Resize(<span class="dv">224</span>, <span class="dv">224</span>),</span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>        A.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>        ToTensorV2()</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_validation_pipeline():</span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> A.Compose([</span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>        A.Resize(<span class="dv">224</span>, <span class="dv">224</span>),</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>        A.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>        ToTensorV2()</span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>    ])</span></code></pre></div></div>
</section>
<section id="torchvision-pipeline-for-simple-cases" class="level3">
<h3 class="anchored" data-anchor-id="torchvision-pipeline-for-simple-cases" id="torchvision-pipeline-for-simple-cases">TorchVision Pipeline for Simple Cases</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_simple_training_pipeline():</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> T.Compose([</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>        T.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        T.RandomHorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        T.RandomRotation(degrees<span class="op">=</span><span class="dv">15</span>),</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        T.RandomApply([T.ColorJitter(<span class="fl">0.3</span>, <span class="fl">0.3</span>, <span class="fl">0.3</span>, <span class="fl">0.1</span>)], p<span class="op">=</span><span class="fl">0.8</span>),</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        T.RandomApply([T.GaussianBlur(kernel_size<span class="op">=</span><span class="dv">3</span>)], p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        T.ToTensor(),</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        T.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>    ])</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> create_simple_validation_pipeline():</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> T.Compose([</span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        T.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        T.ToTensor(),</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>        T.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>    ])</span></code></pre></div></div>
</section>
</section>
<section id="dataset-integration" class="level2">
<h2 class="anchored" data-anchor-id="dataset-integration" id="dataset-integration">Dataset Integration</h2>
<section id="pytorch-dataset-with-albumentations" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-dataset-with-albumentations" id="pytorch-dataset-with-albumentations">PyTorch Dataset with Albumentations</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> Dataset, DataLoader</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CustomDataset(Dataset):</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_paths, labels, transform<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_paths <span class="op">=</span> image_paths</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> labels</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transform</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.image_paths)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load image for Albumentations (OpenCV format)</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> cv2.imread(<span class="va">self</span>.image_paths[idx])</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> cv2.cvtColor(image, cv2.COLOR_BGR2RGB)</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.transform:</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>            transformed <span class="op">=</span> <span class="va">self</span>.transform(image<span class="op">=</span>image)</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> transformed[<span class="st">'image'</span>]</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image, <span class="va">self</span>.labels[idx]</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a>train_dataset <span class="op">=</span> CustomDataset(</span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>    image_paths<span class="op">=</span>train_paths,</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>    labels<span class="op">=</span>train_labels,</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>    transform<span class="op">=</span>create_training_pipeline()</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="pytorch-dataset-with-torchvision" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-dataset-with-torchvision" id="pytorch-dataset-with-torchvision">PyTorch Dataset with TorchVision</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TorchVisionDataset(Dataset):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_paths, labels, transform<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_paths <span class="op">=</span> image_paths</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.labels <span class="op">=</span> labels</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transform</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__len__</span>(<span class="va">self</span>):</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">len</span>(<span class="va">self</span>.image_paths)</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__getitem__</span>(<span class="va">self</span>, idx):</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load image for TorchVision (PIL format)</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="va">self</span>.image_paths[idx]).convert(<span class="st">'RGB'</span>)</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.transform:</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> <span class="va">self</span>.transform(image)</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image, <span class="va">self</span>.labels[idx]</span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a>train_dataset <span class="op">=</span> TorchVisionDataset(</span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>    image_paths<span class="op">=</span>train_paths,</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>    labels<span class="op">=</span>train_labels,</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>    transform<span class="op">=</span>create_simple_training_pipeline()</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
<section id="when-to-use-which-library" class="level2">
<h2 class="anchored" data-anchor-id="when-to-use-which-library" id="when-to-use-which-library">When to Use Which Library</h2>
<section id="choose-albumentations-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-albumentations-when" id="choose-albumentations-when">Choose Albumentations When:</h3>
<ul>
<li>Working with object detection or segmentation tasks</li>
<li>Need advanced augmentation techniques (weather effects, distortions)</li>
<li>Performance is critical (processing large datasets)</li>
<li>Working with bounding boxes or keypoints</li>
<li>Need fine-grained control over augmentation probabilities</li>
<li>Dealing with medical or satellite imagery</li>
</ul>
</section>
<section id="choose-torchvision-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-torchvision-when" id="choose-torchvision-when">Choose TorchVision When:</h3>
<ul>
<li>Building simple image classification models</li>
<li>Working within pure PyTorch ecosystem</li>
<li>Need basic augmentations only</li>
<li>Prototyping quickly</li>
<li>Following PyTorch tutorials or established workflows</li>
<li>Working with pre-trained models that expect specific preprocessing</li>
</ul>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="albumentations-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="albumentations-best-practices" id="albumentations-best-practices">Albumentations Best Practices</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use ReplayCompose for debugging</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>replay_transform <span class="op">=</span> A.ReplayCompose([</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    A.HorizontalFlip(p<span class="op">=</span><span class="fl">0.5</span>),</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>    A.RandomBrightnessContrast(p<span class="op">=</span><span class="fl">0.3</span>),</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> replay_transform(image<span class="op">=</span>image)</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>transformed_image <span class="op">=</span> result[<span class="st">'image'</span>]</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>replay_data <span class="op">=</span> result[<span class="st">'replay'</span>]</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Apply same transforms to another image</span></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>result2 <span class="op">=</span> A.ReplayCompose.replay(replay_data, image<span class="op">=</span>another_image)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Efficient bbox handling</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>bbox_params <span class="op">=</span> A.BboxParams(</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>    <span class="bu">format</span><span class="op">=</span><span class="st">'pascal_voc'</span>,</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>    min_area<span class="op">=</span><span class="dv">1024</span>,  <span class="co"># Filter out small boxes</span></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>    min_visibility<span class="op">=</span><span class="fl">0.3</span>,  <span class="co"># Filter out mostly occluded boxes</span></span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    label_fields<span class="op">=</span>[<span class="st">'class_labels'</span>]</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="torchvision-best-practices" class="level3">
<h3 class="anchored" data-anchor-id="torchvision-best-practices" id="torchvision-best-practices">TorchVision Best Practices</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use functional API for custom control</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> custom_transform(image):</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> torch.rand(<span class="dv">1</span>) <span class="op">&lt;</span> <span class="fl">0.5</span>:</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> TF.hflip(image)</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply rotation with custom logic</span></span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    angle <span class="op">=</span> torch.randint(<span class="op">-</span><span class="dv">30</span>, <span class="dv">30</span>, (<span class="dv">1</span>,)).item()</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    image <span class="op">=</span> TF.rotate(image, angle)</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> image</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Combine with standard transforms</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>combined_transform <span class="op">=</span> T.Compose([</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    T.Lambda(custom_transform),</span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    T.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Both libraries have their strengths:</p>
<p><strong>Albumentations</strong> excels in:</p>
<ul>
<li>Advanced augmentation techniques</li>
<li>Performance optimization</li>
<li>Computer vision tasks beyond classification</li>
<li>Professional production environments</li>
</ul>
<p><strong>TorchVision</strong> is ideal for:</p>
<ul>
<li>Simple classification tasks</li>
<li>Learning and prototyping</li>
<li>Tight PyTorch integration</li>
<li>Basic augmentation needs</li>
</ul>
<p>Choose based on your specific requirements, with Albumentations being the go-to choice for advanced computer vision projects and TorchVision for simpler classification tasks.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[MobileNetV2 PyTorch Docker Deployment Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/deployment/mobilenet-deployment/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/deployment/mobilenet-deployment/</guid>
      <pubDate>Mon, 26 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>mlops</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="mobilenetv2-pytorch-docker-deployment-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/deployment/mobilenet-deployment/docker.png" class="img-fluid"></p>
<p>This guide walks you through deploying a pre-trained MobileNetV2 model using PyTorch and Docker, creating a REST API for image classification.</p>
<section id="project-structure" class="level2">
<h2 class="anchored" data-anchor-id="project-structure" id="project-structure">Project Structure</h2>
<pre><code>mobilenetv2-pytorch-docker/
├── app/
│   ├── __init__.py
│   ├── main.py
│   ├── model_handler.py
│   └── utils.py
├── requirements.txt
├── Dockerfile
├── docker-compose.yml
├── .dockerignore
└── README.md</code></pre>
</section>
<section id="application-code" class="level2">
<h2 class="anchored" data-anchor-id="application-code" id="application-code">1. Application Code</h2>
<section id="appmain.py---fastapi-application" class="level3">
<h3 class="anchored" data-anchor-id="appmain.py---fastapi-application" id="appmain.py---fastapi-application"><code>app/main.py</code> - FastAPI Application</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> fastapi <span class="im">import</span> FastAPI, File, UploadFile, HTTPException</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> fastapi.responses <span class="im">import</span> JSONResponse</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> uvicorn</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> io</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> .model_handler <span class="im">import</span> MobileNetV2Handler</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> .utils <span class="im">import</span> preprocess_image, decode_predictions</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Configure logging</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>logging.basicConfig(level<span class="op">=</span>logging.INFO)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>app <span class="op">=</span> FastAPI(</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">"MobileNetV2 PyTorch Image Classification API"</span>,</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>    description<span class="op">=</span><span class="st">"Deploy MobileNetV2 using PyTorch for image classification"</span>,</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>    version<span class="op">=</span><span class="st">"1.0.0"</span></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize model handler</span></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>model_handler <span class="op">=</span> MobileNetV2Handler()</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a><span class="at">@app.on_event</span>(<span class="st">"startup"</span>)</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> startup_event():</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Load model on startup"""</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        model_handler.load_model()</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        logger.info(<span class="st">"Model loaded successfully"</span>)</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        logger.error(<span class="ss">f"Failed to load model: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span></span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a><span class="at">@app.get</span>(<span class="st">"/"</span>)</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> root():</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">"message"</span>: <span class="st">"MobileNetV2 PyTorch Classification API"</span>, <span class="st">"status"</span>: <span class="st">"running"</span>}</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a><span class="at">@app.get</span>(<span class="st">"/health"</span>)</span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> health_check():</span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {<span class="st">"status"</span>: <span class="st">"healthy"</span>, <span class="st">"model_loaded"</span>: model_handler.is_loaded()}</span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a><span class="at">@app.post</span>(<span class="st">"/predict"</span>)</span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> predict(<span class="bu">file</span>: UploadFile <span class="op">=</span> File(...)):</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a><span class="co">    Predict image class using MobileNetV2</span></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="kw">not</span> <span class="bu">file</span>.content_type.startswith(<span class="st">"image/"</span>):</span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">400</span>, detail<span class="op">=</span><span class="st">"File must be an image"</span>)</span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Read and preprocess image</span></span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a>        image_data <span class="op">=</span> <span class="cf">await</span> <span class="bu">file</span>.read()</span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data))</span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> image.mode <span class="op">!=</span> <span class="st">'RGB'</span>:</span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> image.convert(<span class="st">'RGB'</span>)</span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Preprocess for MobileNetV2</span></span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>        processed_image <span class="op">=</span> preprocess_image(image)</span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Make prediction</span></span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>        predictions <span class="op">=</span> model_handler.predict(processed_image)</span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Decode predictions</span></span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>        decoded_predictions <span class="op">=</span> decode_predictions(predictions, top<span class="op">=</span><span class="dv">5</span>)</span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> JSONResponse(content<span class="op">=</span>{</span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a>            <span class="st">"predictions"</span>: decoded_predictions,</span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a>            <span class="st">"success"</span>: <span class="va">True</span></span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a>        })</span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a>        logger.error(<span class="ss">f"Prediction error: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">500</span>, detail<span class="op">=</span><span class="ss">f"Prediction failed: </span><span class="sc">{</span><span class="bu">str</span>(e)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a><span class="at">@app.post</span>(<span class="st">"/batch_predict"</span>)</span>
<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> batch_predict(files: <span class="bu">list</span>[UploadFile] <span class="op">=</span> File(...)):</span>
<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a><span class="co">    Batch prediction for multiple images</span></span>
<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(files) <span class="op">&gt;</span> <span class="dv">10</span>:  <span class="co"># Limit batch size</span></span>
<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> HTTPException(status_code<span class="op">=</span><span class="dv">400</span>, detail<span class="op">=</span><span class="st">"Maximum 10 images allowed per batch"</span>)</span>
<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a>    results <span class="op">=</span> []</span>
<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> <span class="bu">file</span> <span class="kw">in</span> files:</span>
<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="bu">file</span>.content_type.startswith(<span class="st">"image/"</span>):</span>
<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a>                <span class="st">"filename"</span>: <span class="bu">file</span>.filename,</span>
<span id="cb2-90"><a href="#cb2-90" aria-hidden="true" tabindex="-1"></a>                <span class="st">"error"</span>: <span class="st">"File must be an image"</span></span>
<span id="cb2-91"><a href="#cb2-91" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb2-92"><a href="#cb2-92" aria-hidden="true" tabindex="-1"></a>            <span class="cf">continue</span></span>
<span id="cb2-93"><a href="#cb2-93" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-94"><a href="#cb2-94" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb2-95"><a href="#cb2-95" aria-hidden="true" tabindex="-1"></a>            image_data <span class="op">=</span> <span class="cf">await</span> <span class="bu">file</span>.read()</span>
<span id="cb2-96"><a href="#cb2-96" aria-hidden="true" tabindex="-1"></a>            image <span class="op">=</span> Image.<span class="bu">open</span>(io.BytesIO(image_data))</span>
<span id="cb2-97"><a href="#cb2-97" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-98"><a href="#cb2-98" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> image.mode <span class="op">!=</span> <span class="st">'RGB'</span>:</span>
<span id="cb2-99"><a href="#cb2-99" aria-hidden="true" tabindex="-1"></a>                image <span class="op">=</span> image.convert(<span class="st">'RGB'</span>)</span>
<span id="cb2-100"><a href="#cb2-100" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-101"><a href="#cb2-101" aria-hidden="true" tabindex="-1"></a>            processed_image <span class="op">=</span> preprocess_image(image)</span>
<span id="cb2-102"><a href="#cb2-102" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> model_handler.predict(processed_image)</span>
<span id="cb2-103"><a href="#cb2-103" aria-hidden="true" tabindex="-1"></a>            decoded_predictions <span class="op">=</span> decode_predictions(predictions, top<span class="op">=</span><span class="dv">3</span>)</span>
<span id="cb2-104"><a href="#cb2-104" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-105"><a href="#cb2-105" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb2-106"><a href="#cb2-106" aria-hidden="true" tabindex="-1"></a>                <span class="st">"filename"</span>: <span class="bu">file</span>.filename,</span>
<span id="cb2-107"><a href="#cb2-107" aria-hidden="true" tabindex="-1"></a>                <span class="st">"predictions"</span>: decoded_predictions,</span>
<span id="cb2-108"><a href="#cb2-108" aria-hidden="true" tabindex="-1"></a>                <span class="st">"success"</span>: <span class="va">True</span></span>
<span id="cb2-109"><a href="#cb2-109" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb2-110"><a href="#cb2-110" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-111"><a href="#cb2-111" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb2-112"><a href="#cb2-112" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb2-113"><a href="#cb2-113" aria-hidden="true" tabindex="-1"></a>                <span class="st">"filename"</span>: <span class="bu">file</span>.filename,</span>
<span id="cb2-114"><a href="#cb2-114" aria-hidden="true" tabindex="-1"></a>                <span class="st">"error"</span>: <span class="bu">str</span>(e),</span>
<span id="cb2-115"><a href="#cb2-115" aria-hidden="true" tabindex="-1"></a>                <span class="st">"success"</span>: <span class="va">False</span></span>
<span id="cb2-116"><a href="#cb2-116" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb2-117"><a href="#cb2-117" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-118"><a href="#cb2-118" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> JSONResponse(content<span class="op">=</span>{<span class="st">"results"</span>: results})</span>
<span id="cb2-119"><a href="#cb2-119" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-120"><a href="#cb2-120" aria-hidden="true" tabindex="-1"></a><span class="at">@app.get</span>(<span class="st">"/model_info"</span>)</span>
<span id="cb2-121"><a href="#cb2-121" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> model_info():</span>
<span id="cb2-122"><a href="#cb2-122" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Get model information"""</span></span>
<span id="cb2-123"><a href="#cb2-123" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb2-124"><a href="#cb2-124" aria-hidden="true" tabindex="-1"></a>        <span class="st">"model_name"</span>: <span class="st">"MobileNetV2"</span>,</span>
<span id="cb2-125"><a href="#cb2-125" aria-hidden="true" tabindex="-1"></a>        <span class="st">"framework"</span>: <span class="st">"PyTorch"</span>,</span>
<span id="cb2-126"><a href="#cb2-126" aria-hidden="true" tabindex="-1"></a>        <span class="st">"input_size"</span>: [<span class="dv">224</span>, <span class="dv">224</span>],</span>
<span id="cb2-127"><a href="#cb2-127" aria-hidden="true" tabindex="-1"></a>        <span class="st">"num_classes"</span>: <span class="dv">1000</span>,</span>
<span id="cb2-128"><a href="#cb2-128" aria-hidden="true" tabindex="-1"></a>        <span class="st">"pretrained"</span>: <span class="va">True</span></span>
<span id="cb2-129"><a href="#cb2-129" aria-hidden="true" tabindex="-1"></a>    }</span>
<span id="cb2-130"><a href="#cb2-130" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-131"><a href="#cb2-131" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb2-132"><a href="#cb2-132" aria-hidden="true" tabindex="-1"></a>    uvicorn.run(app, host<span class="op">=</span><span class="st">"0.0.0.0"</span>, port<span class="op">=</span><span class="dv">8000</span>)</span></code></pre></div></div>
</section>
<section id="appmodel_handler.py---pytorch-model-management" class="level3">
<h3 class="anchored" data-anchor-id="appmodel_handler.py---pytorch-model-management" id="appmodel_handler.py---pytorch-model-management"><code>app/model_handler.py</code> - PyTorch Model Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> models</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MobileNetV2Handler:</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">None</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> <span class="va">None</span></span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._loaded <span class="op">=</span> <span class="va">False</span></span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_model(<span class="va">self</span>):</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Load pre-trained MobileNetV2 model"""</span></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="st">"Loading MobileNetV2 PyTorch model..."</span>)</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Determine device (CPU/GPU)</span></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="ss">f"Using device: </span><span class="sc">{</span><span class="va">self</span><span class="sc">.</span>device<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Load pre-trained MobileNetV2</span></span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model <span class="op">=</span> models.mobilenet_v2(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.<span class="bu">eval</span>()  <span class="co"># Set to evaluation mode</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model.to(<span class="va">self</span>.device)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Warm up the model with a dummy prediction</span></span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>            dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">3</span>, <span class="dv">224</span>, <span class="dv">224</span>).to(<span class="va">self</span>.device)</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>                _ <span class="op">=</span> <span class="va">self</span>.model(dummy_input)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>._loaded <span class="op">=</span> <span class="va">True</span></span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="st">"Model loaded and warmed up successfully"</span>)</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>            logger.error(<span class="ss">f"Failed to load model: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict(<span class="va">self</span>, image_tensor):</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Make prediction on preprocessed image tensor"""</span></span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>._loaded:</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">RuntimeError</span>(<span class="st">"Model not loaded"</span>)</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Ensure tensor is on correct device</span></span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(image_tensor, np.ndarray):</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>                image_tensor <span class="op">=</span> torch.from_numpy(image_tensor)</span>
<span id="cb3-50"><a href="#cb3-50" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-51"><a href="#cb3-51" aria-hidden="true" tabindex="-1"></a>            image_tensor <span class="op">=</span> image_tensor.to(<span class="va">self</span>.device)</span>
<span id="cb3-52"><a href="#cb3-52" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-53"><a href="#cb3-53" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Ensure batch dimension</span></span>
<span id="cb3-54"><a href="#cb3-54" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">len</span>(image_tensor.shape) <span class="op">==</span> <span class="dv">3</span>:</span>
<span id="cb3-55"><a href="#cb3-55" aria-hidden="true" tabindex="-1"></a>                image_tensor <span class="op">=</span> image_tensor.unsqueeze(<span class="dv">0</span>)</span>
<span id="cb3-56"><a href="#cb3-56" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-57"><a href="#cb3-57" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Make prediction</span></span>
<span id="cb3-58"><a href="#cb3-58" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-59"><a href="#cb3-59" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.model(image_tensor)</span>
<span id="cb3-60"><a href="#cb3-60" aria-hidden="true" tabindex="-1"></a>                <span class="co"># Apply softmax to get probabilities</span></span>
<span id="cb3-61"><a href="#cb3-61" aria-hidden="true" tabindex="-1"></a>                probabilities <span class="op">=</span> torch.nn.functional.softmax(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-62"><a href="#cb3-62" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-63"><a href="#cb3-63" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> probabilities.cpu().numpy()</span>
<span id="cb3-64"><a href="#cb3-64" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-65"><a href="#cb3-65" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb3-66"><a href="#cb3-66" aria-hidden="true" tabindex="-1"></a>            logger.error(<span class="ss">f"Prediction failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-67"><a href="#cb3-67" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb3-68"><a href="#cb3-68" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-69"><a href="#cb3-69" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> predict_batch(<span class="va">self</span>, image_tensors):</span>
<span id="cb3-70"><a href="#cb3-70" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Make batch predictions"""</span></span>
<span id="cb3-71"><a href="#cb3-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>._loaded:</span>
<span id="cb3-72"><a href="#cb3-72" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">RuntimeError</span>(<span class="st">"Model not loaded"</span>)</span>
<span id="cb3-73"><a href="#cb3-73" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-74"><a href="#cb3-74" aria-hidden="true" tabindex="-1"></a>        <span class="cf">try</span>:</span>
<span id="cb3-75"><a href="#cb3-75" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Convert to tensor if numpy array</span></span>
<span id="cb3-76"><a href="#cb3-76" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(image_tensors, np.ndarray):</span>
<span id="cb3-77"><a href="#cb3-77" aria-hidden="true" tabindex="-1"></a>                image_tensors <span class="op">=</span> torch.from_numpy(image_tensors)</span>
<span id="cb3-78"><a href="#cb3-78" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-79"><a href="#cb3-79" aria-hidden="true" tabindex="-1"></a>            image_tensors <span class="op">=</span> image_tensors.to(<span class="va">self</span>.device)</span>
<span id="cb3-80"><a href="#cb3-80" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-81"><a href="#cb3-81" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Make batch prediction</span></span>
<span id="cb3-82"><a href="#cb3-82" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-83"><a href="#cb3-83" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> <span class="va">self</span>.model(image_tensors)</span>
<span id="cb3-84"><a href="#cb3-84" aria-hidden="true" tabindex="-1"></a>                probabilities <span class="op">=</span> torch.nn.functional.softmax(outputs, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-85"><a href="#cb3-85" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-86"><a href="#cb3-86" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> probabilities.cpu().numpy()</span>
<span id="cb3-87"><a href="#cb3-87" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb3-88"><a href="#cb3-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb3-89"><a href="#cb3-89" aria-hidden="true" tabindex="-1"></a>            logger.error(<span class="ss">f"Batch prediction failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-90"><a href="#cb3-90" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span></span>
<span id="cb3-91"><a href="#cb3-91" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-92"><a href="#cb3-92" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> is_loaded(<span class="va">self</span>):</span>
<span id="cb3-93"><a href="#cb3-93" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Check if model is loaded"""</span></span>
<span id="cb3-94"><a href="#cb3-94" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._loaded</span>
<span id="cb3-95"><a href="#cb3-95" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-96"><a href="#cb3-96" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_device(<span class="va">self</span>):</span>
<span id="cb3-97"><a href="#cb3-97" aria-hidden="true" tabindex="-1"></a>        <span class="co">"""Get current device"""</span></span>
<span id="cb3-98"><a href="#cb3-98" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="bu">str</span>(<span class="va">self</span>.device) <span class="cf">if</span> <span class="va">self</span>.device <span class="cf">else</span> <span class="st">"not initialized"</span></span></code></pre></div></div>
</section>
<section id="apputils.py---utility-functions" class="level3">
<h3 class="anchored" data-anchor-id="apputils.py---utility-functions" id="apputils.py---utility-functions"><code>app/utils.py</code> - Utility Functions</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> os</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> logging</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>logger <span class="op">=</span> logging.getLogger(<span class="va">__name__</span>)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="co"># ImageNet class labels</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>IMAGENET_CLASSES_URL <span class="op">=</span> <span class="st">"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"</span></span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_imagenet_classes():</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Download and cache ImageNet class labels"""</span></span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> os.path.exists(<span class="st">"imagenet_classes.txt"</span>):</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(<span class="st">"imagenet_classes.txt"</span>, <span class="st">"r"</span>) <span class="im">as</span> f:</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>                classes <span class="op">=</span> [line.strip() <span class="cf">for</span> line <span class="kw">in</span> f.readlines()]</span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="st">"Downloading ImageNet class labels..."</span>)</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>            response <span class="op">=</span> requests.get(IMAGENET_CLASSES_URL)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            classes <span class="op">=</span> response.text.strip().split(<span class="st">'</span><span class="ch">\n</span><span class="st">'</span>)</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Cache the classes</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> <span class="bu">open</span>(<span class="st">"imagenet_classes.txt"</span>, <span class="st">"w"</span>) <span class="im">as</span> f:</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> class_name <span class="kw">in</span> classes:</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>                    f.write(<span class="ss">f"</span><span class="sc">{</span>class_name<span class="sc">}</span><span class="ch">\n</span><span class="ss">"</span>)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> classes</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>        logger.warning(<span class="ss">f"Could not load ImageNet classes: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> [<span class="ss">f"class_</span><span class="sc">{</span>i<span class="sc">}</span><span class="ss">"</span> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1000</span>)]</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a><span class="co"># Load ImageNet classes</span></span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>IMAGENET_CLASSES <span class="op">=</span> get_imagenet_classes()</span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_image(image: Image.Image, target_size<span class="op">=</span>(<span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a><span class="co">    Preprocess image for MobileNetV2 PyTorch model</span></span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define transforms</span></span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(target_size),</span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(</span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>                mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],  <span class="co"># ImageNet normalization</span></span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>                std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply transforms</span></span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>        image_tensor <span class="op">=</span> transform(image)</span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image_tensor</span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Image preprocessing failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> preprocess_batch(images: <span class="bu">list</span>, target_size<span class="op">=</span>(<span class="dv">224</span>, <span class="dv">224</span>)):</span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-65"><a href="#cb4-65" aria-hidden="true" tabindex="-1"></a><span class="co">    Preprocess batch of images</span></span>
<span id="cb4-66"><a href="#cb4-66" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-67"><a href="#cb4-67" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-68"><a href="#cb4-68" aria-hidden="true" tabindex="-1"></a>        transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-69"><a href="#cb4-69" aria-hidden="true" tabindex="-1"></a>            transforms.Resize(<span class="dv">256</span>),</span>
<span id="cb4-70"><a href="#cb4-70" aria-hidden="true" tabindex="-1"></a>            transforms.CenterCrop(target_size),</span>
<span id="cb4-71"><a href="#cb4-71" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb4-72"><a href="#cb4-72" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize(</span>
<span id="cb4-73"><a href="#cb4-73" aria-hidden="true" tabindex="-1"></a>                mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>],</span>
<span id="cb4-74"><a href="#cb4-74" aria-hidden="true" tabindex="-1"></a>                std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]</span>
<span id="cb4-75"><a href="#cb4-75" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb4-76"><a href="#cb4-76" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-77"><a href="#cb4-77" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-78"><a href="#cb4-78" aria-hidden="true" tabindex="-1"></a>        batch_tensors <span class="op">=</span> []</span>
<span id="cb4-79"><a href="#cb4-79" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> image <span class="kw">in</span> images:</span>
<span id="cb4-80"><a href="#cb4-80" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> <span class="bu">isinstance</span>(image, <span class="bu">str</span>):  <span class="co"># If path</span></span>
<span id="cb4-81"><a href="#cb4-81" aria-hidden="true" tabindex="-1"></a>                image <span class="op">=</span> Image.<span class="bu">open</span>(image).convert(<span class="st">'RGB'</span>)</span>
<span id="cb4-82"><a href="#cb4-82" aria-hidden="true" tabindex="-1"></a>            <span class="cf">elif</span> <span class="kw">not</span> <span class="bu">isinstance</span>(image, Image.Image):</span>
<span id="cb4-83"><a href="#cb4-83" aria-hidden="true" tabindex="-1"></a>                <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="st">"Invalid image type"</span>)</span>
<span id="cb4-84"><a href="#cb4-84" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-85"><a href="#cb4-85" aria-hidden="true" tabindex="-1"></a>            tensor <span class="op">=</span> transform(image)</span>
<span id="cb4-86"><a href="#cb4-86" aria-hidden="true" tabindex="-1"></a>            batch_tensors.append(tensor)</span>
<span id="cb4-87"><a href="#cb4-87" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-88"><a href="#cb4-88" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.stack(batch_tensors)</span>
<span id="cb4-89"><a href="#cb4-89" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-90"><a href="#cb4-90" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-91"><a href="#cb4-91" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Batch preprocessing failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-92"><a href="#cb4-92" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-93"><a href="#cb4-93" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> decode_predictions(predictions, top<span class="op">=</span><span class="dv">5</span>):</span>
<span id="cb4-94"><a href="#cb4-94" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-95"><a href="#cb4-95" aria-hidden="true" tabindex="-1"></a><span class="co">    Decode model predictions to human-readable labels</span></span>
<span id="cb4-96"><a href="#cb4-96" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-97"><a href="#cb4-97" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-98"><a href="#cb4-98" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get top predictions</span></span>
<span id="cb4-99"><a href="#cb4-99" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">isinstance</span>(predictions, torch.Tensor):</span>
<span id="cb4-100"><a href="#cb4-100" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> predictions.numpy()</span>
<span id="cb4-101"><a href="#cb4-101" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-102"><a href="#cb4-102" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Handle batch predictions (take first sample)</span></span>
<span id="cb4-103"><a href="#cb4-103" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(predictions.shape) <span class="op">&gt;</span> <span class="dv">1</span>:</span>
<span id="cb4-104"><a href="#cb4-104" aria-hidden="true" tabindex="-1"></a>            predictions <span class="op">=</span> predictions[<span class="dv">0</span>]</span>
<span id="cb4-105"><a href="#cb4-105" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-106"><a href="#cb4-106" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get top k indices</span></span>
<span id="cb4-107"><a href="#cb4-107" aria-hidden="true" tabindex="-1"></a>        top_indices <span class="op">=</span> np.argsort(predictions)[<span class="op">-</span>top:][::<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb4-108"><a href="#cb4-108" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-109"><a href="#cb4-109" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Format results</span></span>
<span id="cb4-110"><a href="#cb4-110" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> []</span>
<span id="cb4-111"><a href="#cb4-111" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> idx <span class="kw">in</span> top_indices:</span>
<span id="cb4-112"><a href="#cb4-112" aria-hidden="true" tabindex="-1"></a>            confidence <span class="op">=</span> <span class="bu">float</span>(predictions[idx])</span>
<span id="cb4-113"><a href="#cb4-113" aria-hidden="true" tabindex="-1"></a>            class_name <span class="op">=</span> IMAGENET_CLASSES[idx] <span class="cf">if</span> idx <span class="op">&lt;</span> <span class="bu">len</span>(IMAGENET_CLASSES) <span class="cf">else</span> <span class="ss">f"class_</span><span class="sc">{</span>idx<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb4-114"><a href="#cb4-114" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-115"><a href="#cb4-115" aria-hidden="true" tabindex="-1"></a>            results.append({</span>
<span id="cb4-116"><a href="#cb4-116" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class_id"</span>: <span class="bu">int</span>(idx),</span>
<span id="cb4-117"><a href="#cb4-117" aria-hidden="true" tabindex="-1"></a>                <span class="st">"class_name"</span>: class_name,</span>
<span id="cb4-118"><a href="#cb4-118" aria-hidden="true" tabindex="-1"></a>                <span class="st">"confidence"</span>: confidence</span>
<span id="cb4-119"><a href="#cb4-119" aria-hidden="true" tabindex="-1"></a>            })</span>
<span id="cb4-120"><a href="#cb4-120" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-121"><a href="#cb4-121" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> results</span>
<span id="cb4-122"><a href="#cb4-122" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-123"><a href="#cb4-123" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> <span class="pp">Exception</span> <span class="im">as</span> e:</span>
<span id="cb4-124"><a href="#cb4-124" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Prediction decoding failed: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb4-125"><a href="#cb4-125" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-126"><a href="#cb4-126" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validate_image(image_data):</span>
<span id="cb4-127"><a href="#cb4-127" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""</span></span>
<span id="cb4-128"><a href="#cb4-128" aria-hidden="true" tabindex="-1"></a><span class="co">    Validate image data</span></span>
<span id="cb4-129"><a href="#cb4-129" aria-hidden="true" tabindex="-1"></a><span class="co">    """</span></span>
<span id="cb4-130"><a href="#cb4-130" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb4-131"><a href="#cb4-131" aria-hidden="true" tabindex="-1"></a>        image <span class="op">=</span> Image.<span class="bu">open</span>(image_data)</span>
<span id="cb4-132"><a href="#cb4-132" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> image.<span class="bu">format</span> <span class="kw">in</span> [<span class="st">'JPEG'</span>, <span class="st">'PNG'</span>, <span class="st">'BMP'</span>, <span class="st">'TIFF'</span>, <span class="st">'WEBP'</span>]</span>
<span id="cb4-133"><a href="#cb4-133" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span>:</span>
<span id="cb4-134"><a href="#cb4-134" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">False</span></span>
<span id="cb4-135"><a href="#cb4-135" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-136"><a href="#cb4-136" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> tensor_to_numpy(tensor):</span>
<span id="cb4-137"><a href="#cb4-137" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert PyTorch tensor to numpy array"""</span></span>
<span id="cb4-138"><a href="#cb4-138" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">isinstance</span>(tensor, torch.Tensor):</span>
<span id="cb4-139"><a href="#cb4-139" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> tensor.detach().cpu().numpy()</span>
<span id="cb4-140"><a href="#cb4-140" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> tensor</span>
<span id="cb4-141"><a href="#cb4-141" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-142"><a href="#cb4-142" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> numpy_to_tensor(array, device<span class="op">=</span><span class="st">'cpu'</span>):</span>
<span id="cb4-143"><a href="#cb4-143" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Convert numpy array to PyTorch tensor"""</span></span>
<span id="cb4-144"><a href="#cb4-144" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">isinstance</span>(array, np.ndarray):</span>
<span id="cb4-145"><a href="#cb4-145" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.from_numpy(array).to(device)</span>
<span id="cb4-146"><a href="#cb4-146" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> array</span>
<span id="cb4-147"><a href="#cb4-147" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-148"><a href="#cb4-148" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ModelProfiler:</span>
<span id="cb4-149"><a href="#cb4-149" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Simple profiler for model performance"""</span></span>
<span id="cb4-150"><a href="#cb4-150" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-151"><a href="#cb4-151" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb4-152"><a href="#cb4-152" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inference_times <span class="op">=</span> []</span>
<span id="cb4-153"><a href="#cb4-153" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.preprocessing_times <span class="op">=</span> []</span>
<span id="cb4-154"><a href="#cb4-154" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-155"><a href="#cb4-155" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> record_inference_time(<span class="va">self</span>, time_ms):</span>
<span id="cb4-156"><a href="#cb4-156" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.inference_times.append(time_ms)</span>
<span id="cb4-157"><a href="#cb4-157" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-158"><a href="#cb4-158" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> record_preprocessing_time(<span class="va">self</span>, time_ms):</span>
<span id="cb4-159"><a href="#cb4-159" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.preprocessing_times.append(time_ms)</span>
<span id="cb4-160"><a href="#cb4-160" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-161"><a href="#cb4-161" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> get_stats(<span class="va">self</span>):</span>
<span id="cb4-162"><a href="#cb4-162" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.inference_times:</span>
<span id="cb4-163"><a href="#cb4-163" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> {<span class="st">"message"</span>: <span class="st">"No inference data recorded"</span>}</span>
<span id="cb4-164"><a href="#cb4-164" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-165"><a href="#cb4-165" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb4-166"><a href="#cb4-166" aria-hidden="true" tabindex="-1"></a>            <span class="st">"avg_inference_time_ms"</span>: np.mean(<span class="va">self</span>.inference_times),</span>
<span id="cb4-167"><a href="#cb4-167" aria-hidden="true" tabindex="-1"></a>            <span class="st">"avg_preprocessing_time_ms"</span>: np.mean(<span class="va">self</span>.preprocessing_times) <span class="cf">if</span> <span class="va">self</span>.preprocessing_times <span class="cf">else</span> <span class="dv">0</span>,</span>
<span id="cb4-168"><a href="#cb4-168" aria-hidden="true" tabindex="-1"></a>            <span class="st">"total_inferences"</span>: <span class="bu">len</span>(<span class="va">self</span>.inference_times),</span>
<span id="cb4-169"><a href="#cb4-169" aria-hidden="true" tabindex="-1"></a>            <span class="st">"min_inference_time_ms"</span>: np.<span class="bu">min</span>(<span class="va">self</span>.inference_times),</span>
<span id="cb4-170"><a href="#cb4-170" aria-hidden="true" tabindex="-1"></a>            <span class="st">"max_inference_time_ms"</span>: np.<span class="bu">max</span>(<span class="va">self</span>.inference_times)</span>
<span id="cb4-171"><a href="#cb4-171" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb4-172"><a href="#cb4-172" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-173"><a href="#cb4-173" aria-hidden="true" tabindex="-1"></a><span class="co"># Global profiler instance</span></span>
<span id="cb4-174"><a href="#cb4-174" aria-hidden="true" tabindex="-1"></a>profiler <span class="op">=</span> ModelProfiler()</span></code></pre></div></div>
</section>
<section id="app__init__.py" class="level3">
<h3 class="anchored" data-anchor-id="app__init__.py" id="app__init__.py"><code>app/__init__.py</code></h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Empty file to make app a Python package</span></span></code></pre></div></div>
</section>
</section>
<section id="configuration-files" class="level2">
<h2 class="anchored" data-anchor-id="configuration-files" id="configuration-files">2. Configuration Files</h2>
<section id="requirements.txt" class="level3">
<h3 class="anchored" data-anchor-id="requirements.txt" id="requirements.txt"><code>requirements.txt</code></h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode txt code-with-copy"><code class="sourceCode default"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a>fastapi==0.104.1</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>uvicorn[standard]==0.24.0</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>torch==2.1.0</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>torchvision==0.16.0</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>Pillow==10.1.0</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>python-multipart==0.0.6</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>numpy==1.24.3</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>requests==2.31.0</span></code></pre></div></div>
</section>
<section id="dockerfile" class="level3">
<h3 class="anchored" data-anchor-id="dockerfile" id="dockerfile"><code>Dockerfile</code></h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode dockerfile code-with-copy"><code class="sourceCode dockerfile"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use official Python runtime as base image</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> python:3.11-slim</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Set working directory</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Install system dependencies</span></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">apt-get</span> update <span class="kw">&amp;&amp;</span> <span class="ex">apt-get</span> install <span class="at">-y</span> <span class="dt">\</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    libgl1-mesa-glx <span class="dt">\</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    libglib2.0-0 <span class="dt">\</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    libsm6 <span class="dt">\</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    libxext6 <span class="dt">\</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>    libxrender-dev <span class="dt">\</span></span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    libgomp1 <span class="dt">\</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    wget <span class="dt">\</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    curl <span class="dt">\</span></span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    <span class="kw">&amp;&amp;</span> <span class="fu">rm</span> <span class="at">-rf</span> /var/lib/apt/lists/<span class="pp">*</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy requirements first for better caching</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> requirements.txt .</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyTorch CPU version (smaller image)</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="at">--upgrade</span> pip <span class="kw">&amp;&amp;</span> <span class="dt">\</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    <span class="ex">pip</span> install torch torchvision <span class="at">--index-url</span> https://download.pytorch.org/whl/cpu <span class="kw">&amp;&amp;</span> <span class="dt">\</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="at">-r</span> requirements.txt</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy application code</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> app/ ./app/</span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a><span class="co"># Create non-root user for security</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">useradd</span> <span class="at">--create-home</span> <span class="at">--shell</span> /bin/bash app <span class="kw">&amp;&amp;</span> <span class="dt">\</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>    <span class="fu">chown</span> <span class="at">-R</span> app:app /app</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a><span class="kw">USER</span> app</span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a><span class="co"># Pre-download ImageNet classes</span></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">python</span> <span class="at">-c</span> <span class="st">"from app.utils import get_imagenet_classes; get_imagenet_classes()"</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Expose port</span></span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a><span class="kw">EXPOSE</span> 8000</span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Health check</span></span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a><span class="kw">HEALTHCHECK</span> <span class="op">--interval=30s</span> <span class="op">--timeout=30s</span> <span class="op">--start-period=60s</span> <span class="op">--retries=3</span> \</span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a>    <span class="kw">CMD</span> <span class="ex">curl</span> <span class="at">-f</span> http://localhost:8000/health <span class="kw">||</span> <span class="bu">exit</span> 1</span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a><span class="co"># Command to run the application</span></span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a><span class="kw">CMD</span> [<span class="st">"uvicorn"</span>, <span class="st">"app.main:app"</span>, <span class="st">"--host"</span>, <span class="st">"0.0.0.0"</span>, <span class="st">"--port"</span>, <span class="st">"8000"</span>]</span></code></pre></div></div>
</section>
<section id="dockerfile.gpu-for-gpu-support" class="level3">
<h3 class="anchored" data-anchor-id="dockerfile.gpu-for-gpu-support" id="dockerfile.gpu-for-gpu-support"><code>Dockerfile.gpu</code> (For GPU Support)</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode dockerfile code-with-copy"><code class="sourceCode dockerfile"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use NVIDIA PyTorch base image</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Set working directory</span></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Install system dependencies</span></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">apt-get</span> update <span class="kw">&amp;&amp;</span> <span class="ex">apt-get</span> install <span class="at">-y</span> <span class="dt">\</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    curl <span class="dt">\</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    wget <span class="dt">\</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    <span class="kw">&amp;&amp;</span> <span class="fu">rm</span> <span class="at">-rf</span> /var/lib/apt/lists/<span class="pp">*</span></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy requirements first for better caching</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> requirements.txt .</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Install additional dependencies</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="at">--upgrade</span> pip <span class="kw">&amp;&amp;</span> <span class="dt">\</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    <span class="ex">pip</span> install <span class="at">--no-cache-dir</span> <span class="at">-r</span> requirements.txt</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy application code</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> app/ ./app/</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Create non-root user for security</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">useradd</span> <span class="at">--create-home</span> <span class="at">--shell</span> /bin/bash app <span class="kw">&amp;&amp;</span> <span class="dt">\</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>    <span class="fu">chown</span> <span class="at">-R</span> app:app /app</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a><span class="kw">USER</span> app</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Pre-download ImageNet classes</span></span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">python</span> <span class="at">-c</span> <span class="st">"from app.utils import get_imagenet_classes; get_imagenet_classes()"</span></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Expose port</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a><span class="kw">EXPOSE</span> 8000</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Health check</span></span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a><span class="kw">HEALTHCHECK</span> <span class="op">--interval=30s</span> <span class="op">--timeout=30s</span> <span class="op">--start-period=60s</span> <span class="op">--retries=3</span> \</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">CMD</span> <span class="ex">curl</span> <span class="at">-f</span> http://localhost:8000/health <span class="kw">||</span> <span class="bu">exit</span> 1</span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Command to run the application</span></span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a><span class="kw">CMD</span> [<span class="st">"uvicorn"</span>, <span class="st">"app.main:app"</span>, <span class="st">"--host"</span>, <span class="st">"0.0.0.0"</span>, <span class="st">"--port"</span>, <span class="st">"8000"</span>]</span></code></pre></div></div>
</section>
<section id="docker-compose.yml" class="level3">
<h3 class="anchored" data-anchor-id="docker-compose.yml" id="docker-compose.yml"><code>docker-compose.yml</code></h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="fu">version</span><span class="kw">:</span><span class="at"> </span><span class="st">'3.8'</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="fu">services</span><span class="kw">:</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">mobilenetv2-pytorch-api</span><span class="kw">:</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">build</span><span class="kw">:</span><span class="at"> .</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"8000:8000"</span></span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">environment</span><span class="kw">:</span></span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> PYTHONPATH=/app</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> TORCH_HOME=/app/.torch</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ./logs:/app/logs</span></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> torch_cache:/app/.torch</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">restart</span><span class="kw">:</span><span class="at"> unless-stopped</span></span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">healthcheck</span><span class="kw">:</span></span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">test</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="st">"CMD"</span><span class="kw">,</span><span class="at"> </span><span class="st">"curl"</span><span class="kw">,</span><span class="at"> </span><span class="st">"-f"</span><span class="kw">,</span><span class="at"> </span><span class="st">"http://localhost:8000/health"</span><span class="kw">]</span></span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">interval</span><span class="kw">:</span><span class="at"> 30s</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">timeout</span><span class="kw">:</span><span class="at"> 10s</span></span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">retries</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">start_period</span><span class="kw">:</span><span class="at"> 60s</span></span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">deploy</span><span class="kw">:</span></span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> 2G</span></span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">reservations</span><span class="kw">:</span></span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> 1G</span></span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a><span class="co">  # GPU version (uncomment and modify as needed)</span></span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a><span class="co">  # mobilenetv2-pytorch-gpu:</span></span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a><span class="co">  #   build:</span></span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a><span class="co">  #     context: .</span></span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a><span class="co">  #     dockerfile: Dockerfile.gpu</span></span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a><span class="co">  #   ports:</span></span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - "8000:8000"</span></span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a><span class="co">  #   environment:</span></span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - PYTHONPATH=/app</span></span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - TORCH_HOME=/app/.torch</span></span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - NVIDIA_VISIBLE_DEVICES=all</span></span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a><span class="co">  #   volumes:</span></span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - ./logs:/app/logs</span></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a><span class="co">  #     - torch_cache:/app/.torch</span></span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a><span class="co">  #   restart: unless-stopped</span></span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a><span class="co">  #   deploy:</span></span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a><span class="co">  #     resources:</span></span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a><span class="co">  #       reservations:</span></span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a><span class="co">  #         devices:</span></span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a><span class="co">  #           - driver: nvidia</span></span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a><span class="co">  #             count: 1</span></span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a><span class="co">  #             capabilities: [gpu]</span></span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a><span class="co">  # Optional: Add nginx for production</span></span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">nginx</span><span class="kw">:</span></span>
<span id="cb9-53"><a href="#cb9-53" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">image</span><span class="kw">:</span><span class="at"> nginx:alpine</span></span>
<span id="cb9-54"><a href="#cb9-54" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb9-55"><a href="#cb9-55" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="st">"80:80"</span></span>
<span id="cb9-56"><a href="#cb9-56" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb9-57"><a href="#cb9-57" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> ./nginx.conf:/etc/nginx/nginx.conf:ro</span></span>
<span id="cb9-58"><a href="#cb9-58" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">depends_on</span><span class="kw">:</span></span>
<span id="cb9-59"><a href="#cb9-59" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb9-60"><a href="#cb9-60" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">restart</span><span class="kw">:</span><span class="at"> unless-stopped</span></span>
<span id="cb9-61"><a href="#cb9-61" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-62"><a href="#cb9-62" aria-hidden="true" tabindex="-1"></a><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb9-63"><a href="#cb9-63" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">torch_cache</span><span class="kw">:</span></span></code></pre></div></div>
</section>
<section id="dockerignore" class="level3">
<h3 class="anchored" data-anchor-id="dockerignore" id="dockerignore"><code>.dockerignore</code></h3>
<pre><code>__pycache__
*.pyc
*.pyo
*.pyd
.Python
env/
pip-log.txt
pip-delete-this-directory.txt
.git
.gitignore
README.md
.pytest_cache
.coverage
.nyc_output
node_modules
.DS_Store
*.log
logs/
*.pth
*.pt
.torch/</code></pre>
</section>
<section id="nginx.conf-optional---for-production" class="level3">
<h3 class="anchored" data-anchor-id="nginx.conf-optional---for-production" id="nginx.conf-optional---for-production"><code>nginx.conf</code> (Optional - for production)</h3>
<pre class="nginx"><code>events {
    worker_connections 1024;
}

http {
    upstream api {
        server mobilenetv2-pytorch-api:8000;
    }

    server {
        listen 80;
        client_max_body_size 10M;

        location / {
            proxy_pass http://api;
            proxy_set_header Host $host;
            proxy_set_header X-Real-IP $remote_addr;
            proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
            proxy_set_header X-Forwarded-Proto $scheme;
            proxy_read_timeout 300;
            proxy_connect_timeout 300;
            proxy_send_timeout 300;
        }
    }
}</code></pre>
</section>
</section>
<section id="deployment-commands" class="level2">
<h2 class="anchored" data-anchor-id="deployment-commands" id="deployment-commands">3. Deployment Commands</h2>
<section id="build-and-run-with-docker" class="level3">
<h3 class="anchored" data-anchor-id="build-and-run-with-docker" id="build-and-run-with-docker">Build and Run with Docker</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build the CPU image</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> build <span class="at">-t</span> mobilenetv2-pytorch-api .</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Build the GPU image (if you have NVIDIA GPU)</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> build <span class="at">-f</span> Dockerfile.gpu <span class="at">-t</span> mobilenetv2-pytorch-gpu .</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Run CPU version</span></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> run <span class="at">-p</span> 8000:8000 mobilenetv2-pytorch-api</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Run GPU version</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> run <span class="at">--gpus</span> all <span class="at">-p</span> 8000:8000 mobilenetv2-pytorch-gpu</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Run with environment variables</span></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a><span class="ex">docker</span> run <span class="at">-p</span> 8000:8000 <span class="at">-e</span> TORCH_HOME=/tmp/.torch mobilenetv2-pytorch-api</span></code></pre></div></div>
</section>
<section id="using-docker-compose" class="level3">
<h3 class="anchored" data-anchor-id="using-docker-compose" id="using-docker-compose">Using Docker Compose</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Build and start services</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="ex">docker-compose</span> up <span class="at">--build</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Run in background</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="ex">docker-compose</span> up <span class="at">-d</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co"># View logs</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="ex">docker-compose</span> logs <span class="at">-f</span> mobilenetv2-pytorch-api</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Stop services</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a><span class="ex">docker-compose</span> down</span></code></pre></div></div>
</section>
</section>
<section id="usage-examples" class="level2">
<h2 class="anchored" data-anchor-id="usage-examples" id="usage-examples">4. Usage Examples</h2>
<section id="test-the-api" class="level3">
<h3 class="anchored" data-anchor-id="test-the-api" id="test-the-api">Test the API</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Health check</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> http://localhost:8000/health</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Model info</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> http://localhost:8000/model_info</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Single image prediction</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> <span class="at">-X</span> POST <span class="st">"http://localhost:8000/predict"</span> <span class="dt">\</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>     <span class="at">-H</span> <span class="st">"accept: application/json"</span> <span class="dt">\</span></span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>     <span class="at">-H</span> <span class="st">"Content-Type: multipart/form-data"</span> <span class="dt">\</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>     <span class="at">-F</span> <span class="st">"file=@path/to/your/image.jpg"</span></span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Batch prediction</span></span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a><span class="ex">curl</span> <span class="at">-X</span> POST <span class="st">"http://localhost:8000/batch_predict"</span> <span class="dt">\</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>     <span class="at">-H</span> <span class="st">"accept: application/json"</span> <span class="dt">\</span></span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>     <span class="at">-H</span> <span class="st">"Content-Type: multipart/form-data"</span> <span class="dt">\</span></span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>     <span class="at">-F</span> <span class="st">"files=@image1.jpg"</span> <span class="dt">\</span></span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>     <span class="at">-F</span> <span class="st">"files=@image2.jpg"</span></span></code></pre></div></div>
</section>
<section id="python-client-example" class="level3">
<h3 class="anchored" data-anchor-id="python-client-example" id="python-client-example">Python Client Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> requests</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> json</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Single prediction</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> predict_image(image_path, api_url<span class="op">=</span><span class="st">"http://localhost:8000"</span>):</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>    url <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>api_url<span class="sc">}</span><span class="ss">/predict"</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>    files <span class="op">=</span> {<span class="st">"file"</span>: <span class="bu">open</span>(image_path, <span class="st">"rb"</span>)}</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    response <span class="op">=</span> requests.post(url, files<span class="op">=</span>files)</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> response.json()</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Batch prediction</span></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> predict_batch(image_paths, api_url<span class="op">=</span><span class="st">"http://localhost:8000"</span>):</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    url <span class="op">=</span> <span class="ss">f"</span><span class="sc">{</span>api_url<span class="sc">}</span><span class="ss">/batch_predict"</span></span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    files <span class="op">=</span> [(<span class="st">"files"</span>, <span class="bu">open</span>(path, <span class="st">"rb"</span>)) <span class="cf">for</span> path <span class="kw">in</span> image_paths]</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    response <span class="op">=</span> requests.post(url, files<span class="op">=</span>files)</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> response.json()</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> predict_image(<span class="st">"cat.jpg"</span>)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(json.dumps(result, indent<span class="op">=</span><span class="dv">2</span>))</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>batch_result <span class="op">=</span> predict_batch([<span class="st">"cat.jpg"</span>, <span class="st">"dog.jpg"</span>])</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(json.dumps(batch_result, indent<span class="op">=</span><span class="dv">2</span>))</span></code></pre></div></div>
</section>
<section id="response-format" class="level3">
<h3 class="anchored" data-anchor-id="response-format" id="response-format">Response Format</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode json code-with-copy"><code class="sourceCode json"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="fu">{</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"predictions"</span><span class="fu">:</span> <span class="ot">[</span></span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    <span class="fu">{</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"class_id"</span><span class="fu">:</span> <span class="dv">281</span><span class="fu">,</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"class_name"</span><span class="fu">:</span> <span class="st">"tabby"</span><span class="fu">,</span></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"confidence"</span><span class="fu">:</span> <span class="fl">0.8234567</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span><span class="ot">,</span></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>    <span class="fu">{</span></span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"class_id"</span><span class="fu">:</span> <span class="dv">282</span><span class="fu">,</span></span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"class_name"</span><span class="fu">:</span> <span class="st">"tiger_cat"</span><span class="fu">,</span></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"confidence"</span><span class="fu">:</span> <span class="fl">0.1234567</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>  <span class="ot">]</span><span class="fu">,</span></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"success"</span><span class="fu">:</span> <span class="kw">true</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a><span class="fu">}</span></span></code></pre></div></div>
</section>
</section>
<section id="performance-optimization" class="level2">
<h2 class="anchored" data-anchor-id="performance-optimization" id="performance-optimization">5. Performance Optimization</h2>
<section id="model-optimization" class="level3">
<h3 class="anchored" data-anchor-id="model-optimization" id="model-optimization">Model Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add to model_handler.py for optimization</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.jit</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> OptimizedMobileNetV2Handler(MobileNetV2Handler):</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, use_jit<span class="op">=</span><span class="va">True</span>, use_half_precision<span class="op">=</span><span class="va">False</span>):</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.use_jit <span class="op">=</span> use_jit</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.use_half_precision <span class="op">=</span> use_half_precision</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> load_model(<span class="va">self</span>):</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().load_model()</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.use_jit:</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># TorchScript compilation for faster inference</span></span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model <span class="op">=</span> torch.jit.script(<span class="va">self</span>.model)</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="st">"Model compiled with TorchScript"</span>)</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.use_half_precision <span class="kw">and</span> <span class="va">self</span>.device.<span class="bu">type</span> <span class="op">==</span> <span class="st">'cuda'</span>:</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Half precision for GPU</span></span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>.model.half()</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>            logger.info(<span class="st">"Model converted to half precision"</span>)</span></code></pre></div></div>
</section>
<section id="docker-optimization" class="level3">
<h3 class="anchored" data-anchor-id="docker-optimization" id="docker-optimization">Docker Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode dockerfile code-with-copy"><code class="sourceCode dockerfile"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Multi-stage build for smaller image</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> python:3.11-slim as builder</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> requirements.txt .</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a><span class="kw">RUN</span> <span class="ex">pip</span> install <span class="at">--user</span> <span class="at">--no-cache-dir</span> <span class="at">-r</span> requirements.txt</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a><span class="kw">FROM</span> python:3.11-slim</span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a><span class="kw">WORKDIR</span> /app</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> <span class="op">--from=builder</span> /root/.local /root/.local</span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a><span class="kw">COPY</span> app/ ./app/</span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Make sure scripts in .local are usable</span></span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a><span class="kw">ENV</span> PATH=/root/.local/bin:$PATH</span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a><span class="kw">EXPOSE</span> 8000</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a><span class="kw">CMD</span> [<span class="st">"uvicorn"</span>, <span class="st">"app.main:app"</span>, <span class="st">"--host"</span>, <span class="st">"0.0.0.0"</span>, <span class="st">"--port"</span>, <span class="st">"8000"</span>]</span></code></pre></div></div>
</section>
</section>
<section id="monitoring-and-logging" class="level2">
<h2 class="anchored" data-anchor-id="monitoring-and-logging" id="monitoring-and-logging">6. Monitoring and Logging</h2>
<section id="enhanced-logging" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-logging" id="enhanced-logging">Enhanced Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add to main.py</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> app.utils <span class="im">import</span> profiler</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="at">@app.middleware</span>(<span class="st">"http"</span>)</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> log_requests(request, call_next):</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    start_time <span class="op">=</span> time.time()</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    response <span class="op">=</span> <span class="cf">await</span> call_next(request)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    process_time <span class="op">=</span> (time.time() <span class="op">-</span> start_time) <span class="op">*</span> <span class="dv">1000</span></span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    logger.info(<span class="ss">f"</span><span class="sc">{</span>request<span class="sc">.</span>method<span class="sc">}</span><span class="ss"> </span><span class="sc">{</span>request<span class="sc">.</span>url<span class="sc">.</span>path<span class="sc">}</span><span class="ss"> - </span><span class="sc">{</span>process_time<span class="sc">:.2f}</span><span class="ss">ms"</span>)</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> request.url.path <span class="op">==</span> <span class="st">"/predict"</span>:</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a>        profiler.record_inference_time(process_time)</span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> response</span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a><span class="at">@app.get</span>(<span class="st">"/stats"</span>)</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> get_stats():</span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>    <span class="co">"""Get performance statistics"""</span></span>
<span id="cb19-21"><a href="#cb19-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> profiler.get_stats()</span></code></pre></div></div>
</section>
</section>
<section id="cloud-deployment" class="level2">
<h2 class="anchored" data-anchor-id="cloud-deployment" id="cloud-deployment">7. Cloud Deployment</h2>
<section id="aws-ecs-task-definition" class="level3">
<h3 class="anchored" data-anchor-id="aws-ecs-task-definition" id="aws-ecs-task-definition">AWS ECS Task Definition</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode json code-with-copy"><code class="sourceCode json"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="fu">{</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"family"</span><span class="fu">:</span> <span class="st">"mobilenetv2-pytorch-task"</span><span class="fu">,</span></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"networkMode"</span><span class="fu">:</span> <span class="st">"awsvpc"</span><span class="fu">,</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"requiresCompatibilities"</span><span class="fu">:</span> <span class="ot">[</span><span class="st">"FARGATE"</span><span class="ot">]</span><span class="fu">,</span></span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"cpu"</span><span class="fu">:</span> <span class="st">"1024"</span><span class="fu">,</span></span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"memory"</span><span class="fu">:</span> <span class="st">"2048"</span><span class="fu">,</span></span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>  <span class="dt">"containerDefinitions"</span><span class="fu">:</span> <span class="ot">[</span></span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a>    <span class="fu">{</span></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"name"</span><span class="fu">:</span> <span class="st">"mobilenetv2-pytorch-api"</span><span class="fu">,</span></span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"image"</span><span class="fu">:</span> <span class="st">"your-registry/mobilenetv2-pytorch-api:latest"</span><span class="fu">,</span></span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"portMappings"</span><span class="fu">:</span> <span class="ot">[</span></span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>        <span class="fu">{</span></span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"containerPort"</span><span class="fu">:</span> <span class="dv">8000</span><span class="fu">,</span></span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"protocol"</span><span class="fu">:</span> <span class="st">"tcp"</span></span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>        <span class="fu">}</span></span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>      <span class="ot">]</span><span class="fu">,</span></span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"environment"</span><span class="fu">:</span> <span class="ot">[</span></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>        <span class="fu">{</span></span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"name"</span><span class="fu">:</span> <span class="st">"TORCH_HOME"</span><span class="fu">,</span></span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"value"</span><span class="fu">:</span> <span class="st">"/tmp/.torch"</span></span>
<span id="cb20-21"><a href="#cb20-21" aria-hidden="true" tabindex="-1"></a>        <span class="fu">}</span></span>
<span id="cb20-22"><a href="#cb20-22" aria-hidden="true" tabindex="-1"></a>      <span class="ot">]</span><span class="fu">,</span></span>
<span id="cb20-23"><a href="#cb20-23" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"essential"</span><span class="fu">:</span> <span class="kw">true</span><span class="fu">,</span></span>
<span id="cb20-24"><a href="#cb20-24" aria-hidden="true" tabindex="-1"></a>      <span class="dt">"logConfiguration"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb20-25"><a href="#cb20-25" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"logDriver"</span><span class="fu">:</span> <span class="st">"awslogs"</span><span class="fu">,</span></span>
<span id="cb20-26"><a href="#cb20-26" aria-hidden="true" tabindex="-1"></a>        <span class="dt">"options"</span><span class="fu">:</span> <span class="fu">{</span></span>
<span id="cb20-27"><a href="#cb20-27" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"awslogs-group"</span><span class="fu">:</span> <span class="st">"/ecs/mobilenetv2-pytorch"</span><span class="fu">,</span></span>
<span id="cb20-28"><a href="#cb20-28" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"awslogs-region"</span><span class="fu">:</span> <span class="st">"us-east-1"</span><span class="fu">,</span></span>
<span id="cb20-29"><a href="#cb20-29" aria-hidden="true" tabindex="-1"></a>          <span class="dt">"awslogs-stream-prefix"</span><span class="fu">:</span> <span class="st">"ecs"</span></span>
<span id="cb20-30"><a href="#cb20-30" aria-hidden="true" tabindex="-1"></a>        <span class="fu">}</span></span>
<span id="cb20-31"><a href="#cb20-31" aria-hidden="true" tabindex="-1"></a>      <span class="fu">}</span></span>
<span id="cb20-32"><a href="#cb20-32" aria-hidden="true" tabindex="-1"></a>    <span class="fu">}</span></span>
<span id="cb20-33"><a href="#cb20-33" aria-hidden="true" tabindex="-1"></a>  <span class="ot">]</span></span>
<span id="cb20-34"><a href="#cb20-34" aria-hidden="true" tabindex="-1"></a><span class="fu">}</span></span></code></pre></div></div>
</section>
<section id="kubernetes-deployment" class="level3">
<h3 class="anchored" data-anchor-id="kubernetes-deployment" id="kubernetes-deployment">Kubernetes Deployment</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> apps/v1</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Deployment</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">replicas</span><span class="kw">:</span><span class="at"> </span><span class="dv">3</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">selector</span><span class="kw">:</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">matchLabels</span><span class="kw">:</span></span>
<span id="cb21-9"><a href="#cb21-9" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb21-10"><a href="#cb21-10" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">template</span><span class="kw">:</span></span>
<span id="cb21-11"><a href="#cb21-11" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb21-12"><a href="#cb21-12" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">labels</span><span class="kw">:</span></span>
<span id="cb21-13"><a href="#cb21-13" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb21-14"><a href="#cb21-14" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb21-15"><a href="#cb21-15" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">containers</span><span class="kw">:</span></span>
<span id="cb21-16"><a href="#cb21-16" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb21-17"><a href="#cb21-17" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">image</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api:latest</span></span>
<span id="cb21-18"><a href="#cb21-18" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb21-19"><a href="#cb21-19" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="kw">-</span><span class="at"> </span><span class="fu">containerPort</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
<span id="cb21-20"><a href="#cb21-20" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">env</span><span class="kw">:</span></span>
<span id="cb21-21"><a href="#cb21-21" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> TORCH_HOME</span></span>
<span id="cb21-22"><a href="#cb21-22" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">value</span><span class="kw">:</span><span class="at"> /tmp/.torch</span></span>
<span id="cb21-23"><a href="#cb21-23" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">resources</span><span class="kw">:</span></span>
<span id="cb21-24"><a href="#cb21-24" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">requests</span><span class="kw">:</span></span>
<span id="cb21-25"><a href="#cb21-25" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"1Gi"</span></span>
<span id="cb21-26"><a href="#cb21-26" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"500m"</span></span>
<span id="cb21-27"><a href="#cb21-27" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">limits</span><span class="kw">:</span></span>
<span id="cb21-28"><a href="#cb21-28" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">memory</span><span class="kw">:</span><span class="at"> </span><span class="st">"2Gi"</span></span>
<span id="cb21-29"><a href="#cb21-29" aria-hidden="true" tabindex="-1"></a><span class="at">            </span><span class="fu">cpu</span><span class="kw">:</span><span class="at"> </span><span class="st">"1000m"</span></span>
<span id="cb21-30"><a href="#cb21-30" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">volumeMounts</span><span class="kw">:</span></span>
<span id="cb21-31"><a href="#cb21-31" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> torch-cache</span></span>
<span id="cb21-32"><a href="#cb21-32" aria-hidden="true" tabindex="-1"></a><span class="at">          </span><span class="fu">mountPath</span><span class="kw">:</span><span class="at"> /tmp/.torch</span></span>
<span id="cb21-33"><a href="#cb21-33" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">volumes</span><span class="kw">:</span></span>
<span id="cb21-34"><a href="#cb21-34" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="kw">-</span><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> torch-cache</span></span>
<span id="cb21-35"><a href="#cb21-35" aria-hidden="true" tabindex="-1"></a><span class="at">        </span><span class="fu">emptyDir</span><span class="kw">:</span><span class="at"> </span><span class="kw">{}</span></span>
<span id="cb21-36"><a href="#cb21-36" aria-hidden="true" tabindex="-1"></a><span class="pp">---</span></span>
<span id="cb21-37"><a href="#cb21-37" aria-hidden="true" tabindex="-1"></a><span class="fu">apiVersion</span><span class="kw">:</span><span class="at"> v1</span></span>
<span id="cb21-38"><a href="#cb21-38" aria-hidden="true" tabindex="-1"></a><span class="fu">kind</span><span class="kw">:</span><span class="at"> Service</span></span>
<span id="cb21-39"><a href="#cb21-39" aria-hidden="true" tabindex="-1"></a><span class="fu">metadata</span><span class="kw">:</span></span>
<span id="cb21-40"><a href="#cb21-40" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">name</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-service</span></span>
<span id="cb21-41"><a href="#cb21-41" aria-hidden="true" tabindex="-1"></a><span class="fu">spec</span><span class="kw">:</span></span>
<span id="cb21-42"><a href="#cb21-42" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">selector</span><span class="kw">:</span></span>
<span id="cb21-43"><a href="#cb21-43" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="fu">app</span><span class="kw">:</span><span class="at"> mobilenetv2-pytorch-api</span></span>
<span id="cb21-44"><a href="#cb21-44" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">ports</span><span class="kw">:</span></span>
<span id="cb21-45"><a href="#cb21-45" aria-hidden="true" tabindex="-1"></a><span class="at">    </span><span class="kw">-</span><span class="at"> </span><span class="fu">protocol</span><span class="kw">:</span><span class="at"> TCP</span></span>
<span id="cb21-46"><a href="#cb21-46" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">port</span><span class="kw">:</span><span class="at"> </span><span class="dv">80</span></span>
<span id="cb21-47"><a href="#cb21-47" aria-hidden="true" tabindex="-1"></a><span class="at">      </span><span class="fu">targetPort</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
<span id="cb21-48"><a href="#cb21-48" aria-hidden="true" tabindex="-1"></a><span class="at">  </span><span class="fu">type</span><span class="kw">:</span><span class="at"> LoadBalancer</span></span></code></pre></div></div>
<p>This PyTorch-based guide provides the same functionality as the TensorFlow version but uses PyTorch’s ecosystem, including torchvision for pre-trained models, PyTorch transformations for preprocessing, and proper tensor handling throughout the application.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch to PyTorch Lightning Migration Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/pytorch-to-pytorchlightning/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/pytorch-to-pytorchlightning/</guid>
      <pubDate>Sun, 25 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="pytorch-to-pytorch-lightning-migration-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/pytorch-to-pytorchlightning/pyt.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>PyTorch Lightning is a lightweight wrapper around PyTorch that eliminates boilerplate code while maintaining full control over your models. It provides a structured approach to organizing PyTorch code and includes built-in support for distributed training, logging, and experiment management.</p>
<section id="why-migrate" class="level3">
<h3 class="anchored" data-anchor-id="why-migrate" id="why-migrate">Why Migrate?</h3>
<ul>
<li><strong>Reduced Boilerplate</strong>: Lightning handles training loops, device management, and distributed training</li>
<li><strong>Better Organization</strong>: Standardized code structure improves readability and maintenance</li>
<li><strong>Built-in Features</strong>: Automatic logging, checkpointing, early stopping, and more</li>
<li><strong>Scalability</strong>: Easy multi-GPU and multi-node training</li>
<li><strong>Reproducibility</strong>: Better experiment tracking and configuration management</li>
</ul>
</section>
</section>
<section id="key-concepts" class="level2">
<h2 class="anchored" data-anchor-id="key-concepts" id="key-concepts">Key Concepts</h2>
<section id="lightningmodule" class="level3">
<h3 class="anchored" data-anchor-id="lightningmodule" id="lightningmodule">LightningModule</h3>
<p>The core abstraction that wraps your PyTorch model. It defines:</p>
<ul>
<li>Model architecture (<code>__init__</code>)</li>
<li>Forward pass (<code>forward</code>)</li>
<li>Training step (<code>training_step</code>)</li>
<li>Validation step (<code>validation_step</code>)</li>
<li>Optimizer configuration (<code>configure_optimizers</code>)</li>
</ul>
</section>
<section id="trainer" class="level3">
<h3 class="anchored" data-anchor-id="trainer" id="trainer">Trainer</h3>
<p>Handles the training loop, device management, and various training configurations.</p>
</section>
<section id="datamodule" class="level3">
<h3 class="anchored" data-anchor-id="datamodule" id="datamodule">DataModule</h3>
<p>Encapsulates data loading logic, including datasets and dataloaders.</p>
</section>
</section>
<section id="basic-migration-steps" class="level2">
<h2 class="anchored" data-anchor-id="basic-migration-steps" id="basic-migration-steps">Basic Migration Steps</h2>
<section id="step-1-convert-model-to-lightningmodule" class="level3">
<h3 class="anchored" data-anchor-id="step-1-convert-model-to-lightningmodule" id="step-1-convert-model-to-lightningmodule">Step 1: Convert Model to LightningModule</h3>
<p><strong>Before (PyTorch):</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleModel(nn.Module):</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, hidden_size, num_classes):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(input_size, hidden_size)</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relu <span class="op">=</span> nn.ReLU()</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(hidden_size, num_classes)</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.relu(x)</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
<p><strong>After (Lightning):</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytorch_lightning <span class="im">as</span> pl</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> LightningModel(pl.LightningModule):</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, input_size, hidden_size, num_classes, learning_rate<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(input_size, hidden_size)</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relu <span class="op">=</span> nn.ReLU()</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(hidden_size, num_classes)</span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.relu(x)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(y_hat, y)</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss)</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(y_hat, y)</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> (y_hat.argmax(dim<span class="op">=</span><span class="dv">1</span>) <span class="op">==</span> y).<span class="bu">float</span>().mean()</span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss)</span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_acc'</span>, acc)</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> torch.optim.Adam(<span class="va">self</span>.parameters(), lr<span class="op">=</span><span class="va">self</span>.hparams.learning_rate)</span></code></pre></div></div>
</section>
<section id="step-2-replace-training-loop-with-trainer" class="level3">
<h3 class="anchored" data-anchor-id="step-2-replace-training-loop-with-trainer" id="step-2-replace-training-loop-with-trainer">Step 2: Replace Training Loop with Trainer</h3>
<p><strong>Before (PyTorch):</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop</span></span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SimpleModel(input_size<span class="op">=</span><span class="dv">784</span>, hidden_size<span class="op">=</span><span class="dv">128</span>, num_classes<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>model.to(device)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch_idx, (data, target) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(data)</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(output, target)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f'Epoch: </span><span class="sc">{</span>epoch<span class="sc">}</span><span class="ss">, Batch: </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss">, Loss: </span><span class="sc">{</span>loss<span class="sc">.</span>item()<span class="sc">}</span><span class="ss">'</span>)</span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validation</span></span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    val_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> data, target <span class="kw">in</span> val_loader:</span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>            data, target <span class="op">=</span> data.to(device), target.to(device)</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>            output <span class="op">=</span> model(data)</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>            val_loss <span class="op">+=</span> criterion(output, target).item()</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>            pred <span class="op">=</span> output.argmax(dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> pred.eq(target).<span class="bu">sum</span>().item()</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f'Validation Loss: </span><span class="sc">{</span>val_loss<span class="op">/</span><span class="bu">len</span>(val_loader)<span class="sc">}</span><span class="ss">, Accuracy: </span><span class="sc">{</span>correct<span class="op">/</span><span class="bu">len</span>(val_dataset)<span class="sc">}</span><span class="ss">'</span>)</span></code></pre></div></div>
<p><strong>After (Lightning):</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Training with Lightning</span></span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> LightningModel(input_size<span class="op">=</span><span class="dv">784</span>, hidden_size<span class="op">=</span><span class="dv">128</span>, num_classes<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> pl.Trainer(max_epochs<span class="op">=</span><span class="dv">10</span>, accelerator<span class="op">=</span><span class="st">'auto'</span>, devices<span class="op">=</span><span class="st">'auto'</span>)</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, train_loader, val_loader)</span></code></pre></div></div>
</section>
</section>
<section id="code-examples" class="level2">
<h2 class="anchored" data-anchor-id="code-examples" id="code-examples">Code Examples</h2>
<section id="complete-mnist-example" class="level3">
<h3 class="anchored" data-anchor-id="complete-mnist-example" id="complete-mnist-example">Complete MNIST Example</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytorch_lightning <span class="im">as</span> pl</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, random_split</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MNISTDataModule(pl.LightningDataModule):</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_dir: <span class="bu">str</span> <span class="op">=</span> <span class="st">"data/"</span>, batch_size: <span class="bu">int</span> <span class="op">=</span> <span class="dv">64</span>):</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data_dir <span class="op">=</span> data_dir</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.batch_size <span class="op">=</span> batch_size</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> prepare_data(<span class="va">self</span>):</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Download data</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        datasets.MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        datasets.MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">False</span>, download<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, stage: <span class="bu">str</span>):</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Assign train/val datasets</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> stage <span class="op">==</span> <span class="st">'fit'</span>:</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>            mnist_full <span class="op">=</span> datasets.MNIST(</span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span><span class="va">self</span>.transform</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb5-29"><a href="#cb5-29" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mnist_train, <span class="va">self</span>.mnist_val <span class="op">=</span> random_split(</span>
<span id="cb5-30"><a href="#cb5-30" aria-hidden="true" tabindex="-1"></a>                mnist_full, [<span class="dv">55000</span>, <span class="dv">5000</span>]</span>
<span id="cb5-31"><a href="#cb5-31" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb5-32"><a href="#cb5-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-33"><a href="#cb5-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> stage <span class="op">==</span> <span class="st">'test'</span>:</span>
<span id="cb5-34"><a href="#cb5-34" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mnist_test <span class="op">=</span> datasets.MNIST(</span>
<span id="cb5-35"><a href="#cb5-35" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">False</span>, transform<span class="op">=</span><span class="va">self</span>.transform</span>
<span id="cb5-36"><a href="#cb5-36" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb5-37"><a href="#cb5-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-38"><a href="#cb5-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_dataloader(<span class="va">self</span>):</span>
<span id="cb5-39"><a href="#cb5-39" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_train, batch_size<span class="op">=</span><span class="va">self</span>.batch_size, shuffle<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-40"><a href="#cb5-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-41"><a href="#cb5-41" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> val_dataloader(<span class="va">self</span>):</span>
<span id="cb5-42"><a href="#cb5-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_val, batch_size<span class="op">=</span><span class="va">self</span>.batch_size)</span>
<span id="cb5-43"><a href="#cb5-43" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-44"><a href="#cb5-44" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_dataloader(<span class="va">self</span>):</span>
<span id="cb5-45"><a href="#cb5-45" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_test, batch_size<span class="op">=</span><span class="va">self</span>.batch_size)</span>
<span id="cb5-46"><a href="#cb5-46" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-47"><a href="#cb5-47" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MNISTClassifier(pl.LightningModule):</span>
<span id="cb5-48"><a href="#cb5-48" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, learning_rate<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb5-49"><a href="#cb5-49" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-50"><a href="#cb5-50" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()</span>
<span id="cb5-51"><a href="#cb5-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-52"><a href="#cb5-52" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">1</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">1</span>)</span>
<span id="cb5-53"><a href="#cb5-53" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>, <span class="dv">1</span>)</span>
<span id="cb5-54"><a href="#cb5-54" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout1 <span class="op">=</span> nn.Dropout(<span class="fl">0.25</span>)</span>
<span id="cb5-55"><a href="#cb5-55" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout2 <span class="op">=</span> nn.Dropout(<span class="fl">0.5</span>)</span>
<span id="cb5-56"><a href="#cb5-56" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(<span class="dv">9216</span>, <span class="dv">128</span>)</span>
<span id="cb5-57"><a href="#cb5-57" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb5-58"><a href="#cb5-58" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-59"><a href="#cb5-59" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb5-60"><a href="#cb5-60" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv1(x)</span>
<span id="cb5-61"><a href="#cb5-61" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb5-62"><a href="#cb5-62" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv2(x)</span>
<span id="cb5-63"><a href="#cb5-63" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb5-64"><a href="#cb5-64" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb5-65"><a href="#cb5-65" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout1(x)</span>
<span id="cb5-66"><a href="#cb5-66" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb5-67"><a href="#cb5-67" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb5-68"><a href="#cb5-68" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb5-69"><a href="#cb5-69" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout2(x)</span>
<span id="cb5-70"><a href="#cb5-70" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb5-71"><a href="#cb5-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> F.log_softmax(x, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-72"><a href="#cb5-72" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-73"><a href="#cb5-73" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb5-74"><a href="#cb5-74" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb5-75"><a href="#cb5-75" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb5-76"><a href="#cb5-76" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.nll_loss(logits, y)</span>
<span id="cb5-77"><a href="#cb5-77" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">"train_loss"</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-78"><a href="#cb5-78" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb5-79"><a href="#cb5-79" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-80"><a href="#cb5-80" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb5-81"><a href="#cb5-81" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb5-82"><a href="#cb5-82" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb5-83"><a href="#cb5-83" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.nll_loss(logits, y)</span>
<span id="cb5-84"><a href="#cb5-84" aria-hidden="true" tabindex="-1"></a>        preds <span class="op">=</span> torch.argmax(logits, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-85"><a href="#cb5-85" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> torch.<span class="bu">sum</span>(preds <span class="op">==</span> y).<span class="bu">float</span>() <span class="op">/</span> <span class="bu">len</span>(y)</span>
<span id="cb5-86"><a href="#cb5-86" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">"val_loss"</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-87"><a href="#cb5-87" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">"val_acc"</span>, acc, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb5-88"><a href="#cb5-88" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-89"><a href="#cb5-89" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb5-90"><a href="#cb5-90" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb5-91"><a href="#cb5-91" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb5-92"><a href="#cb5-92" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.nll_loss(logits, y)</span>
<span id="cb5-93"><a href="#cb5-93" aria-hidden="true" tabindex="-1"></a>        preds <span class="op">=</span> torch.argmax(logits, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-94"><a href="#cb5-94" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> torch.<span class="bu">sum</span>(preds <span class="op">==</span> y).<span class="bu">float</span>() <span class="op">/</span> <span class="bu">len</span>(y)</span>
<span id="cb5-95"><a href="#cb5-95" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">"test_loss"</span>, loss)</span>
<span id="cb5-96"><a href="#cb5-96" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">"test_acc"</span>, acc)</span>
<span id="cb5-97"><a href="#cb5-97" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-98"><a href="#cb5-98" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb5-99"><a href="#cb5-99" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.Adam(<span class="va">self</span>.parameters(), lr<span class="op">=</span><span class="va">self</span>.hparams.learning_rate)</span>
<span id="cb5-100"><a href="#cb5-100" aria-hidden="true" tabindex="-1"></a>        scheduler <span class="op">=</span> torch.optim.lr_scheduler.StepLR(optimizer, step_size<span class="op">=</span><span class="dv">7</span>, gamma<span class="op">=</span><span class="fl">0.1</span>)</span>
<span id="cb5-101"><a href="#cb5-101" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">"optimizer"</span>: optimizer, <span class="st">"lr_scheduler"</span>: scheduler}</span>
<span id="cb5-102"><a href="#cb5-102" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-103"><a href="#cb5-103" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb5-104"><a href="#cb5-104" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb5-105"><a href="#cb5-105" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize data module and model</span></span>
<span id="cb5-106"><a href="#cb5-106" aria-hidden="true" tabindex="-1"></a>    dm <span class="op">=</span> MNISTDataModule()</span>
<span id="cb5-107"><a href="#cb5-107" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> MNISTClassifier()</span>
<span id="cb5-108"><a href="#cb5-108" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-109"><a href="#cb5-109" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize trainer</span></span>
<span id="cb5-110"><a href="#cb5-110" aria-hidden="true" tabindex="-1"></a>    trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb5-111"><a href="#cb5-111" aria-hidden="true" tabindex="-1"></a>        max_epochs<span class="op">=</span><span class="dv">5</span>,</span>
<span id="cb5-112"><a href="#cb5-112" aria-hidden="true" tabindex="-1"></a>        accelerator<span class="op">=</span><span class="st">'auto'</span>,</span>
<span id="cb5-113"><a href="#cb5-113" aria-hidden="true" tabindex="-1"></a>        devices<span class="op">=</span><span class="st">'auto'</span>,</span>
<span id="cb5-114"><a href="#cb5-114" aria-hidden="true" tabindex="-1"></a>        logger<span class="op">=</span>pl.loggers.TensorBoardLogger(<span class="st">'lightning_logs/'</span>),</span>
<span id="cb5-115"><a href="#cb5-115" aria-hidden="true" tabindex="-1"></a>        callbacks<span class="op">=</span>[</span>
<span id="cb5-116"><a href="#cb5-116" aria-hidden="true" tabindex="-1"></a>            pl.callbacks.EarlyStopping(monitor<span class="op">=</span><span class="st">'val_loss'</span>, patience<span class="op">=</span><span class="dv">3</span>),</span>
<span id="cb5-117"><a href="#cb5-117" aria-hidden="true" tabindex="-1"></a>            pl.callbacks.ModelCheckpoint(monitor<span class="op">=</span><span class="st">'val_loss'</span>, save_top_k<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb5-118"><a href="#cb5-118" aria-hidden="true" tabindex="-1"></a>        ]</span>
<span id="cb5-119"><a href="#cb5-119" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb5-120"><a href="#cb5-120" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-121"><a href="#cb5-121" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train the model</span></span>
<span id="cb5-122"><a href="#cb5-122" aria-hidden="true" tabindex="-1"></a>    trainer.fit(model, dm)</span>
<span id="cb5-123"><a href="#cb5-123" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-124"><a href="#cb5-124" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Test the model</span></span>
<span id="cb5-125"><a href="#cb5-125" aria-hidden="true" tabindex="-1"></a>    trainer.test(model, dm)</span></code></pre></div></div>
</section>
<section id="advanced-model-with-custom-metrics" class="level3">
<h3 class="anchored" data-anchor-id="advanced-model-with-custom-metrics" id="advanced-model-with-custom-metrics">Advanced Model with Custom Metrics</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchmetrics</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> AdvancedClassifier(pl.LightningModule):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes<span class="op">=</span><span class="dv">10</span>, learning_rate<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Model layers</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> nn.Sequential(</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">784</span>, <span class="dv">256</span>),</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.2</span>),</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">256</span>, <span class="dv">128</span>),</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(<span class="fl">0.2</span>),</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="dv">128</span>, num_classes)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Metrics</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_accuracy <span class="op">=</span> torchmetrics.Accuracy(task<span class="op">=</span><span class="st">'multiclass'</span>, num_classes<span class="op">=</span>num_classes)</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_accuracy <span class="op">=</span> torchmetrics.Accuracy(task<span class="op">=</span><span class="st">'multiclass'</span>, num_classes<span class="op">=</span>num_classes)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_f1 <span class="op">=</span> torchmetrics.F1Score(task<span class="op">=</span><span class="st">'multiclass'</span>, num_classes<span class="op">=</span>num_classes)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.model(x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(y_hat, y)</span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log metrics</span></span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.train_accuracy(y_hat, y)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss, on_step<span class="op">=</span><span class="va">True</span>, on_epoch<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_acc'</span>, <span class="va">self</span>.train_accuracy, on_step<span class="op">=</span><span class="va">True</span>, on_epoch<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(y_hat, y)</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update metrics</span></span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_accuracy(y_hat, y)</span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.val_f1(y_hat, y)</span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Log metrics</span></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_acc'</span>, <span class="va">self</span>.val_accuracy, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_f1'</span>, <span class="va">self</span>.val_f1, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.AdamW(</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.parameters(), </span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>            lr<span class="op">=</span><span class="va">self</span>.hparams.learning_rate,</span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>            weight_decay<span class="op">=</span><span class="fl">1e-4</span></span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>        scheduler <span class="op">=</span> torch.optim.lr_scheduler.CosineAnnealingLR(</span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a>            optimizer, T_max<span class="op">=</span><span class="dv">100</span></span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-64"><a href="#cb6-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {</span>
<span id="cb6-65"><a href="#cb6-65" aria-hidden="true" tabindex="-1"></a>            <span class="st">"optimizer"</span>: optimizer,</span>
<span id="cb6-66"><a href="#cb6-66" aria-hidden="true" tabindex="-1"></a>            <span class="st">"lr_scheduler"</span>: {</span>
<span id="cb6-67"><a href="#cb6-67" aria-hidden="true" tabindex="-1"></a>                <span class="st">"scheduler"</span>: scheduler,</span>
<span id="cb6-68"><a href="#cb6-68" aria-hidden="true" tabindex="-1"></a>                <span class="st">"monitor"</span>: <span class="st">"val_loss"</span>,</span>
<span id="cb6-69"><a href="#cb6-69" aria-hidden="true" tabindex="-1"></a>                <span class="st">"frequency"</span>: <span class="dv">1</span></span>
<span id="cb6-70"><a href="#cb6-70" aria-hidden="true" tabindex="-1"></a>            }</span>
<span id="cb6-71"><a href="#cb6-71" aria-hidden="true" tabindex="-1"></a>        }</span></code></pre></div></div>
</section>
</section>
<section id="advanced-features" class="level2">
<h2 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h2>
<section id="multiple-optimizers-and-schedulers" class="level3">
<h3 class="anchored" data-anchor-id="multiple-optimizers-and-schedulers" id="multiple-optimizers-and-schedulers">1. Multiple Optimizers and Schedulers</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Different learning rates for different parts</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    opt_g <span class="op">=</span> torch.optim.Adam(<span class="va">self</span>.generator.parameters(), lr<span class="op">=</span><span class="fl">0.0002</span>)</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    opt_d <span class="op">=</span> torch.optim.Adam(<span class="va">self</span>.discriminator.parameters(), lr<span class="op">=</span><span class="fl">0.0002</span>)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Different schedulers</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    sch_g <span class="op">=</span> torch.optim.lr_scheduler.StepLR(opt_g, step_size<span class="op">=</span><span class="dv">50</span>, gamma<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    sch_d <span class="op">=</span> torch.optim.lr_scheduler.StepLR(opt_d, step_size<span class="op">=</span><span class="dv">50</span>, gamma<span class="op">=</span><span class="fl">0.5</span>)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [opt_g, opt_d], [sch_g, sch_d]</span></code></pre></div></div>
</section>
<section id="custom-callbacks" class="level3">
<h3 class="anchored" data-anchor-id="custom-callbacks" id="custom-callbacks">2. Custom Callbacks</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> CustomCallback(pl.Callback):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_train_epoch_end(<span class="va">self</span>, trainer, pl_module):</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Custom logic at end of each epoch</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> trainer.current_epoch <span class="op">%</span> <span class="dv">10</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Completed </span><span class="sc">{</span>trainer<span class="sc">.</span>current_epoch<span class="sc">}</span><span class="ss"> epochs"</span>)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_validation_end(<span class="va">self</span>, trainer, pl_module):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Custom validation logic</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        val_loss <span class="op">=</span> trainer.callback_metrics.get(<span class="st">'val_loss'</span>)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> val_loss <span class="kw">and</span> val_loss <span class="op">&lt;</span> <span class="fl">0.1</span>:</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="st">"Excellent validation performance!"</span>)</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> pl.Trainer(</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    callbacks<span class="op">=</span>[CustomCallback(), pl.callbacks.EarlyStopping(monitor<span class="op">=</span><span class="st">'val_loss'</span>)]</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="manual-optimization" class="level3">
<h3 class="anchored" data-anchor-id="manual-optimization" id="manual-optimization">3. Manual Optimization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ManualOptimizationModel(pl.LightningModule):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.automatic_optimization <span class="op">=</span> <span class="va">False</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># ... model definition</span></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        opt <span class="op">=</span> <span class="va">self</span>.optimizers()</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Manual optimization</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        opt.zero_grad()</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> <span class="va">self</span>.compute_loss(batch)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.manual_backward(loss)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        opt.step()</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss)</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span></code></pre></div></div>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="hyperparameter-management" class="level3">
<h3 class="anchored" data-anchor-id="hyperparameter-management" id="hyperparameter-management">1. Hyperparameter Management</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ConfigurableModel(pl.LightningModule):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">**</span>kwargs):</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Save all hyperparameters</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Access with self.hparams</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> <span class="va">self</span>._build_model()</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _build_model(<span class="va">self</span>):</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> nn.Sequential(</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="va">self</span>.hparams.input_size, <span class="va">self</span>.hparams.hidden_size),</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(),</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>            nn.Linear(<span class="va">self</span>.hparams.hidden_size, <span class="va">self</span>.hparams.num_classes)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        )</span></code></pre></div></div>
</section>
<section id="proper-logging" class="level3">
<h3 class="anchored" data-anchor-id="proper-logging" id="proper-logging">2. Proper Logging</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> <span class="va">self</span>.compute_loss(batch)</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log to both step and epoch</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss, on_step<span class="op">=</span><span class="va">True</span>, on_epoch<span class="op">=</span><span class="va">True</span>, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log learning rate</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'lr'</span>, <span class="va">self</span>.trainer.optimizers[<span class="dv">0</span>].param_groups[<span class="dv">0</span>][<span class="st">'lr'</span>])</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss</span></code></pre></div></div>
</section>
<section id="model-organization" class="level3">
<h3 class="anchored" data-anchor-id="model-organization" id="model-organization">3. Model Organization</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Separate model definition from Lightning logic</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ResNetBackbone(nn.Module):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Model architecture here</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ResNetLightning(pl.LightningModule):</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes, learning_rate<span class="op">=</span><span class="fl">1e-3</span>):</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Use separate model class</span></span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> ResNetBackbone(num_classes)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.model(x)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training logic here...</span></span></code></pre></div></div>
</section>
<section id="testing-and-validation" class="level3">
<h3 class="anchored" data-anchor-id="testing-and-validation" id="testing-and-validation">4. Testing and Validation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Always include validation metrics</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    loss, acc <span class="op">=</span> <span class="va">self</span>._shared_eval_step(batch, batch_idx)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log_dict({<span class="st">'val_loss'</span>: loss, <span class="st">'val_acc'</span>: acc}, prog_bar<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> test_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Comprehensive test metrics</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    loss, acc <span class="op">=</span> <span class="va">self</span>._shared_eval_step(batch, batch_idx)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log_dict({<span class="st">'test_loss'</span>: loss, <span class="st">'test_acc'</span>: acc})</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> _shared_eval_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    x, y <span class="op">=</span> batch</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> F.cross_entropy(y_hat, y)</span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    acc <span class="op">=</span> (y_hat.argmax(dim<span class="op">=</span><span class="dv">1</span>) <span class="op">==</span> y).<span class="bu">float</span>().mean()</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss, acc</span></code></pre></div></div>
</section>
</section>
<section id="common-pitfalls" class="level2">
<h2 class="anchored" data-anchor-id="common-pitfalls" id="common-pitfalls">Common Pitfalls</h2>
<section id="device-management" class="level3">
<h3 class="anchored" data-anchor-id="device-management" id="device-management">1. Device Management</h3>
<p><strong>Wrong:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>    x, y <span class="op">=</span> batch</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> x.cuda()  <span class="co"># Don't manually move to device</span></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>    y <span class="op">=</span> y.cuda()</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ...</span></span></code></pre></div></div>
<p><strong>Correct:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>    x, y <span class="op">=</span> batch  <span class="co"># Lightning handles device placement</span></span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>    <span class="co"># ...</span></span></code></pre></div></div>
</section>
<section id="gradient-accumulation" class="level3">
<h3 class="anchored" data-anchor-id="gradient-accumulation" id="gradient-accumulation">2. Gradient Accumulation</h3>
<p><strong>Wrong:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> <span class="va">self</span>.compute_loss(batch)</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    loss.backward()  <span class="co"># Don't call backward manually</span></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss</span></code></pre></div></div>
<p><strong>Correct:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> <span class="va">self</span>.compute_loss(batch)</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss  <span class="co"># Lightning handles backward pass</span></span></code></pre></div></div>
</section>
<section id="metric-computation" class="level3">
<h3 class="anchored" data-anchor-id="metric-computation" id="metric-computation">3. Metric Computation</h3>
<p><strong>Wrong:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Computing metrics inside step leads to incorrect averages</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    acc <span class="op">=</span> compute_accuracy(batch)</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'val_acc'</span>, acc.mean())</span></code></pre></div></div>
<p><strong>Correct:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>    <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.val_acc <span class="op">=</span> torchmetrics.Accuracy()</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Let torchmetrics handle the averaging</span></span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>    y_hat <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.val_acc(y_hat, y)</span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'val_acc'</span>, <span class="va">self</span>.val_acc)</span></code></pre></div></div>
</section>
<section id="dataloader-in-model" class="level3">
<h3 class="anchored" data-anchor-id="dataloader-in-model" id="dataloader-in-model">4. DataLoader in Model</h3>
<p><strong>Wrong:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Model(pl.LightningModule):</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_dataloader(<span class="va">self</span>):</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Don't put data loading in model</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(...)</span></code></pre></div></div>
<p><strong>Correct:</strong></p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use separate DataModule</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DataModule(pl.LightningDataModule):</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_dataloader(<span class="va">self</span>):</span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(...)</span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Or pass dataloaders to trainer.fit()</span></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, train_dataloader, val_dataloader)</span></code></pre></div></div>
</section>
</section>
<section id="migration-checklist" class="level2">
<h2 class="anchored" data-anchor-id="migration-checklist" id="migration-checklist">Migration Checklist</h2>
<ul class="task-list">
<li><label><input type="checkbox">Convert model class to inherit from <code>pl.LightningModule</code></label></li>
<li><label><input type="checkbox">Implement required methods: <code>training_step</code>, <code>configure_optimizers</code></label></li>
<li><label><input type="checkbox">Add <code>validation_step</code> if you have validation data</label></li>
<li><label><input type="checkbox">Replace manual training loop with <code>pl.Trainer</code></label></li>
<li><label><input type="checkbox">Move data loading to <code>pl.LightningDataModule</code> or separate functions</label></li>
<li><label><input type="checkbox">Add proper logging with <code>self.log()</code></label></li>
<li><label><input type="checkbox">Use <code>self.save_hyperparameters()</code> for configuration</label></li>
<li><label><input type="checkbox">Add callbacks for checkpointing, early stopping, etc.</label></li>
<li><label><input type="checkbox">Remove manual device management (CUDA calls)</label></li>
<li><label><input type="checkbox">Test with different accelerators (CPU, GPU, multi-GPU)</label></li>
<li><label><input type="checkbox">Update any custom metrics to use torchmetrics</label></li>
<li><label><input type="checkbox">Verify logging and experiment tracking works</label></li>
<li><label><input type="checkbox">Add proper test methods if needed</label></li>
</ul>
<p>By following this guide, you should be able to successfully migrate your PyTorch code to PyTorch Lightning while maintaining all functionality and gaining the benefits of Lightning’s structured approach.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Vision Transformers (ViT): A Simple Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/vision-transformers-explained/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/vision-transformers-explained/</guid>
      <pubDate>Sat, 24 May 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="vision-transformers-vit-a-simple-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-transformers-explained/vit.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Vision Transformers (ViTs) represent a paradigm shift in computer vision, adapting the transformer architecture that revolutionized natural language processing for image classification and other visual tasks. Instead of relying on convolutional neural networks (CNNs), ViTs treat images as sequences of patches, applying the self-attention mechanism to understand spatial relationships and visual features.</p>
</section>
<section id="background-from-cnns-to-transformers" class="level2">
<h2 class="anchored" data-anchor-id="background-from-cnns-to-transformers" id="background-from-cnns-to-transformers">Background: From CNNs to Transformers</h2>
<p>Traditional computer vision relied heavily on Convolutional Neural Networks (CNNs), which process images through layers of convolutions that detect local features like edges, textures, and patterns. While effective, CNNs have limitations in capturing long-range dependencies across an image due to their local receptive fields.</p>
<p>Transformers, originally designed for language tasks, excel at modeling long-range dependencies through self-attention mechanisms. The key insight behind Vision Transformers was asking: “What if we could apply this powerful attention mechanism to images?”</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>The fundamental breakthrough of ViTs was recognizing that images could be treated as sequences of patches, making them compatible with transformer architectures originally designed for text.</p>
</div>
</div>
</section>
<section id="core-concept-images-as-sequences" class="level2">
<h2 class="anchored" data-anchor-id="core-concept-images-as-sequences" id="core-concept-images-as-sequences">Core Concept: Images as Sequences</h2>
<p>The fundamental innovation of ViTs lies in treating images as sequences of patches rather than pixel grids. Here’s how this transformation works:</p>
<section id="image-patch-embedding" class="level3">
<h3 class="anchored" data-anchor-id="image-patch-embedding" id="image-patch-embedding">Image Patch Embedding</h3>
<ol type="1">
<li><strong>Patch Division</strong>: An input image (typically 224×224 pixels) is divided into fixed-size patches (commonly 16×16 pixels), resulting in a sequence of patches</li>
<li><strong>Linear Projection</strong>: Each patch is flattened into a vector and linearly projected to create patch embeddings</li>
<li><strong>Position Encoding</strong>: Since transformers don’t inherently understand spatial relationships, positional encodings are added to maintain spatial information</li>
<li><strong>Classification Token</strong>: A special learnable [CLS] token is prepended to the sequence, similar to BERT’s approach</li>
</ol>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A[Input Image 224×224] --&gt; B[Divide into 16×16 patches]
    B --&gt; C[196 patches]
    C --&gt; D[Flatten each patch]
    D --&gt; E[Linear projection]
    E --&gt; F[Add positional encoding]
    F --&gt; G[Prepend CLS token]
    G --&gt; H[Sequence ready for transformer]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="mathematical-formulation" class="level3">
<h3 class="anchored" data-anchor-id="mathematical-formulation" id="mathematical-formulation">Mathematical Formulation</h3>
<p>For an image of size <span class="math inline">\(H \times W \times C\)</span> divided into patches of size <span class="math inline">\(P \times P\)</span>:</p>
<ul>
<li>Number of patches: <span class="math inline">\(N = \frac{H \times W}{P^2}\)</span></li>
<li>Each patch becomes a vector of size <span class="math inline">\(P^2 \times C\)</span></li>
<li>After linear projection: embedding dimension <span class="math inline">\(D\)</span></li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Patch Size Trade-off
</div>
</div>
<div class="callout-body-container callout-body">
<p>Smaller patches (e.g., 8×8) provide finer detail but increase computational cost, while larger patches (e.g., 32×32) are more efficient but may lose important spatial information.</p>
</div>
</div>
</section>
</section>
<section id="architecture-components" class="level2">
<h2 class="anchored" data-anchor-id="architecture-components" id="architecture-components">Architecture Components</h2>
<section id="patch-embedding-layer" class="level3">
<h3 class="anchored" data-anchor-id="patch-embedding-layer" id="patch-embedding-layer">Patch Embedding Layer</h3>
<p>The patch embedding layer converts image patches into token embeddings that the transformer can process. This involves:</p>
<ul>
<li>Reshaping patches into vectors</li>
<li>Linear transformation to desired embedding dimension</li>
<li>Adding positional encodings</li>
</ul>
</section>
<section id="transformer-encoder" class="level3">
<h3 class="anchored" data-anchor-id="transformer-encoder" id="transformer-encoder">Transformer Encoder</h3>
<p>The core of ViT consists of standard transformer encoder blocks, each containing:</p>
<ul>
<li><strong>Multi-Head Self-Attention (MSA)</strong>: Allows patches to attend to all other patches</li>
<li><strong>Layer Normalization</strong>: Applied before both attention and MLP layers</li>
<li><strong>Multi-Layer Perceptron (MLP)</strong>: Two-layer feedforward network with GELU activation</li>
<li><strong>Residual Connections</strong>: Skip connections around both attention and MLP blocks</li>
</ul>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A[Input Patches] --&gt; B[Patch Embedding]
    B --&gt; C[Add Position Encoding]
    C --&gt; D[Add CLS Token]
    D --&gt; E[Transformer Encoder Block 1]
    E --&gt; F[Transformer Encoder Block 2]
    F --&gt; G[...]
    G --&gt; H[Transformer Encoder Block N]
    H --&gt; I[Extract CLS Token]
    I --&gt; J[Classification Head]
    J --&gt; K[Output Predictions]
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="classification-head" class="level3">
<h3 class="anchored" data-anchor-id="classification-head" id="classification-head">Classification Head</h3>
<p>The final component extracts the [CLS] token’s representation and passes it through:</p>
<ul>
<li>Layer normalization</li>
<li>Linear classifier to produce class predictions</li>
</ul>
</section>
</section>
<section id="self-attention-in-vision" class="level2">
<h2 class="anchored" data-anchor-id="self-attention-in-vision" id="self-attention-in-vision">Self-Attention in Vision</h2>
<p>The self-attention mechanism in ViTs operates differently from CNNs:</p>
<section id="attention-maps" class="level3">
<h3 class="anchored" data-anchor-id="attention-maps" id="attention-maps">Attention Maps</h3>
<ul>
<li>Each patch can attend to every other patch in the image</li>
<li>Attention weights reveal which parts of the image are most relevant for classification</li>
<li>This enables modeling of long-range spatial dependencies</li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Global Receptive Field
</div>
</div>
<div class="callout-body-container callout-body">
<p>Unlike CNNs that build up receptive fields gradually, ViTs have global receptive fields from the first layer, allowing immediate access to information across the entire image.</p>
</div>
</div>
</section>
<section id="global-context" class="level3">
<h3 class="anchored" data-anchor-id="global-context" id="global-context">Global Context</h3>
<p>The ability to model global context from the first layer is a key advantage of ViTs over traditional CNNs.</p>
</section>
</section>
<section id="training-considerations" class="level2">
<h2 class="anchored" data-anchor-id="training-considerations" id="training-considerations">Training Considerations</h2>
<section id="data-requirements" class="level3">
<h3 class="anchored" data-anchor-id="data-requirements" id="data-requirements">Data Requirements</h3>
<p>Vision Transformers typically require large amounts of training data to perform well:</p>
<ul>
<li><strong>Pre-training</strong>: Often trained on large datasets like ImageNet-21k or JFT-300M</li>
<li><strong>Fine-tuning</strong>: Then adapted to specific tasks with smaller datasets</li>
<li><strong>Data Efficiency</strong>: ViTs can be less data-efficient than CNNs when training from scratch</li>
</ul>
</section>
<section id="optimization-challenges" class="level3">
<h3 class="anchored" data-anchor-id="optimization-challenges" id="optimization-challenges">Optimization Challenges</h3>
<ul>
<li><strong>Initialization</strong>: Careful weight initialization is crucial</li>
<li><strong>Learning Rate</strong>: Often requires different learning rates for different components</li>
<li><strong>Regularization</strong>: Techniques like dropout and weight decay are important</li>
<li><strong>Warmup</strong>: Learning rate warmup is commonly used</li>
</ul>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Training from Scratch
</div>
</div>
<div class="callout-body-container callout-body">
<p>Training ViTs from scratch on small datasets often leads to poor performance. Pre-training on large datasets followed by fine-tuning is the recommended approach.</p>
</div>
</div>
</section>
</section>
<section id="variants-and-improvements" class="level2">
<h2 class="anchored" data-anchor-id="variants-and-improvements" id="variants-and-improvements">Variants and Improvements</h2>
<section id="vit-variants" class="level3">
<h3 class="anchored" data-anchor-id="vit-variants" id="vit-variants">ViT Variants</h3>
<div id="tbl-vit-variants" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-vit-variants-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: ViT Model Variants
</figcaption>
<div aria-describedby="tbl-vit-variants-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="caption-top table">
<thead>
<tr class="header">
<th>Model</th>
<th>Patch Size</th>
<th>Parameters</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>ViT-B/16</td>
<td>16×16</td>
<td>86M</td>
<td>Base model with 16×16 patches</td>
</tr>
<tr class="even">
<td>ViT-L/16</td>
<td>16×16</td>
<td>307M</td>
<td>Large model with 16×16 patches</td>
</tr>
<tr class="odd">
<td>ViT-H/14</td>
<td>14×14</td>
<td>632M</td>
<td>Huge model with 14×14 patches</td>
</tr>
<tr class="even">
<td>DeiT</td>
<td>16×16</td>
<td>86M</td>
<td>Data-efficient training strategies</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
</section>
<section id="architectural-improvements" class="level3">
<h3 class="anchored" data-anchor-id="architectural-improvements" id="architectural-improvements">Architectural Improvements</h3>
<ul>
<li><strong>Hierarchical Processing</strong>: Multi-scale feature extraction</li>
<li><strong>Local Attention</strong>: Restricting attention to local neighborhoods</li>
<li><strong>Hybrid Models</strong>: Combining CNN features with transformer processing</li>
</ul>
</section>
</section>
<section id="advantages-of-vision-transformers" class="level2">
<h2 class="anchored" data-anchor-id="advantages-of-vision-transformers" id="advantages-of-vision-transformers">Advantages of Vision Transformers</h2>
<section id="strengths" class="level3">
<h3 class="anchored" data-anchor-id="strengths" id="strengths">Strengths</h3>
<div class="columns">
<div class="column" style="width:50%;">
<p><strong>Technical Advantages:</strong> - Long-range Dependencies - Interpretability through attention maps - Scalability with model size - Architectural Simplicity</p>
</div><div class="column" style="width:50%;">
<p><strong>Practical Benefits:</strong> - State-of-the-art classification results - Excellent transfer learning - Strong multi-task performance - Domain adaptation capabilities</p>
</div>
</div>
</section>
<section id="performance-benefits" class="level3">
<h3 class="anchored" data-anchor-id="performance-benefits" id="performance-benefits">Performance Benefits</h3>
<ul>
<li>State-of-the-art results on image classification</li>
<li>Strong performance on object detection and segmentation when adapted</li>
<li>Excellent transfer learning capabilities across domains</li>
</ul>
</section>
</section>
<section id="limitations-and-challenges" class="level2">
<h2 class="anchored" data-anchor-id="limitations-and-challenges" id="limitations-and-challenges">Limitations and Challenges</h2>
<section id="current-limitations" class="level3">
<h3 class="anchored" data-anchor-id="current-limitations" id="current-limitations">Current Limitations</h3>
<table class="caption-top table">
<caption>ViT Limitations and Solutions</caption>
<colgroup>
<col style="width: 28%">
<col style="width: 19%">
<col style="width: 52%">
</colgroup>
<thead>
<tr class="header">
<th>Limitation</th>
<th>Impact</th>
<th>Mitigation Strategies</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Data Hunger</td>
<td>Poor performance on small datasets</td>
<td>Pre-training + fine-tuning</td>
</tr>
<tr class="even">
<td>Computational Cost</td>
<td>High memory/compute requirements</td>
<td>Model compression, efficient variants</td>
</tr>
<tr class="odd">
<td>Lack of Inductive Bias</td>
<td>Missing spatial assumptions</td>
<td>Hybrid architectures</td>
</tr>
<tr class="even">
<td>Training Instability</td>
<td>Sensitive to hyperparameters</td>
<td>Careful initialization, warmup</td>
</tr>
</tbody>
</table>
</section>
<section id="ongoing-research-areas" class="level3">
<h3 class="anchored" data-anchor-id="ongoing-research-areas" id="ongoing-research-areas">Ongoing Research Areas</h3>
<ul>
<li>Improving data efficiency</li>
<li>Reducing computational requirements</li>
<li>Better integration of spatial inductive biases</li>
<li>Hybrid CNN-Transformer architectures</li>
</ul>
</section>
</section>
<section id="applications-beyond-classification" class="level2">
<h2 class="anchored" data-anchor-id="applications-beyond-classification" id="applications-beyond-classification">Applications Beyond Classification</h2>
<section id="computer-vision-tasks" class="level3">
<h3 class="anchored" data-anchor-id="computer-vision-tasks" id="computer-vision-tasks">Computer Vision Tasks</h3>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">mindmap
  root((ViT Applications))
    Object Detection
      DETR
      Deformable DETR
    Segmentation
      SETR
      SegFormer
    Generation
      VQGAN
      DALL-E 2
    Video Analysis
      TimeSformer
      Video ViT
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="multimodal-applications" class="level3">
<h3 class="anchored" data-anchor-id="multimodal-applications" id="multimodal-applications">Multimodal Applications</h3>
<ul>
<li><strong>Vision-Language Models</strong>: CLIP and similar models combining vision and text</li>
<li><strong>Visual Question Answering</strong>: Integrating visual and textual understanding</li>
<li><strong>Image Captioning</strong>: Generating descriptions from visual content</li>
</ul>
</section>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="model-selection" class="level3">
<h3 class="anchored" data-anchor-id="model-selection" id="model-selection">Model Selection</h3>
<p>Choose ViT variants based on:</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center collapsed" data-bs-toggle="collapse" data-bs-target=".callout-5-contents" aria-controls="callout-5" aria-expanded="false" aria-label="Toggle callout">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Selection Criteria
</div>
<div class="callout-btn-toggle d-inline-block border-0 py-1 ps-1 pe-0 float-end"><i class="callout-toggle"></i></div>
</div>
<div id="callout-5" class="callout-5-contents callout-collapse collapse">
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Computational Resources</strong>: Available GPU memory and compute budget</li>
<li><strong>Dataset Size</strong>: Larger datasets can support bigger models</li>
<li><strong>Inference Speed</strong>: Real-time applications need smaller, faster models</li>
<li><strong>Accuracy Requirements</strong>: Higher accuracy often requires larger models</li>
</ol>
</div>
</div>
</div>
</section>
<section id="training-strategy" class="level3">
<h3 class="anchored" data-anchor-id="training-strategy" id="training-strategy">Training Strategy</h3>
<ul>
<li>Use pre-trained models when possible</li>
<li>Apply appropriate data augmentation</li>
<li>Consider knowledge distillation for smaller models</li>
<li>Monitor for overfitting, especially on smaller datasets</li>
</ul>
</section>
<section id="optimization-tips" class="level3">
<h3 class="anchored" data-anchor-id="optimization-tips" id="optimization-tips">Optimization Tips</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example training configuration</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>training_config <span class="op">=</span> {</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">"mixed_precision"</span>: <span class="va">True</span>,</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">"gradient_checkpointing"</span>: <span class="va">True</span>,</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">"weight_decay"</span>: <span class="fl">0.05</span>,</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">"learning_rate"</span>: <span class="fl">1e-3</span>,</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">"warmup_epochs"</span>: <span class="dv">5</span>,</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">"batch_size"</span>: <span class="dv">512</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>}</span></code></pre></div></div>
</section>
</section>
<section id="future-directions" class="level2">
<h2 class="anchored" data-anchor-id="future-directions" id="future-directions">Future Directions</h2>
<section id="research-trends" class="level3">
<h3 class="anchored" data-anchor-id="research-trends" id="research-trends">Research Trends</h3>
<div class="tabset-margin-container"></div><div class="panel-tabset">
<ul class="nav nav-tabs" role="tablist"><li class="nav-item" role="presentation"><a class="nav-link active" id="tabset-1-1-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-1" role="tab" aria-controls="tabset-1-1" aria-selected="true" href="">Efficiency</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-2-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-2" role="tab" aria-controls="tabset-1-2" aria-selected="false" href="">Architecture</a></li><li class="nav-item" role="presentation"><a class="nav-link" id="tabset-1-3-tab" data-bs-toggle="tab" data-bs-target="#tabset-1-3" role="tab" aria-controls="tabset-1-3" aria-selected="false" href="">Learning</a></li></ul>
<div class="tab-content">
<div id="tabset-1-1" class="tab-pane active" role="tabpanel" aria-labelledby="tabset-1-1-tab">
<ul>
<li>Making ViTs more computationally efficient</li>
<li>Mobile and edge deployment optimizations</li>
<li>Pruning and quantization techniques</li>
</ul>
</div>
<div id="tabset-1-2" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-2-tab">
<ul>
<li>Automated design of vision transformer architectures</li>
<li>Neural architecture search for ViTs</li>
<li>Hybrid CNN-Transformer designs</li>
</ul>
</div>
<div id="tabset-1-3" class="tab-pane" role="tabpanel" aria-labelledby="tabset-1-3-tab">
<ul>
<li>Self-supervised learning approaches</li>
<li>Reducing dependence on labeled data</li>
<li>Few-shot and zero-shot learning capabilities</li>
</ul>
</div>
</div>
</div>
</section>
<section id="emerging-applications" class="level3">
<h3 class="anchored" data-anchor-id="emerging-applications" id="emerging-applications">Emerging Applications</h3>
<ul>
<li>Real-time vision applications</li>
<li>Mobile and edge deployment</li>
<li>Scientific imaging and medical applications</li>
<li>Autonomous systems and robotics</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Vision Transformers represent a fundamental shift in computer vision, demonstrating that the transformer architecture’s success in NLP can extend to visual tasks. While they present challenges in terms of data requirements and computational cost, their ability to model long-range dependencies and achieve state-of-the-art performance makes them a crucial tool in modern computer vision.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Key Takeaways
</div>
</div>
<div class="callout-body-container callout-body">
<ol type="1">
<li><strong>Paradigm Shift</strong>: ViTs treat images as sequences of patches</li>
<li><strong>Global Attention</strong>: Immediate access to long-range dependencies</li>
<li><strong>Data Requirements</strong>: Best performance with large-scale pre-training</li>
<li><strong>Scalability</strong>: Performance improves with model and dataset size</li>
<li><strong>Versatility</strong>: Applicable across many computer vision tasks</li>
</ol>
</div>
</div>
<p>The field continues to evolve rapidly, with ongoing research addressing current limitations while exploring new applications. As the technology matures, we can expect ViTs to become increasingly practical for a wider range of real-world applications, potentially reshaping how we approach visual understanding tasks.</p>
<p>Understanding Vision Transformers is essential for anyone working in modern computer vision, as they represent not just a new model architecture, but a new way of thinking about how machines can understand and process visual information.</p>
<hr>
<p><em>This document provides a comprehensive overview of Vision Transformers. For the latest developments and research, please refer to recent publications and the official implementations.</em></p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python 3.14: Key Improvements and New Features]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-pi-code/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-pi-code/</guid>
      <pubDate>Sat, 24 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="python-3.14-key-improvements-and-new-features" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-pi-code/pypi.png" class="img-fluid"></p>
<p>Python 3.14 introduces several significant improvements focused on performance, developer experience, and language capabilities. This guide covers the most important changes that will impact your code and development workflow.</p>
<section id="performance-improvements" class="level2">
<h2 class="anchored" data-anchor-id="performance-improvements" id="performance-improvements">Performance Improvements</h2>
<section id="free-threading-experimental" class="level3">
<h3 class="anchored" data-anchor-id="free-threading-experimental" id="free-threading-experimental">Free-Threading (Experimental)</h3>
<p>Python 3.14 continues the experimental free-threading support introduced in 3.13, with significant improvements:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable free-threading with --disable-gil flag</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Better performance for CPU-bound multi-threaded applications</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> threading</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cpu_intensive_task(n):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n):</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> i <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total</span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Now benefits more from true parallelism</span></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>threads <span class="op">=</span> []</span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">4</span>):</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>    t <span class="op">=</span> threading.Thread(target<span class="op">=</span>cpu_intensive_task, args<span class="op">=</span>(<span class="dv">1000000</span>,))</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>    threads.append(t)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>    t.start()</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> t <span class="kw">in</span> threads:</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>    t.join()</span></code></pre></div></div>
</section>
<section id="jit-compilation-improvements" class="level3">
<h3 class="anchored" data-anchor-id="jit-compilation-improvements" id="jit-compilation-improvements">JIT Compilation Improvements</h3>
<p>Enhanced Just-In-Time compilation for better runtime performance:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Functions with hot loops see significant speedups</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> fibonacci(n):</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n <span class="op">&lt;=</span> <span class="dv">1</span>:</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> n</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> fibonacci(n<span class="op">-</span><span class="dv">1</span>) <span class="op">+</span> fibonacci(n<span class="op">-</span><span class="dv">2</span>)</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a><span class="co"># JIT compiler optimizes recursive calls more effectively</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> fibonacci(<span class="dv">35</span>)  <span class="co"># Noticeably faster than previous versions</span></span></code></pre></div></div>
</section>
</section>
<section id="language-features" class="level2">
<h2 class="anchored" data-anchor-id="language-features" id="language-features">Language Features</h2>
<section id="improved-type-annotations" class="level3">
<h3 class="anchored" data-anchor-id="improved-type-annotations" id="improved-type-annotations">Improved Type Annotations</h3>
<p>Enhanced support for generic types and better inference:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> typing <span class="im">import</span> Generic, TypeVar</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>T <span class="op">=</span> TypeVar(<span class="st">'T'</span>)</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> Stack(Generic[T]):</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>) <span class="op">-&gt;</span> <span class="va">None</span>:</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._items: <span class="bu">list</span>[T] <span class="op">=</span> []</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> push(<span class="va">self</span>, item: T) <span class="op">-&gt;</span> <span class="va">None</span>:</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._items.append(item)</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> pop(<span class="va">self</span>) <span class="op">-&gt;</span> T:</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>._items:</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">IndexError</span>(<span class="st">"Stack is empty"</span>)</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>._items.pop()</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Better type inference</span></span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>stack <span class="op">=</span> Stack[<span class="bu">int</span>]()</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>stack.push(<span class="dv">42</span>)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>value <span class="op">=</span> stack.pop()  <span class="co"># Type checker knows this is int</span></span></code></pre></div></div>
</section>
<section id="pattern-matching-enhancements" class="level3">
<h3 class="anchored" data-anchor-id="pattern-matching-enhancements" id="pattern-matching-enhancements">Pattern Matching Enhancements</h3>
<p>Improved pattern matching with new syntax options:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_data(data):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">match</span> data:</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="cf">case</span> {<span class="st">"type"</span>: <span class="st">"user"</span>, <span class="st">"id"</span>: <span class="bu">int</span>(user_id), <span class="st">"active"</span>: <span class="va">True</span>}:</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="ss">f"Active user: </span><span class="sc">{</span>user_id<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">case</span> {<span class="st">"type"</span>: <span class="st">"user"</span>, <span class="st">"id"</span>: <span class="bu">int</span>(user_id), <span class="st">"active"</span>: <span class="va">False</span>}:</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="ss">f"Inactive user: </span><span class="sc">{</span>user_id<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        case {<span class="st">"type"</span>: <span class="st">"admin"</span>, <span class="st">"permissions"</span>: <span class="bu">list</span>(perms)} <span class="cf">if</span> <span class="bu">len</span>(perms) <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="ss">f"Admin with </span><span class="sc">{</span><span class="bu">len</span>(perms)<span class="sc">}</span><span class="ss"> permissions"</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="cf">case</span> _:</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> <span class="st">"Unknown data format"</span></span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Enhanced guard conditions and destructuring</span></span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> process_data({<span class="st">"type"</span>: <span class="st">"user"</span>, <span class="st">"id"</span>: <span class="dv">123</span>, <span class="st">"active"</span>: <span class="va">True</span>})</span></code></pre></div></div>
</section>
<section id="new-string-methods" class="level3">
<h3 class="anchored" data-anchor-id="new-string-methods" id="new-string-methods">New String Methods</h3>
<p>Additional string manipulation methods for common operations:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>text <span class="op">=</span> <span class="st">"Hello, World! How are you today?"</span></span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co"># New methods for better string handling</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>words <span class="op">=</span> text.split_keep_separator(<span class="st">" "</span>)  <span class="co"># Keeps separators</span></span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Result: ["Hello,", " ", "World!", " ", "How", " ", "are", " ", "you", " ", "today?"]</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Improved case conversion</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>title_text <span class="op">=</span> <span class="st">"hello-world_example"</span></span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> title_text.to_title_case()  <span class="co"># "Hello World Example"</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Better whitespace handling</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>messy_text <span class="op">=</span> <span class="st">"  </span><span class="ch">\t\n</span><span class="st">  Hello  World  </span><span class="ch">\t\n</span><span class="st">  "</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>clean_text <span class="op">=</span> messy_text.normalize_whitespace()  <span class="co"># "Hello World"</span></span></code></pre></div></div>
</section>
</section>
<section id="error-handling-improvements" class="level2">
<h2 class="anchored" data-anchor-id="error-handling-improvements" id="error-handling-improvements">Error Handling Improvements</h2>
<section id="enhanced-exception-groups" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-exception-groups" id="enhanced-exception-groups">Enhanced Exception Groups</h3>
<p>Better support for handling multiple exceptions:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> asyncio</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> fetch_data(url):</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">"invalid"</span> <span class="kw">in</span> url:</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="ss">f"Invalid URL: </span><span class="sc">{</span>url<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="st">"timeout"</span> <span class="kw">in</span> url:</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">raise</span> <span class="pp">TimeoutError</span>(<span class="ss">f"Timeout for: </span><span class="sc">{</span>url<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> <span class="ss">f"Data from </span><span class="sc">{</span>url<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> fetch_multiple():</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    urls <span class="op">=</span> [<span class="st">"http://valid.com"</span>, <span class="st">"http://invalid.com"</span>, <span class="st">"http://timeout.com"</span>]</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        results <span class="op">=</span> <span class="cf">await</span> asyncio.gather(</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            <span class="op">*</span>[fetch_data(url) <span class="cf">for</span> url <span class="kw">in</span> urls],</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>            return_exceptions<span class="op">=</span><span class="va">True</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span><span class="op">*</span> <span class="pp">ValueError</span> <span class="im">as</span> eg:</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Value errors: </span><span class="sc">{</span><span class="bu">len</span>(eg.exceptions)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> exc <span class="kw">in</span> eg.exceptions:</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  - </span><span class="sc">{</span>exc<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span><span class="op">*</span> <span class="pp">TimeoutError</span> <span class="im">as</span> eg:</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f"Timeout errors: </span><span class="sc">{</span><span class="bu">len</span>(eg.exceptions)<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> exc <span class="kw">in</span> eg.exceptions:</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"  - </span><span class="sc">{</span>exc<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="improved-traceback-information" class="level3">
<h3 class="anchored" data-anchor-id="improved-traceback-information" id="improved-traceback-information">Improved Traceback Information</h3>
<p>More detailed and helpful error messages:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_nested_data(data):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data[<span class="st">"users"</span>][<span class="dv">0</span>][<span class="st">"profile"</span>][<span class="st">"email"</span>].upper()</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Better error messages show the exact path that failed</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="cf">try</span>:</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    result <span class="op">=</span> process_nested_data({<span class="st">"users"</span>: []})</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a><span class="cf">except</span> (<span class="pp">KeyError</span>, <span class="pp">IndexError</span>) <span class="im">as</span> e:</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Error message now includes: "Failed accessing: data['users'][0]"</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Access error: </span><span class="sc">{</span>e<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="standard-library-updates" class="level2">
<h2 class="anchored" data-anchor-id="standard-library-updates" id="standard-library-updates">Standard Library Updates</h2>
<section id="enhanced-pathlib" class="level3">
<h3 class="anchored" data-anchor-id="enhanced-pathlib" id="enhanced-pathlib">Enhanced <code>pathlib</code></h3>
<p>New methods for better file system operations:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pathlib <span class="im">import</span> Path</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>path <span class="op">=</span> Path(<span class="st">"./my_project"</span>)</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co"># New methods for common operations</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> path.is_empty_dir():</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"Directory is empty"</span>)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Better glob patterns</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>python_files <span class="op">=</span> path.rglob(<span class="st">"*.py"</span>, follow_symlinks<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Atomic operations</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>config_file <span class="op">=</span> path <span class="op">/</span> <span class="st">"config.json"</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>config_file.write_text_atomic(<span class="st">'{"version": "1.0"}'</span>)  <span class="co"># Atomic write operation</span></span></code></pre></div></div>
</section>
<section id="improved-asyncio" class="level3">
<h3 class="anchored" data-anchor-id="improved-asyncio" id="improved-asyncio">Improved <code>asyncio</code></h3>
<p>Better async/await support and performance:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> asyncio</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co"># New context manager for better resource cleanup</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> main():</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">async</span> <span class="cf">with</span> asyncio.TaskGroup() <span class="im">as</span> tg:</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        task1 <span class="op">=</span> tg.create_task(fetch_data(<span class="st">"url1"</span>))</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        task2 <span class="op">=</span> tg.create_task(fetch_data(<span class="st">"url2"</span>))</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        task3 <span class="op">=</span> tg.create_task(fetch_data(<span class="st">"url3"</span>))</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># All tasks complete or all cancelled if one fails</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="st">"All tasks completed successfully"</span>)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Improved timeout handling</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> with_timeout():</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">try</span>:</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>        <span class="cf">async</span> <span class="cf">with</span> asyncio.timeout(<span class="fl">5.0</span>):</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            result <span class="op">=</span> <span class="cf">await</span> slow_operation()</span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> result</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">except</span> asyncio.<span class="pp">TimeoutError</span>:</span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Operation timed out"</span>)</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">None</span></span></code></pre></div></div>
</section>
<section id="new-itertools-functions" class="level3">
<h3 class="anchored" data-anchor-id="new-itertools-functions" id="new-itertools-functions">New <code>itertools</code> Functions</h3>
<p>Additional utilities for working with iterators:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> itertools</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a><span class="co"># New batching function</span></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> <span class="bu">range</span>(<span class="dv">15</span>)</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>batches <span class="op">=</span> <span class="bu">list</span>(itertools.batched(data, <span class="dv">4</span>))</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Result: [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11), (12, 13, 14)]</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Improved pairwise iteration</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>points <span class="op">=</span> [(<span class="dv">0</span>, <span class="dv">0</span>), (<span class="dv">1</span>, <span class="dv">1</span>), (<span class="dv">2</span>, <span class="dv">4</span>), (<span class="dv">3</span>, <span class="dv">9</span>)]</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>segments <span class="op">=</span> <span class="bu">list</span>(itertools.pairwise(points))</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Result: [((0, 0), (1, 1)), ((1, 1), (2, 4)), ((2, 4), (3, 9))]</span></span></code></pre></div></div>
</section>
</section>
<section id="development-tools" class="level2">
<h2 class="anchored" data-anchor-id="development-tools" id="development-tools">Development Tools</h2>
<section id="better-repl-experience" class="level3">
<h3 class="anchored" data-anchor-id="better-repl-experience" id="better-repl-experience">Better REPL Experience</h3>
<p>Enhanced interactive Python shell:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Improved auto-completion and syntax highlighting</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="co"># Better error recovery - continue working after syntax errors</span></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Enhanced help system with examples</span></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="co"># New REPL commands</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a><span class="co"># %time &lt;expression&gt;  - Time execution</span></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="co"># %edit &lt;function&gt;    - Edit function in external editor</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="co"># %history           - Show command history</span></span></code></pre></div></div>
</section>
<section id="debugging-improvements" class="level3">
<h3 class="anchored" data-anchor-id="debugging-improvements" id="debugging-improvements">Debugging Improvements</h3>
<p>Better debugging capabilities:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pdb</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> complex_function(data):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># New breakpoint() enhancements</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    <span class="bu">breakpoint</span>()  <span class="co"># Now supports conditional breakpoints</span></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    processed <span class="op">=</span> []</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> item <span class="kw">in</span> data:</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Enhanced step-through debugging</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        result <span class="op">=</span> item <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        processed.append(result)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> processed</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Better integration with IDE debuggers</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a><span class="co"># Improved variable inspection</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Enhanced stack trace navigation</span></span></code></pre></div></div>
</section>
</section>
<section id="migration-considerations" class="level2">
<h2 class="anchored" data-anchor-id="migration-considerations" id="migration-considerations">Migration Considerations</h2>
<section id="deprecated-features" class="level3">
<h3 class="anchored" data-anchor-id="deprecated-features" id="deprecated-features">Deprecated Features</h3>
<p>Features removed or deprecated in 3.14:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Deprecated: Old-style string formatting (still works but discouraged)</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="co"># old_way = "Hello %s" % name</span></span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Better: Use f-strings or .format()</span></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>new_way <span class="op">=</span> <span class="ss">f"Hello </span><span class="sc">{</span>name<span class="sc">}</span><span class="ss">"</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Removed: Some legacy asyncio APIs</span></span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Use modern async/await syntax consistently</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Deprecated: Certain distutils modules</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Use setuptools or build system alternatives</span></span></code></pre></div></div>
</section>
<section id="compatibility-notes" class="level3">
<h3 class="anchored" data-anchor-id="compatibility-notes" id="compatibility-notes">Compatibility Notes</h3>
<p>Important changes that might affect existing code:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Stricter type checking in some standard library functions</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="co"># May need to update type annotations</span></span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Changed behavior in some edge cases for consistency</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Review code that relies on previous undefined behavior</span></span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Updated default parameters for some functions</span></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Check documentation for functions you use extensively</span></span></code></pre></div></div>
</section>
</section>
<section id="performance-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="performance-benchmarks" id="performance-benchmarks">Performance Benchmarks</h2>
<p>Typical performance improvements you can expect:</p>
<ul>
<li><strong>General Python code</strong>: 10-15% faster execution</li>
<li><strong>Multi-threaded applications</strong>: Up to 40% improvement with free-threading</li>
<li><strong>String operations</strong>: 20-25% faster for common operations</li>
<li><strong>Import time</strong>: 15-20% faster module loading</li>
<li><strong>Memory usage</strong>: 5-10% reduction in typical applications</li>
</ul>
</section>
<section id="getting-started" class="level2">
<h2 class="anchored" data-anchor-id="getting-started" id="getting-started">Getting Started</h2>
<section id="installation" class="level3">
<h3 class="anchored" data-anchor-id="installation" id="installation">Installation</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install Python 3.14</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="ex">python3.14</span> <span class="at">-m</span> pip install <span class="at">--upgrade</span> pip</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Check version</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="ex">python3.14</span> <span class="at">--version</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Enable experimental features</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a><span class="ex">python3.14</span> <span class="at">--disable-gil</span> script.py  <span class="co"># For free-threading</span></span></code></pre></div></div>
</section>
<section id="migration-checklist" class="level3">
<h3 class="anchored" data-anchor-id="migration-checklist" id="migration-checklist">Migration Checklist</h3>
<ol type="1">
<li><strong>Test your existing code</strong> with Python 3.14</li>
<li><strong>Update type annotations</strong> to use new features</li>
<li><strong>Review deprecated warnings</strong> in your codebase</li>
<li><strong>Consider enabling free-threading</strong> for CPU-bound applications</li>
<li><strong>Update development tools</strong> and IDE configurations</li>
<li><strong>Benchmark performance</strong> improvements in your applications</li>
</ol>
</section>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="taking-advantage-of-new-features" class="level3">
<h3 class="anchored" data-anchor-id="taking-advantage-of-new-features" id="taking-advantage-of-new-features">Taking Advantage of New Features</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Use enhanced pattern matching for complex data processing</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> process_api_response(response):</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">match</span> response:</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a>        case {<span class="st">"status"</span>: <span class="st">"success"</span>, <span class="st">"data"</span>: <span class="bu">list</span>(items)} <span class="cf">if</span> <span class="bu">len</span>(items) <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> [process_item(item) <span class="cf">for</span> item <span class="kw">in</span> items]</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>        case {<span class="st">"status"</span>: <span class="st">"error"</span>, <span class="st">"message"</span>: <span class="bu">str</span>(msg)}:</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> APIError(msg)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">case</span> _:</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>            <span class="cf">raise</span> <span class="pp">ValueError</span>(<span class="st">"Unexpected response format"</span>)</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Leverage improved async features</span></span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a><span class="cf">async</span> <span class="kw">def</span> robust_async_operation():</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">async</span> <span class="cf">with</span> asyncio.TaskGroup() <span class="im">as</span> tg:</span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a>        tasks <span class="op">=</span> [tg.create_task(operation(i)) <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">5</span>)]</span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> [task.result() <span class="cf">for</span> task <span class="kw">in</span> tasks]</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Use new string methods for cleaner code</span></span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> clean_user_input(text):</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> text.normalize_whitespace().strip()</span></code></pre></div></div>
<p>Python 3.14 represents a significant step forward in Python’s evolution, focusing on performance, developer experience, and language consistency. The improvements make Python more efficient and enjoyable to work with while maintaining the language’s commitment to readability and simplicity.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python 3.14: The Next Evolution in Python Development]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/python/python-pi/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/python/python-pi/</guid>
      <pubDate>Fri, 23 May 2025 00:00:00 GMT</pubDate>
      
      <category>news</category>
      <content:encoded><![CDATA[






<section id="python-3.14-the-next-evolution-in-python-development" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/python/python-pi/pythonpi.jpg" class="img-fluid"></p>
<p>Python continues its steady march forward with the anticipated release of Python 3.14, marking another significant milestone in the language’s evolution. As the Python Software Foundation maintains its annual release cycle, Python 3.14 represents the ongoing commitment to improving performance, developer experience, and language capabilities.</p>
<section id="release-timeline-and-development" class="level2">
<h2 class="anchored" data-anchor-id="release-timeline-and-development" id="release-timeline-and-development">Release Timeline and Development</h2>
<p>Following Python’s established release schedule, Python 3.14 continues the pattern of annual major releases that began with Python 3.9. The development process follows the standard Python Enhancement Proposal (PEP) system, where community members propose, discuss, and refine new features before implementation.</p>
<p>The release represents months of collaborative work from core developers, contributors, and the broader Python community, focusing on backward compatibility while introducing meaningful improvements to the language.</p>
</section>
<section id="performance-enhancements" class="level2">
<h2 class="anchored" data-anchor-id="performance-enhancements" id="performance-enhancements">Performance Enhancements</h2>
<p>Python 3.14 builds upon the performance improvements introduced in recent versions. The development team has continued optimizing the interpreter, with particular attention to:</p>
<ul>
<li><strong>Memory Management</strong>: Further refinements to Python’s garbage collection system, reducing memory overhead and improving allocation efficiency</li>
<li><strong>Bytecode Optimization</strong>: Enhanced compilation processes that generate more efficient bytecode</li>
<li><strong>Standard Library Performance</strong>: Optimizations to frequently-used modules and functions</li>
</ul>
<p>These improvements contribute to faster execution times and reduced resource consumption, particularly beneficial for long-running applications and data-intensive workloads.</p>
</section>
<section id="language-features-and-syntax" class="level2">
<h2 class="anchored" data-anchor-id="language-features-and-syntax" id="language-features-and-syntax">Language Features and Syntax</h2>
<p>While maintaining Python’s philosophy of readability and simplicity, Python 3.14 introduces carefully considered language enhancements:</p>
<section id="type-system-improvements" class="level3">
<h3 class="anchored" data-anchor-id="type-system-improvements" id="type-system-improvements">Type System Improvements</h3>
<p>The static typing ecosystem continues to mature, with enhancements to type hints and better integration between runtime and static analysis tools. These improvements make it easier for developers to write type-safe code while maintaining Python’s dynamic nature.</p>
</section>
<section id="developer-experience-enhancements" class="level3">
<h3 class="anchored" data-anchor-id="developer-experience-enhancements" id="developer-experience-enhancements">Developer Experience Enhancements</h3>
<p>Several quality-of-life improvements have been introduced to make Python development more efficient and enjoyable. Error messages have been further refined to provide clearer guidance, and debugging capabilities have been enhanced.</p>
</section>
</section>
<section id="standard-library-updates" class="level2">
<h2 class="anchored" data-anchor-id="standard-library-updates" id="standard-library-updates">Standard Library Updates</h2>
<p>Python’s “batteries included” philosophy remains strong in 3.14, with updates across the standard library:</p>
<ul>
<li><strong>New Modules</strong>: Introduction of modules addressing modern development needs</li>
<li><strong>Deprecated Module Updates</strong>: Continued modernization of older modules while maintaining backward compatibility</li>
<li><strong>Security Enhancements</strong>: Strengthened cryptographic modules and security-related functionality</li>
</ul>
</section>
<section id="breaking-changes-and-migration" class="level2">
<h2 class="anchored" data-anchor-id="breaking-changes-and-migration" id="breaking-changes-and-migration">Breaking Changes and Migration</h2>
<p>Python 3.14 maintains the project’s commitment to stability. Any breaking changes are minimal and well-documented, with clear migration paths provided. The development team continues to balance innovation with the needs of existing codebases.</p>
<p>Most Python 3.13 code should run without modification on Python 3.14, though developers are encouraged to review the official migration guide for any project-specific considerations.</p>
</section>
<section id="community-impact" class="level2">
<h2 class="anchored" data-anchor-id="community-impact" id="community-impact">Community Impact</h2>
<p>The release reflects the vibrant Python ecosystem, with contributions from developers worldwide. The Python Software Foundation’s governance model ensures that changes serve the broad community while maintaining the language’s core principles.</p>
</section>
<section id="looking-forward" class="level2">
<h2 class="anchored" data-anchor-id="looking-forward" id="looking-forward">Looking Forward</h2>
<p>Python 3.14 sets the foundation for future developments while addressing current needs. The development team continues to explore areas such as:</p>
<ul>
<li>Performance optimization strategies</li>
<li>Improved tooling integration</li>
<li>Enhanced support for modern development practices</li>
<li>Continued evolution of the type system</li>
</ul>
</section>
<section id="getting-started-with-python-3.14" class="level2">
<h2 class="anchored" data-anchor-id="getting-started-with-python-3.14" id="getting-started-with-python-3.14">Getting Started with Python 3.14</h2>
<p>Developers interested in trying Python 3.14 can download it from the official Python website. The comprehensive documentation includes migration notes, feature explanations, and examples to help developers transition smoothly.</p>
<p>For production environments, thorough testing is recommended before upgrading, though the Python team’s commitment to stability makes the transition process straightforward for most applications.</p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Python 3.14 represents another solid step forward for the language, balancing innovation with stability. The release demonstrates the Python community’s continued dedication to creating a language that remains accessible to beginners while powerful enough for the most demanding applications.</p>
<p>As Python approaches its fourth decade, releases like 3.14 show that the language continues to evolve thoughtfully, maintaining its position as one of the world’s most popular and versatile programming languages.</p>
<hr>
<p><em>For the most current information about Python 3.14, including detailed release notes and migration guides, visit the official Python documentation at python.org.</em></p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DINOv2: A Deep Dive into Architecture and Training]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/dino-v2-explained/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/dino-v2-explained/</guid>
      <pubDate>Sat, 17 May 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="dinov2-a-deep-dive-into-architecture-and-training" class="level1">

<p>In 2023, Meta AI Research unveiled DINOv2 (Self-Distillation with No Labels v2), a breakthrough in self-supervised visual learning that produces remarkably versatile and robust visual features. This article provides a detailed exploration of DINOv2’s architecture and training methodology, explaining how it achieves state-of-the-art performance across diverse visual tasks without task-specific supervision.</p>
<p><img src="https://theja-vanka.github.io/blogs/posts/dino/dino-v2-explained/dinov2.jpeg" class="img-fluid"></p>
<section id="architectural-foundation-vision-transformers" class="level2">
<h2 class="anchored" data-anchor-id="architectural-foundation-vision-transformers" id="architectural-foundation-vision-transformers">Architectural Foundation: Vision Transformers</h2>
<p>At the heart of DINOv2 is the Vision Transformer (ViT) architecture, which has proven highly effective for computer vision tasks. DINOv2 specifically uses:</p>
<section id="vit-backbone-variants" class="level3">
<h3 class="anchored" data-anchor-id="vit-backbone-variants" id="vit-backbone-variants">ViT Backbone Variants</h3>
<ul>
<li><strong>ViT-S/14</strong>: Small model (22M parameters)</li>
<li><strong>ViT-B/14</strong>: Base model (87M parameters)<br>
</li>
<li><strong>ViT-L/14</strong>: Large model (304M parameters)</li>
<li><strong>ViT-g/14</strong>: Giant model (1.1B parameters)</li>
</ul>
<p>The “/14” indicates a patch size of 14×14 pixels. These patches are how images are tokenized before being processed by the transformer.</p>
</section>
<section id="architectural-enhancements" class="level3">
<h3 class="anchored" data-anchor-id="architectural-enhancements" id="architectural-enhancements">Architectural Enhancements</h3>
<p>DINOv2 incorporates several architectural improvements over the original DINO:</p>
<ol type="1">
<li><p><strong>Improved Layer Normalization</strong>: Uses a modified version of layer normalization that enhances stability during training at scale.</p></li>
<li><p><strong>SwiGLU Activation</strong>: Replaces standard ReLU or GELU activations with SwiGLU, which improves representation quality.</p></li>
<li><p><strong>Register Tokens</strong>: Additional learnable tokens (alongside the [CLS] token) that capture different aspects of image information.</p></li>
<li><p><strong>Attention Bias</strong>: Incorporates relative position embeddings through attention biases instead of absolute positional encodings.</p></li>
<li><p><strong>Post-Normalization</strong>: Places the layer normalization after the multi-head attention and feed-forward blocks rather than before them.</p></li>
</ol>
</section>
</section>
<section id="training-methodology-self-distillation-framework" class="level2">
<h2 class="anchored" data-anchor-id="training-methodology-self-distillation-framework" id="training-methodology-self-distillation-framework">Training Methodology: Self-Distillation Framework</h2>
<p>DINOv2’s training methodology centers around self-distillation, where a model essentially teaches itself. This is implemented through a student-teacher framework:</p>
<section id="teacher-student-architecture" class="level3">
<h3 class="anchored" data-anchor-id="teacher-student-architecture" id="teacher-student-architecture">Teacher-Student Architecture</h3>
<ul>
<li><strong>Student Network</strong>: The network being trained, updated via backpropagation</li>
<li><strong>Teacher Network</strong>: An exponential moving average (EMA) of the student’s parameters</li>
<li>Both networks share the same architecture but different parameters</li>
</ul>
<p>This approach creates a moving target that continuously evolves as training progresses, preventing trivial solutions where the network collapses to outputting the same representation for all inputs.</p>
</section>
<section id="multi-crop-strategy" class="level3">
<h3 class="anchored" data-anchor-id="multi-crop-strategy" id="multi-crop-strategy">Multi-Crop Strategy</h3>
<p>A key component of DINOv2’s training is its sophisticated multi-crop approach:</p>
<ol type="1">
<li><strong>Global Views</strong>: Two large crops covering significant portions of the image (224×224 pixels)</li>
<li><strong>Local Views</strong>: Multiple smaller crops capturing image details (96×96 pixels)</li>
</ol>
<p>The student network processes both global and local views, while the teacher network only processes global views. This forces the model to learn both global context and local details.</p>
</section>
<section id="self-supervised-objective" class="level3">
<h3 class="anchored" data-anchor-id="self-supervised-objective" id="self-supervised-objective">Self-Supervised Objective</h3>
<p>The training objective is a cross-entropy loss that encourages the student’s output distribution for local views to match the teacher’s output distribution for global views of the same image. Mathematically:</p>
<pre><code>L = H(Pt(g), Ps(l))</code></pre>
<p>Where:</p>
<ul>
<li>H is the cross-entropy</li>
<li>Pt(g) is the teacher’s prediction on global views</li>
<li>Ps(l) is the student’s prediction on local views</li>
</ul>
<p>The teacher’s outputs are sharpened using a temperature parameter that gradually decreases throughout training, making the targets increasingly focused on specific features.</p>
</section>
</section>
<section id="data-curation-and-processing" class="level2">
<h2 class="anchored" data-anchor-id="data-curation-and-processing" id="data-curation-and-processing">Data Curation and Processing</h2>
<p>DINOv2’s impressive performance comes not just from architecture but from meticulous data preparation:</p>
<section id="lvd-142m-dataset" class="level3">
<h3 class="anchored" data-anchor-id="lvd-142m-dataset" id="lvd-142m-dataset">LVD-142M Dataset</h3>
<p>The researchers curated a high-quality dataset of 142 million images from publicly available sources, with careful filtering to remove:</p>
<ul>
<li>Duplicate images</li>
<li>Low-quality content</li>
<li>Inappropriate material</li>
<li>Text-heavy images</li>
<li>Human faces</li>
</ul>
</section>
<section id="data-augmentation-pipeline" class="level3">
<h3 class="anchored" data-anchor-id="data-augmentation-pipeline" id="data-augmentation-pipeline">Data Augmentation Pipeline</h3>
<p>During training, DINOv2 employs a robust augmentation strategy:</p>
<ol type="1">
<li><strong>Random resized cropping</strong>: Different sized views of the same image</li>
<li><strong>Random horizontal flips</strong>: Mirroring images horizontally</li>
<li><strong>Color jittering</strong>: Altering brightness, contrast, saturation, and hue</li>
<li><strong>Gaussian blur</strong>: Adding controlled blur to some views</li>
<li><strong>Solarization</strong>: Inverting pixels above a threshold (applied selectively)</li>
</ol>
<p>These augmentations create diverse views while preserving the semantic content, forcing the model to learn invariance to these transformations.</p>
</section>
</section>
<section id="distributed-training-strategy" class="level2">
<h2 class="anchored" data-anchor-id="distributed-training-strategy" id="distributed-training-strategy">Distributed Training Strategy</h2>
<p>Training a model of DINOv2’s scale requires sophisticated distributed computing approaches:</p>
<section id="optimization-details" class="level3">
<h3 class="anchored" data-anchor-id="optimization-details" id="optimization-details">Optimization Details</h3>
<ul>
<li><strong>Optimizer</strong>: AdamW with a cosine learning rate schedule</li>
<li><strong>Gradient Accumulation</strong>: Used to handle effectively larger batch sizes</li>
<li><strong>Mixed Precision</strong>: FP16 calculations to speed up training</li>
<li><strong>Sharding</strong>: Model parameters distributed across multiple GPUs</li>
</ul>
</section>
<section id="effective-batch-size" class="level3">
<h3 class="anchored" data-anchor-id="effective-batch-size" id="effective-batch-size">Effective Batch Size</h3>
<p>DINOv2 uses enormous effective batch sizes (up to 65,536 images) by leveraging distributed training across hundreds of GPUs. This large batch size is crucial for learning high-quality representations.</p>
</section>
</section>
<section id="regularization-techniques" class="level2">
<h2 class="anchored" data-anchor-id="regularization-techniques" id="regularization-techniques">Regularization Techniques</h2>
<p>To prevent representation collapse and ensure diverse, meaningful features, DINOv2 employs:</p>
<ol type="1">
<li><strong>Centering</strong>: Ensuring the average output across the batch remains close to zero</li>
<li><strong>Sharpening</strong>: Gradually decreasing the temperature parameter of the teacher’s softmax</li>
<li><strong>DALL-E VAE Integration</strong>: Using a pre-trained DALL-E VAE to improve representation quality</li>
<li><strong>Weight Decay</strong>: Applied differently to various components of the model</li>
</ol>
</section>
<section id="feature-extraction-and-deployment" class="level2">
<h2 class="anchored" data-anchor-id="feature-extraction-and-deployment" id="feature-extraction-and-deployment">Feature Extraction and Deployment</h2>
<p>After training, DINOv2 can be used in different ways:</p>
<section id="feature-extraction-methods" class="level3">
<h3 class="anchored" data-anchor-id="feature-extraction-methods" id="feature-extraction-methods">Feature Extraction Methods</h3>
<ul>
<li><strong>[CLS] Token</strong>: The class token representation serves as a global image descriptor</li>
<li><strong>Register Tokens</strong>: Multiple specialized tokens that capture different aspects of the image</li>
<li><strong>Patch Tokens</strong>: Local features corresponding to specific regions of the image</li>
</ul>
</section>
<section id="model-distillation" class="level3">
<h3 class="anchored" data-anchor-id="model-distillation" id="model-distillation">Model Distillation</h3>
<p>The researchers also created smaller, distilled versions of DINOv2 that maintain much of the performance while requiring significantly fewer computational resources for deployment.</p>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>DINOv2 represents a remarkable achievement in self-supervised visual learning. Its sophisticated architecture and training methodology enable it to learn general-purpose visual features that transfer exceptionally well across diverse tasks. The careful balance of architectural innovations, data curation, and training techniques creates a visual representation system that approaches the versatility and power that we’ve seen in large language models.</p>
<p>The success of DINOv2 highlights how self-supervised learning can leverage vast amounts of unlabeled data to create foundation models for computer vision that may eventually eliminate the need for task-specific supervised training in many applications.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DINO: Emerging Properties in Self-Supervised Vision Transformers]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/dino-explained/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/dino-explained/</guid>
      <pubDate>Tue, 13 May 2025 00:00:00 GMT</pubDate>
      
      <category>research</category>
      <category>intermediate</category>
      <content:encoded><![CDATA[






<section id="dino-emerging-properties-in-self-supervised-vision-transformers" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/dino-explained/dino.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>In 2021, Facebook AI Research (now Meta AI) introduced DINO (Self-Distillation with No Labels), a groundbreaking approach to self-supervised learning in computer vision. Published in the paper “Emerging Properties in Self-Supervised Vision Transformers” by Mathilde Caron and colleagues, DINO represented a significant leap forward in learning visual representations without relying on labeled data. This article explores the key aspects of the original DINO research, its methodology, and its implications for computer vision.</p>
</section>
<section id="the-challenge-of-self-supervised-learning" class="level2">
<h2 class="anchored" data-anchor-id="the-challenge-of-self-supervised-learning" id="the-challenge-of-self-supervised-learning">The Challenge of Self-Supervised Learning</h2>
<p>Traditionally, computer vision models have relied heavily on supervised learning using massive labeled datasets like ImageNet. However, creating such datasets requires enormous human effort for annotation. Self-supervised learning aims to overcome this limitation by teaching models to learn meaningful representations from unlabeled images, which are abundantly available.</p>
<p>Several approaches to self-supervised learning had been proposed before DINO, including:</p>
<ul>
<li>Contrastive learning (SimCLR, MoCo)</li>
<li>Clustering-based methods (SwAV, DeepCluster)</li>
<li>Predictive methods (predicting rotations, solving jigsaw puzzles)</li>
</ul>
<p>DINO introduced a novel approach that combined elements of knowledge distillation and self-supervision to produce surprisingly effective visual representations.</p>
</section>
<section id="dinos-core-methodology" class="level2">
<h2 class="anchored" data-anchor-id="dinos-core-methodology" id="dinos-core-methodology">DINO’s Core Methodology</h2>
<p>DINO’s key innovation was adapting the concept of knowledge distillation to a self-supervised setting. Traditional knowledge distillation involves a teacher model transferring knowledge to a student model, but DINO cleverly applies this concept without requiring separate pre-trained teacher models.</p>
<section id="self-distillation-framework" class="level3">
<h3 class="anchored" data-anchor-id="self-distillation-framework" id="self-distillation-framework">Self-Distillation Framework</h3>
<p>In DINO:</p>
<ol type="1">
<li><strong>Teacher and Student Networks</strong>: Both networks share the same architecture but have different parameters.</li>
<li><strong>Parameter Updates</strong>:
<ul>
<li>The student network is updated through standard backpropagation</li>
<li>The teacher is updated as an exponential moving average (EMA) of the student’s parameters</li>
</ul></li>
</ol>
<p>This creates a bootstrapping effect where the teacher continually provides slightly better targets for the student to learn from.</p>
</section>
<section id="multi-crop-training-strategy" class="level3">
<h3 class="anchored" data-anchor-id="multi-crop-training-strategy" id="multi-crop-training-strategy">Multi-crop Training Strategy</h3>
<p>DINO employs a sophisticated data augmentation approach:</p>
<ol type="1">
<li><strong>Global Views</strong>: Two larger crops of an image (covering significant portions)</li>
<li><strong>Local Views</strong>: Several smaller crops that focus on details</li>
</ol>
<p>The student network processes all views (global and local), while the teacher only processes the global views. The student network is trained to predict the teacher’s output for the global views from the local views, forcing it to understand both global context and local details.</p>
</section>
<section id="self-supervision-objective" class="level3">
<h3 class="anchored" data-anchor-id="self-supervision-objective" id="self-supervision-objective">Self-Supervision Objective</h3>
<p>The training objective minimizes the cross-entropy between the teacher’s output distribution for global views and the student’s output distribution for all views (both global and local). This encourages consistency across different scales and regions of the image.</p>
</section>
<section id="collapse-prevention" class="level3">
<h3 class="anchored" data-anchor-id="collapse-prevention" id="collapse-prevention">Collapse Prevention</h3>
<p>A major challenge in self-supervised learning is representation collapse—where the model outputs the same representation regardless of input. DINO prevents this through:</p>
<ol type="1">
<li><strong>Centering</strong>: Subtracting a running average of the network’s output from the current output</li>
<li><strong>Sharpening</strong>: Using a temperature parameter in the softmax that gradually decreases throughout training</li>
</ol>
<p>These techniques ensure the model learns diverse and meaningful features.</p>
</section>
</section>
<section id="vision-transformer-architecture" class="level2">
<h2 class="anchored" data-anchor-id="vision-transformer-architecture" id="vision-transformer-architecture">Vision Transformer Architecture</h2>
<p>While DINO can be applied to various neural network architectures, the paper demonstrated particularly impressive results using Vision Transformers (ViT). The combination of DINO with ViT offered several advantages:</p>
<ol type="1">
<li><strong>Patch-based processing</strong>: ViT divides images into patches, which aligns well with DINO’s local-global view approach</li>
<li><strong>Self-attention mechanism</strong>: Enables capturing long-range dependencies in images</li>
<li><strong>Scalability</strong>: The architecture scales effectively with more data and parameters</li>
</ol>
<p>DINO was implemented with various sizes of ViT models: - ViT-S: Small (22M parameters) - ViT-B: Base (86M parameters)</p>
</section>
<section id="emergent-properties" class="level2">
<h2 class="anchored" data-anchor-id="emergent-properties" id="emergent-properties">Emergent Properties</h2>
<p>The most surprising aspect of DINO was the emergence of properties that weren’t explicitly trained for:</p>
<section id="unsupervised-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="unsupervised-segmentation" id="unsupervised-segmentation">Unsupervised Segmentation</h3>
<p>Remarkably, the self-attention maps from DINO-trained Vision Transformers naturally highlighted object boundaries in images. Without any segmentation supervision, the model learned to focus attention on semantically meaningful regions. This surprised the research community and suggested that the model had developed a deeper understanding of visual structures than previous self-supervised approaches.</p>
</section>
<section id="local-feature-quality" class="level3">
<h3 class="anchored" data-anchor-id="local-feature-quality" id="local-feature-quality">Local Feature Quality</h3>
<p>DINO produced local features (from patch tokens) that proved extremely effective for tasks requiring spatial understanding, like semantic segmentation. The features exhibited strong semantic coherence across spatial regions.</p>
</section>
<section id="nearest-neighbor-performance" class="level3">
<h3 class="anchored" data-anchor-id="nearest-neighbor-performance" id="nearest-neighbor-performance">Nearest Neighbor Performance</h3>
<p>Using DINO features with simple k-nearest neighbor classifiers achieved impressive accuracy on ImageNet classification, demonstrating the quality of the learned representations.</p>
</section>
</section>
<section id="training-details" class="level2">
<h2 class="anchored" data-anchor-id="training-details" id="training-details">Training Details</h2>
<p>The original DINO paper described several important implementation details:</p>
<section id="data-augmentation" class="level3">
<h3 class="anchored" data-anchor-id="data-augmentation" id="data-augmentation">Data Augmentation</h3>
<p>The augmentation pipeline included: - Random resized cropping - Horizontal flipping - Color jittering - Gaussian blur - Solarization (for some views)</p>
</section>
<section id="optimization" class="level3">
<h3 class="anchored" data-anchor-id="optimization" id="optimization">Optimization</h3>
<ul>
<li>Optimizer: AdamW with weight decay</li>
<li>Learning rate: Cosine schedule with linear warmup</li>
<li>Batch size: 1024 images</li>
</ul>
</section>
<section id="architectural-choices" class="level3">
<h3 class="anchored" data-anchor-id="architectural-choices" id="architectural-choices">Architectural Choices</h3>
<ul>
<li>Projection head: 3-layer MLP with bottleneck structure</li>
<li>CLS token: Used as global image representation</li>
<li>Positional embeddings: Standard learnable embeddings</li>
</ul>
</section>
</section>
<section id="results-and-impact" class="level2">
<h2 class="anchored" data-anchor-id="results-and-impact" id="results-and-impact">Results and Impact</h2>
<p>DINO achieved remarkable results on several benchmarks:</p>
<section id="imagenet-classification" class="level3">
<h3 class="anchored" data-anchor-id="imagenet-classification" id="imagenet-classification">ImageNet Classification</h3>
<ul>
<li>80.1% top-1 accuracy with k-NN classification using ViT-B</li>
<li>Competitive with supervised methods and superior to previous self-supervised approaches</li>
</ul>
</section>
<section id="downstream-tasks" class="level3">
<h3 class="anchored" data-anchor-id="downstream-tasks" id="downstream-tasks">Downstream Tasks</h3>
<p>DINO features transferred successfully to: - Object detection - Semantic segmentation - Video instance segmentation</p>
</section>
<section id="robustness" class="level3">
<h3 class="anchored" data-anchor-id="robustness" id="robustness">Robustness</h3>
<p>The features showed strong robustness to distribution shifts and generalized well to out-of-distribution data.</p>
</section>
</section>
<section id="comparison-with-previous-methods" class="level2">
<h2 class="anchored" data-anchor-id="comparison-with-previous-methods" id="comparison-with-previous-methods">Comparison with Previous Methods</h2>
<p>DINO differed from earlier self-supervised approaches in several key ways:</p>
<section id="versus-contrastive-learning" class="level3">
<h3 class="anchored" data-anchor-id="versus-contrastive-learning" id="versus-contrastive-learning">Versus Contrastive Learning</h3>
<ul>
<li>No need for large negative sample sets</li>
<li>No dependence on intricate data augmentation strategies</li>
<li>More stable training dynamics</li>
</ul>
</section>
<section id="versus-clustering-based-methods" class="level3">
<h3 class="anchored" data-anchor-id="versus-clustering-based-methods" id="versus-clustering-based-methods">Versus Clustering-Based Methods</h3>
<ul>
<li>No explicit clustering objective</li>
<li>More straightforward implementation</li>
<li>Better scaling properties with model size</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The original DINO research represented a significant step forward in self-supervised visual representation learning. By combining knowledge distillation techniques with self-supervision and leveraging the Vision Transformer architecture, DINO produced features with remarkable properties for a wide range of computer vision tasks.</p>
<p>The emergence of semantic features and unsupervised segmentation abilities demonstrated that well-designed self-supervised methods could lead to models that understand visual concepts in ways previously thought to require explicit supervision. DINO laid the groundwork for subsequent advances in this field, including its successor DINOv2, and helped establish self-supervised learning as a powerful paradigm for computer vision.</p>
<p>The success of DINO highlighted the potential for self-supervised learning to reduce reliance on large labeled datasets and pointed toward a future where visual foundation models could be developed primarily through self-supervision – mirroring similar developments in natural language processing with large language models.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[DINOv2: Comprehensive Implementation Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/dino/dinov2/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/dino/dinov2/</guid>
      <pubDate>Sat, 03 May 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="dinov2-comprehensive-implementation-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/dino/dinov2/dino.jpg" class="img-fluid"></p>
<p>DINOv2 is a state-of-the-art self-supervised vision model developed by Meta AI Research that builds upon the original DINO (Self-Distillation with No Labels) framework. This guide will walk you through understanding, implementing, and leveraging DINOv2 for various computer vision tasks.</p>
<section id="introduction-to-dinov2" class="level2">
<h2 class="anchored" data-anchor-id="introduction-to-dinov2" id="introduction-to-dinov2">Introduction to DINOv2</h2>
<p>DINOv2 is a self-supervised learning method for vision that produces high-quality visual features without requiring labeled data. It extends the original DINO architecture with several improvements:</p>
<ul>
<li>Training on a large and diverse dataset of images</li>
<li>Enhanced teacher-student architecture</li>
<li>Improved augmentation strategy</li>
<li>Multi-scale feature learning</li>
<li>Support for various Vision Transformer (ViT) backbones</li>
</ul>
<p>The result is a versatile foundation model that can be adapted to numerous vision tasks with minimal fine-tuning.</p>
</section>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<p>To use DINOv2, you’ll need to install the official implementation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Install PyTorch first if not already installed</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="co"># pip install torch torchvision</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Install DINOv2</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install git+https://github.com/facebookresearch/dinov2</span></code></pre></div></div>
<p>Alternatively, you can clone the repository and install it locally:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/facebookresearch/dinov2.git</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="bu">cd</span> dinov2</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install <span class="at">-e</span> .</span></code></pre></div></div>
<section id="dependencies" class="level3">
<h3 class="anchored" data-anchor-id="dependencies" id="dependencies">Dependencies</h3>
<p>DINOv2 requires:</p>
<ul>
<li>Python 3.8+</li>
<li>PyTorch 1.12+</li>
<li>torchvision</li>
<li>CUDA (for GPU acceleration)</li>
</ul>
</section>
</section>
<section id="loading-pre-trained-models" class="level2">
<h2 class="anchored" data-anchor-id="loading-pre-trained-models" id="loading-pre-trained-models">Loading Pre-trained Models</h2>
<p>DINOv2 provides several pre-trained models with different sizes and capabilities:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dinov2.models <span class="im">import</span> build_model_from_cfg</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dinov2.configs <span class="im">import</span> get_config</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Available model sizes: 'small', 'base', 'large', 'giant'</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>model_size <span class="op">=</span> <span class="st">'base'</span></span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>cfg <span class="op">=</span> get_config(<span class="ss">f"dinov2_</span><span class="sc">{</span>model_size<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> build_model_from_cfg(cfg)</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained weights</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>checkpoint_path <span class="op">=</span> <span class="ss">f"dinov2_</span><span class="sc">{</span>model_size<span class="sc">}</span><span class="ss">_pretrain.pth"</span>  <span class="co"># Download this from Meta AI's repository</span></span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>checkpoint <span class="op">=</span> torch.load(checkpoint_path, map_location<span class="op">=</span><span class="st">"cpu"</span>)</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>model.load_state_dict(checkpoint[<span class="st">"model"</span>])</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Move to GPU if available</span></span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()  <span class="co"># Set to evaluation mode</span></span></code></pre></div></div>
<p>You can also use the Hugging Face Transformers library for an easier integration:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoImageProcessor, AutoModel</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Available model sizes: 'small', 'base', 'large', 'giant'</span></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>model_name <span class="op">=</span> <span class="st">"facebook/dinov2-base"</span></span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> AutoImageProcessor.from_pretrained(model_name)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(model_name)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Move to GPU if available</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span></code></pre></div></div>
</section>
<section id="feature-extraction" class="level2">
<h2 class="anchored" data-anchor-id="feature-extraction" id="feature-extraction">Feature Extraction</h2>
<p>One of DINOv2’s key strengths is its ability to extract powerful visual features:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoImageProcessor, AutoModel</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Load model</span></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>model_name <span class="op">=</span> <span class="st">"facebook/dinov2-base"</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>processor <span class="op">=</span> AutoImageProcessor.from_pretrained(model_name)</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(model_name)</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>model.to(torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>))</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Load and preprocess image</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">"path/to/your/image.jpg"</span>).convert(<span class="st">"RGB"</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>inputs <span class="op">=</span> processor(images<span class="op">=</span>image, return_tensors<span class="op">=</span><span class="st">"pt"</span>).to(model.device)</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract features</span></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>    outputs <span class="op">=</span> model(<span class="op">**</span>inputs)</span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Get CLS token features (useful for classification tasks)</span></span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>cls_features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">0</span>]</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Get patch features (useful for dense prediction tasks like segmentation)</span></span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>patch_features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">1</span>:]</span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"CLS features shape: </span><span class="sc">{</span>cls_features<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Patch features shape: </span><span class="sc">{</span>patch_features<span class="sc">.</span>shape<span class="sc">}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="fine-tuning-for-downstream-tasks" class="level2">
<h2 class="anchored" data-anchor-id="fine-tuning-for-downstream-tasks" id="fine-tuning-for-downstream-tasks">Fine-tuning for Downstream Tasks</h2>
<p>DINOv2 can be fine-tuned for specific vision tasks:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Load pre-trained DINOv2 model</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>backbone <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a custom classification head</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ClassificationHead(nn.Module):</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, backbone, num_classes<span class="op">=</span><span class="dv">1000</span>):</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> backbone</span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(backbone.config.hidden_size, num_classes)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.backbone(x)</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        cls_token <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">0</span>]</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.classifier(cls_token)</span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the complete model</span></span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> ClassificationHead(backbone, num_classes<span class="op">=</span><span class="dv">100</span>)  <span class="co"># For 100 classes</span></span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Define optimizer and loss function</span></span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW(model.parameters(), lr<span class="op">=</span><span class="fl">1e-5</span>)</span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop example</span></span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_one_epoch(model, dataloader, optimizer, criterion, device):</span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>    total_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> batch <span class="kw">in</span> dataloader:</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        images <span class="op">=</span> batch[<span class="st">"pixel_values"</span>].to(device)</span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>        labels <span class="op">=</span> batch[<span class="st">"labels"</span>].to(device)</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(images)</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, labels)</span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        total_loss <span class="op">+=</span> loss.item()</span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> total_loss <span class="op">/</span> <span class="bu">len</span>(dataloader)</span></code></pre></div></div>
</section>
<section id="image-classification-example" class="level2">
<h2 class="anchored" data-anchor-id="image-classification-example" id="image-classification-example">Image Classification Example</h2>
<p>Here’s a complete example for image classification using DINOv2:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, Dataset</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.datasets <span class="im">import</span> ImageFolder</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoImageProcessor, AutoModel</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Define the dataset and transforms</span></span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>    transforms.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Load your dataset (adjust the path)</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>train_dataset <span class="op">=</span> ImageFolder(root<span class="op">=</span><span class="st">"path/to/train"</span>, transform<span class="op">=</span>transform)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>val_dataset <span class="op">=</span> ImageFolder(root<span class="op">=</span><span class="st">"path/to/val"</span>, transform<span class="op">=</span>transform)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>, num_workers<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>val_loader <span class="op">=</span> DataLoader(val_dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">False</span>, num_workers<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the model</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv2Classifier(nn.Module):</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes):</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dinov2 <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>)</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.classifier <span class="op">=</span> nn.Linear(<span class="dv">768</span>, num_classes)  <span class="co"># 768 is the hidden size for base model</span></span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Extract features</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.set_grad_enabled(<span class="va">self</span>.training):</span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>            features <span class="op">=</span> <span class="va">self</span>.dinov2(x).last_hidden_state[:, <span class="dv">0</span>]  <span class="co"># Get CLS token</span></span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classify</span></span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>.classifier(features)</span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> logits</span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-38"><a href="#cb7-38" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize model</span></span>
<span id="cb7-39"><a href="#cb7-39" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb7-40"><a href="#cb7-40" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DINOv2Classifier(num_classes<span class="op">=</span><span class="bu">len</span>(train_dataset.classes))</span>
<span id="cb7-41"><a href="#cb7-41" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span>
<span id="cb7-42"><a href="#cb7-42" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-43"><a href="#cb7-43" aria-hidden="true" tabindex="-1"></a><span class="co"># Define optimizer and loss function</span></span>
<span id="cb7-44"><a href="#cb7-44" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW([</span>
<span id="cb7-45"><a href="#cb7-45" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'params'</span>: model.classifier.parameters(), <span class="st">'lr'</span>: <span class="fl">1e-3</span>},</span>
<span id="cb7-46"><a href="#cb7-46" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'params'</span>: model.dinov2.parameters(), <span class="st">'lr'</span>: <span class="fl">1e-5</span>}</span>
<span id="cb7-47"><a href="#cb7-47" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb7-48"><a href="#cb7-48" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb7-49"><a href="#cb7-49" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-50"><a href="#cb7-50" aria-hidden="true" tabindex="-1"></a><span class="co"># Training loop</span></span>
<span id="cb7-51"><a href="#cb7-51" aria-hidden="true" tabindex="-1"></a>num_epochs <span class="op">=</span> <span class="dv">10</span></span>
<span id="cb7-52"><a href="#cb7-52" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(num_epochs):</span>
<span id="cb7-53"><a href="#cb7-53" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Training</span></span>
<span id="cb7-54"><a href="#cb7-54" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb7-55"><a href="#cb7-55" aria-hidden="true" tabindex="-1"></a>    train_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-56"><a href="#cb7-56" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-57"><a href="#cb7-57" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-58"><a href="#cb7-58" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-59"><a href="#cb7-59" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> inputs, targets <span class="kw">in</span> train_loader:</span>
<span id="cb7-60"><a href="#cb7-60" aria-hidden="true" tabindex="-1"></a>        inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb7-61"><a href="#cb7-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-62"><a href="#cb7-62" aria-hidden="true" tabindex="-1"></a>        optimizer.zero_grad()</span>
<span id="cb7-63"><a href="#cb7-63" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(inputs)</span>
<span id="cb7-64"><a href="#cb7-64" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb7-65"><a href="#cb7-65" aria-hidden="true" tabindex="-1"></a>        loss.backward()</span>
<span id="cb7-66"><a href="#cb7-66" aria-hidden="true" tabindex="-1"></a>        optimizer.step()</span>
<span id="cb7-67"><a href="#cb7-67" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-68"><a href="#cb7-68" aria-hidden="true" tabindex="-1"></a>        train_loss <span class="op">+=</span> loss.item()</span>
<span id="cb7-69"><a href="#cb7-69" aria-hidden="true" tabindex="-1"></a>        _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb7-70"><a href="#cb7-70" aria-hidden="true" tabindex="-1"></a>        total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb7-71"><a href="#cb7-71" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb7-72"><a href="#cb7-72" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-73"><a href="#cb7-73" aria-hidden="true" tabindex="-1"></a>    train_accuracy <span class="op">=</span> <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span>
<span id="cb7-74"><a href="#cb7-74" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-75"><a href="#cb7-75" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Validation</span></span>
<span id="cb7-76"><a href="#cb7-76" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb7-77"><a href="#cb7-77" aria-hidden="true" tabindex="-1"></a>    val_loss <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-78"><a href="#cb7-78" aria-hidden="true" tabindex="-1"></a>    correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-79"><a href="#cb7-79" aria-hidden="true" tabindex="-1"></a>    total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-80"><a href="#cb7-80" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-81"><a href="#cb7-81" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb7-82"><a href="#cb7-82" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> inputs, targets <span class="kw">in</span> val_loader:</span>
<span id="cb7-83"><a href="#cb7-83" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb7-84"><a href="#cb7-84" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb7-85"><a href="#cb7-85" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb7-86"><a href="#cb7-86" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-87"><a href="#cb7-87" aria-hidden="true" tabindex="-1"></a>            val_loss <span class="op">+=</span> loss.item()</span>
<span id="cb7-88"><a href="#cb7-88" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb7-89"><a href="#cb7-89" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb7-90"><a href="#cb7-90" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb7-91"><a href="#cb7-91" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-92"><a href="#cb7-92" aria-hidden="true" tabindex="-1"></a>    val_accuracy <span class="op">=</span> <span class="dv">100</span> <span class="op">*</span> correct <span class="op">/</span> total</span>
<span id="cb7-93"><a href="#cb7-93" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-94"><a href="#cb7-94" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">/</span><span class="sc">{</span>num_epochs<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb7-95"><a href="#cb7-95" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Train Loss: </span><span class="sc">{</span>train_loss<span class="op">/</span><span class="bu">len</span>(train_loader)<span class="sc">:.4f}</span><span class="ss">, Train Acc: </span><span class="sc">{</span>train_accuracy<span class="sc">:.2f}</span><span class="ss">%"</span>)</span>
<span id="cb7-96"><a href="#cb7-96" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Val Loss: </span><span class="sc">{</span>val_loss<span class="op">/</span><span class="bu">len</span>(val_loader)<span class="sc">:.4f}</span><span class="ss">, Val Acc: </span><span class="sc">{</span>val_accuracy<span class="sc">:.2f}</span><span class="ss">%"</span>)</span></code></pre></div></div>
</section>
<section id="semantic-segmentation-example" class="level2">
<h2 class="anchored" data-anchor-id="semantic-segmentation-example" id="semantic-segmentation-example">Semantic Segmentation Example</h2>
<p>DINOv2 is particularly powerful for segmentation tasks:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv2Segmenter(nn.Module):</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load DINOv2 backbone</span></span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define segmentation head</span></span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        hidden_dim <span class="op">=</span> <span class="va">self</span>.backbone.config.hidden_size</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.segmentation_head <span class="op">=</span> nn.Sequential(</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(hidden_dim, hidden_dim, kernel_size<span class="op">=</span><span class="dv">3</span>, padding<span class="op">=</span><span class="dv">1</span>),</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>            nn.BatchNorm2d(hidden_dim),</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>            nn.Conv2d(hidden_dim, num_classes, kernel_size<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Image size and patch size for reshaping</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_size <span class="op">=</span> <span class="dv">224</span></span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_size <span class="op">=</span> <span class="dv">14</span>  <span class="co"># For ViT-Base</span></span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> (<span class="va">self</span>.image_size <span class="op">//</span> <span class="va">self</span>.patch_size) <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get patch features</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.backbone(x)</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        patch_features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">1</span>:]  <span class="co"># Remove CLS token</span></span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Reshape to 2D spatial layout</span></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>        B <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>        H <span class="op">=</span> W <span class="op">=</span> <span class="va">self</span>.image_size <span class="op">//</span> <span class="va">self</span>.patch_size</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>        patch_features <span class="op">=</span> patch_features.reshape(B, H, W, <span class="op">-</span><span class="dv">1</span>).permute(<span class="dv">0</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">2</span>)</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply segmentation head</span></span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a>        segmentation_logits <span class="op">=</span> <span class="va">self</span>.segmentation_head(patch_features)</span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-39"><a href="#cb8-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Upsample to original image size</span></span>
<span id="cb8-40"><a href="#cb8-40" aria-hidden="true" tabindex="-1"></a>        segmentation_logits <span class="op">=</span> F.interpolate(</span>
<span id="cb8-41"><a href="#cb8-41" aria-hidden="true" tabindex="-1"></a>            segmentation_logits, </span>
<span id="cb8-42"><a href="#cb8-42" aria-hidden="true" tabindex="-1"></a>            size<span class="op">=</span>(<span class="va">self</span>.image_size, <span class="va">self</span>.image_size), </span>
<span id="cb8-43"><a href="#cb8-43" aria-hidden="true" tabindex="-1"></a>            mode<span class="op">=</span><span class="st">'bilinear'</span>, </span>
<span id="cb8-44"><a href="#cb8-44" aria-hidden="true" tabindex="-1"></a>            align_corners<span class="op">=</span><span class="va">False</span></span>
<span id="cb8-45"><a href="#cb8-45" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb8-46"><a href="#cb8-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-47"><a href="#cb8-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> segmentation_logits</span>
<span id="cb8-48"><a href="#cb8-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-49"><a href="#cb8-49" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model and move to device</span></span>
<span id="cb8-50"><a href="#cb8-50" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb8-51"><a href="#cb8-51" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DINOv2Segmenter(num_classes<span class="op">=</span><span class="dv">21</span>)  <span class="co"># 21 classes for Pascal VOC</span></span>
<span id="cb8-52"><a href="#cb8-52" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span>
<span id="cb8-53"><a href="#cb8-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-54"><a href="#cb8-54" aria-hidden="true" tabindex="-1"></a><span class="co"># Define optimizer and loss function</span></span>
<span id="cb8-55"><a href="#cb8-55" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.AdamW([</span>
<span id="cb8-56"><a href="#cb8-56" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'params'</span>: model.segmentation_head.parameters(), <span class="st">'lr'</span>: <span class="fl">1e-3</span>},</span>
<span id="cb8-57"><a href="#cb8-57" aria-hidden="true" tabindex="-1"></a>    {<span class="st">'params'</span>: model.backbone.parameters(), <span class="st">'lr'</span>: <span class="fl">1e-5</span>}</span>
<span id="cb8-58"><a href="#cb8-58" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb8-59"><a href="#cb8-59" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss(ignore_index<span class="op">=</span><span class="dv">255</span>)  <span class="co"># 255 is typically the ignore index</span></span>
<span id="cb8-60"><a href="#cb8-60" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-61"><a href="#cb8-61" aria-hidden="true" tabindex="-1"></a><span class="co"># Rest of the training code would be similar to the classification example</span></span></code></pre></div></div>
</section>
<section id="object-detection-example" class="level2">
<h2 class="anchored" data-anchor-id="object-detection-example" id="object-detection-example">Object Detection Example</h2>
<p>Here’s how to use DINOv2 features for object detection with a simple detection head:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DINOv2Detector(nn.Module):</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, num_classes):</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load DINOv2 backbone</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.backbone <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>)</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>        hidden_dim <span class="op">=</span> <span class="va">self</span>.backbone.config.hidden_size</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Detection heads</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.box_predictor <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-15"><a href="#cb9-15" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_dim, hidden_dim),</span>
<span id="cb9-16"><a href="#cb9-16" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb9-17"><a href="#cb9-17" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_dim, <span class="dv">4</span>)  <span class="co"># (x1, y1, x2, y2)</span></span>
<span id="cb9-18"><a href="#cb9-18" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-19"><a href="#cb9-19" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-20"><a href="#cb9-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.class_predictor <span class="op">=</span> nn.Sequential(</span>
<span id="cb9-21"><a href="#cb9-21" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_dim, hidden_dim),</span>
<span id="cb9-22"><a href="#cb9-22" aria-hidden="true" tabindex="-1"></a>            nn.ReLU(inplace<span class="op">=</span><span class="va">True</span>),</span>
<span id="cb9-23"><a href="#cb9-23" aria-hidden="true" tabindex="-1"></a>            nn.Linear(hidden_dim, num_classes <span class="op">+</span> <span class="dv">1</span>)  <span class="co"># +1 for background</span></span>
<span id="cb9-24"><a href="#cb9-24" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb9-25"><a href="#cb9-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-26"><a href="#cb9-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Image size and patch size for feature map creation</span></span>
<span id="cb9-27"><a href="#cb9-27" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_size <span class="op">=</span> <span class="dv">224</span></span>
<span id="cb9-28"><a href="#cb9-28" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_size <span class="op">=</span> <span class="dv">14</span>  <span class="co"># For ViT-Base</span></span>
<span id="cb9-29"><a href="#cb9-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-30"><a href="#cb9-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb9-31"><a href="#cb9-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get features</span></span>
<span id="cb9-32"><a href="#cb9-32" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> <span class="va">self</span>.backbone(x)</span>
<span id="cb9-33"><a href="#cb9-33" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> outputs.last_hidden_state[:, <span class="dv">1</span>:]  <span class="co"># Remove CLS token</span></span>
<span id="cb9-34"><a href="#cb9-34" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-35"><a href="#cb9-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Reshape to 2D spatial layout</span></span>
<span id="cb9-36"><a href="#cb9-36" aria-hidden="true" tabindex="-1"></a>        B <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb9-37"><a href="#cb9-37" aria-hidden="true" tabindex="-1"></a>        H <span class="op">=</span> W <span class="op">=</span> <span class="va">self</span>.image_size <span class="op">//</span> <span class="va">self</span>.patch_size</span>
<span id="cb9-38"><a href="#cb9-38" aria-hidden="true" tabindex="-1"></a>        features <span class="op">=</span> features.reshape(B, H, W, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb9-39"><a href="#cb9-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-40"><a href="#cb9-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Flatten for prediction heads</span></span>
<span id="cb9-41"><a href="#cb9-41" aria-hidden="true" tabindex="-1"></a>        features_flat <span class="op">=</span> features.reshape(B, <span class="op">-</span><span class="dv">1</span>, features.shape[<span class="op">-</span><span class="dv">1</span>])</span>
<span id="cb9-42"><a href="#cb9-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-43"><a href="#cb9-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Predict boxes and classes</span></span>
<span id="cb9-44"><a href="#cb9-44" aria-hidden="true" tabindex="-1"></a>        boxes <span class="op">=</span> <span class="va">self</span>.box_predictor(features_flat)</span>
<span id="cb9-45"><a href="#cb9-45" aria-hidden="true" tabindex="-1"></a>        classes <span class="op">=</span> <span class="va">self</span>.class_predictor(features_flat)</span>
<span id="cb9-46"><a href="#cb9-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb9-47"><a href="#cb9-47" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> {<span class="st">'boxes'</span>: boxes, <span class="st">'classes'</span>: classes, <span class="st">'features_map'</span>: features}</span>
<span id="cb9-48"><a href="#cb9-48" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-49"><a href="#cb9-49" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model and move to device</span></span>
<span id="cb9-50"><a href="#cb9-50" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">"cuda"</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">"cpu"</span>)</span>
<span id="cb9-51"><a href="#cb9-51" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> DINOv2Detector(num_classes<span class="op">=</span><span class="dv">80</span>)  <span class="co"># 80 classes for COCO</span></span>
<span id="cb9-52"><a href="#cb9-52" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> model.to(device)</span>
<span id="cb9-53"><a href="#cb9-53" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-54"><a href="#cb9-54" aria-hidden="true" tabindex="-1"></a><span class="co"># Training would require a more complex detection pipeline with NMS, etc.</span></span></code></pre></div></div>
</section>
<section id="advanced-usage-and-customization" class="level2">
<h2 class="anchored" data-anchor-id="advanced-usage-and-customization" id="advanced-usage-and-customization">Advanced Usage and Customization</h2>
<section id="custom-vision-transformer-configurations" class="level3">
<h3 class="anchored" data-anchor-id="custom-vision-transformer-configurations" id="custom-vision-transformer-configurations">Custom Vision Transformer Configurations</h3>
<p>You can customize the DINOv2 model architecture:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dinov2.configs <span class="im">import</span> get_config</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> dinov2.models <span class="im">import</span> build_model_from_cfg</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Get default configuration and modify it</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>cfg <span class="op">=</span> get_config(<span class="st">"dinov2_base"</span>)</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Modify configuration</span></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>cfg.student.drop_path_rate <span class="op">=</span> <span class="fl">0.2</span>  <span class="co"># Change stochastic depth rate</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>cfg.student.num_registers <span class="op">=</span> <span class="dv">16</span>    <span class="co"># Change the number of registers</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Build model from modified config</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> build_model_from_cfg(cfg)</span></code></pre></div></div>
</section>
<section id="extracting-intermediate-features" class="level3">
<h3 class="anchored" data-anchor-id="extracting-intermediate-features" id="extracting-intermediate-features">Extracting Intermediate Features</h3>
<p>For some applications, you might want to extract features from intermediate layers:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.hooks <span class="im">import</span> RemovableHandle</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> FeatureExtractor:</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, model, layers<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.model <span class="op">=</span> model</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features <span class="op">=</span> {}</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.hooks <span class="op">=</span> []</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Default to extracting from the last block if no layers specified</span></span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layers <span class="op">=</span> layers <span class="cf">if</span> layers <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span> <span class="cf">else</span> [<span class="dv">11</span>]  <span class="co"># Base model has 12 blocks (0-11)</span></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Register hooks</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> idx <span class="kw">in</span> <span class="va">self</span>.layers:</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>            hook <span class="op">=</span> <span class="va">self</span>.model.encoder.layer[idx].register_forward_hook(</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>                <span class="kw">lambda</span> module, <span class="bu">input</span>, output, idx<span class="op">=</span>idx: <span class="va">self</span>.features.update({<span class="ss">f"layer_</span><span class="sc">{</span>idx<span class="sc">}</span><span class="ss">"</span>: output})</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.hooks.append(hook)</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__call__</span>(<span class="va">self</span>, x):</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.features.clear()</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> <span class="va">self</span>.model(x)</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.features</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> remove_hooks(<span class="va">self</span>):</span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> hook <span class="kw">in</span> <span class="va">self</span>.hooks:</span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>            hook.remove()</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Usage</span></span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>)</span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>extractor <span class="op">=</span> FeatureExtractor(model, layers<span class="op">=</span>[<span class="dv">3</span>, <span class="dv">7</span>, <span class="dv">11</span>])</span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract features</span></span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>features <span class="op">=</span> extractor(input_image)</span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>layer_3_features <span class="op">=</span> features[<span class="st">"layer_3"</span>]</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>layer_7_features <span class="op">=</span> features[<span class="st">"layer_7"</span>]</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>layer_11_features <span class="op">=</span> features[<span class="st">"layer_11"</span>]</span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a><span class="co"># Clean up</span></span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>extractor.remove_hooks()</span></code></pre></div></div>
</section>
</section>
<section id="performance-benchmarks" class="level2">
<h2 class="anchored" data-anchor-id="performance-benchmarks" id="performance-benchmarks">Performance Benchmarks</h2>
<p>DINOv2 achieves excellent results across various vision tasks. Here are typical performance metrics:</p>
<ul>
<li><strong>ImageNet-1K Classification</strong> (top-1 accuracy):
<ul>
<li>DINOv2-Small: ~80.0%</li>
<li>DINOv2-Base: ~84.5%</li>
<li>DINOv2-Large: ~86.3%</li>
<li>DINOv2-Giant: ~87.0%</li>
</ul></li>
<li><strong>Semantic Segmentation (ADE20K)</strong> (mIoU):
<ul>
<li>DINOv2-Small: ~47.5%</li>
<li>DINOv2-Base: ~50.2%</li>
<li>DINOv2-Large: ~52.5%</li>
<li>DINOv2-Giant: ~53.8%</li>
</ul></li>
<li><strong>Object Detection (COCO)</strong> (AP):
<ul>
<li>DINOv2-Small: ~48.5%</li>
<li>DINOv2-Base: ~51.3%</li>
<li>DINOv2-Large: ~53.2%</li>
<li>DINOv2-Giant: ~54.5%</li>
</ul></li>
</ul>
</section>
<section id="troubleshooting" class="level2">
<h2 class="anchored" data-anchor-id="troubleshooting" id="troubleshooting">Troubleshooting</h2>
<section id="common-issues-and-solutions" class="level3">
<h3 class="anchored" data-anchor-id="common-issues-and-solutions" id="common-issues-and-solutions">Common Issues and Solutions</h3>
<ol type="1">
<li><strong>Out of Memory Errors</strong>
<ul>
<li>Reduce batch size</li>
<li>Use gradient accumulation</li>
<li>Use a smaller model variant (Small or Base)</li>
<li>Use mixed precision training</li>
</ul></li>
</ol>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Example of mixed precision training</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>scaler <span class="op">=</span> GradScaler()</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> inputs, targets <span class="kw">in</span> train_loader:</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    optimizer.zero_grad()</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Use autocast for mixed precision</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> autocast():</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(inputs)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Scale loss and backprop</span></span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    scaler.scale(loss).backward()</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    scaler.step(optimizer)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    scaler.update()</span></code></pre></div></div>
<ol start="2" type="1">
<li><strong>Slow Inference</strong>
<ul>
<li>Use batch processing</li>
<li>Use model.eval() and torch.no_grad()</li>
<li>Consider model distillation or quantization</li>
</ul></li>
<li><strong>Poor Performance on Downstream Tasks</strong>
<ul>
<li>Ensure proper data preprocessing</li>
<li>Adjust learning rates (lower for backbone, higher for heads)</li>
<li>Use appropriate augmentations</li>
<li>Consider using a larger variant of DINOv2</li>
</ul></li>
</ol>
</section>
<section id="debugging-tips" class="level3">
<h3 class="anchored" data-anchor-id="debugging-tips" id="debugging-tips">Debugging Tips</h3>
<ul>
<li>Visualize model attention maps to understand what the model focuses on:</li>
</ul>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> PIL <span class="im">import</span> Image</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> T</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> get_attention_map(model, img_tensor):</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        outputs <span class="op">=</span> model(img_tensor.unsqueeze(<span class="dv">0</span>), output_attentions<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get attention weights from the last layer</span></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    att_mat <span class="op">=</span> outputs.attentions[<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Average attention across heads</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>    att_mat <span class="op">=</span> att_mat.mean(dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Extract attention for cls token to patch tokens</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>    cls_att_map <span class="op">=</span> att_mat[<span class="dv">0</span>, <span class="dv">0</span>, <span class="dv">1</span>:].reshape(<span class="dv">14</span>, <span class="dv">14</span>)</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> cls_att_map.cpu().numpy()</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Load and preprocess image</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> Image.<span class="bu">open</span>(<span class="st">"path/to/image.jpg"</span>).convert(<span class="st">"RGB"</span>)</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> T.Compose([</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>    T.Resize((<span class="dv">224</span>, <span class="dv">224</span>)),</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>    T.ToTensor(),</span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    T.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>]),</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>img_tensor <span class="op">=</span> transform(image).to(device)</span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Get attention map</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> transformers <span class="im">import</span> AutoModel</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> AutoModel.from_pretrained(<span class="st">"facebook/dinov2-base"</span>, output_attentions<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>model.to(device)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>attention_map <span class="op">=</span> get_attention_map(model, img_tensor)</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Visualize</span></span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">10</span>))</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>plt.imshow(image.resize((<span class="dv">224</span>, <span class="dv">224</span>)))</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>plt.imshow(attention_map, alpha<span class="op">=</span><span class="fl">0.5</span>, cmap<span class="op">=</span><span class="st">'jet'</span>)</span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>plt.axis(<span class="st">'off'</span>)</span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>plt.colorbar()</span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>plt.savefig(<span class="st">'attention_map.png'</span>)</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>plt.close()</span></code></pre></div></div>
<p>This guide should help you get started with DINOv2 and explore its capabilities for various computer vision tasks. As a self-supervised vision foundation model, DINOv2 provides a strong starting point for numerous applications with minimal labeled data requirements.</p>


</section>
</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Vision Transformer (ViT) Implementation Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/models/vision-transformers/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/models/vision-transformers/</guid>
      <pubDate>Sat, 26 Apr 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="vision-transformer-vit-implementation-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-transformers/banner.png" class="img-fluid"></p>
<p>Vision Transformers (ViT) represent a significant paradigm shift in computer vision, applying the transformer architecture initially developed for NLP to image processing tasks. This guide walks through implementing a Vision Transformer from scratch using PyTorch.</p>
<section id="introduction-to-vision-transformers" class="level2">
<h2 class="anchored" data-anchor-id="introduction-to-vision-transformers" id="introduction-to-vision-transformers">1. Introduction to Vision Transformers</h2>
<p>Vision Transformers (ViT) were introduced in the paper “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” by Dosovitskiy et al.&nbsp;in 2020. The core idea is to treat an image as a sequence of patches, similar to how words are treated in NLP, and process them using a transformer encoder.</p>
<p>Key advantages of ViTs include:</p>
<ul>
<li>Global receptive field from the start</li>
<li>Ability to capture long-range dependencies</li>
<li>Scalability to large datasets</li>
<li>No inductive bias towards local processing (unlike CNNs)</li>
</ul>
</section>
<section id="understanding-the-architecture" class="level2">
<h2 class="anchored" data-anchor-id="understanding-the-architecture" id="understanding-the-architecture">2. Understanding the Architecture</h2>
<p>The ViT architecture consists of the following components:</p>
<ol type="1">
<li><strong>Image Patching</strong>: Dividing the input image into fixed-size patches</li>
<li><strong>Patch Embedding</strong>: Linear projection of flattened patches</li>
<li><strong>Position Embedding</strong>: Adding positional information</li>
<li><strong>Transformer Encoder</strong>: Self-attention and feed-forward layers</li>
<li><strong>MLP Head</strong>: Final classification layer</li>
</ol>
<p><img src="https://theja-vanka.github.io/blogs/posts/models/vision-transformers/vit.png" class="img-fluid"></p>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation" id="implementation">3. Implementation</h2>
<p>Let’s implement each component of the Vision Transformer step by step.</p>
<section id="image-patching" class="level3">
<h3 class="anchored" data-anchor-id="image-patching" id="image-patching">Image Patching</h3>
<p>First, we need to divide the input image into fixed-size patches. For a typical ViT, these are 16×16 pixel patches.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PatchEmbedding(nn.Module):</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_size, patch_size, in_channels, embed_dim):</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.image_size <span class="op">=</span> image_size</span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_size <span class="op">=</span> patch_size</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> (image_size <span class="op">//</span> patch_size) <span class="op">**</span> <span class="dv">2</span></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convert image into patches and embed them</span></span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Instead of using einops, we'll use standard PyTorch operations</span></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.projection <span class="op">=</span> nn.Conv2d(</span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>            in_channels, embed_dim, </span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>            kernel_size<span class="op">=</span>patch_size, stride<span class="op">=</span>patch_size</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a>        <span class="co"># x: (batch_size, channels, height, width)</span></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Convert image into patches using convolution</span></span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.projection(x)  <span class="co"># (batch_size, embed_dim, grid_height, grid_width)</span></span>
<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Flatten spatial dimensions and transpose to get </span></span>
<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># (batch_size, num_patches, embed_dim)</span></span>
<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.flatten(<span class="dv">2</span>)  <span class="co"># (batch_size, embed_dim, num_patches)</span></span>
<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.transpose(<span class="dv">1</span>, <span class="dv">2</span>)  <span class="co"># (batch_size, num_patches, embed_dim)</span></span>
<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
<section id="patch-embedding" class="level3">
<h3 class="anchored" data-anchor-id="patch-embedding" id="patch-embedding">Patch Embedding</h3>
<p>After patching, we need to add a learnable class token and position embeddings.</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionTransformer(nn.Module):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_size, patch_size, in_channels, num_classes, </span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>                 embed_dim, depth, num_heads, mlp_ratio<span class="op">=</span><span class="dv">4</span>, </span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>                 dropout_rate<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_embedding <span class="op">=</span> PatchEmbedding(</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>            image_size, patch_size, in_channels, embed_dim)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> <span class="va">self</span>.patch_embedding.num_patches</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Class token</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cls_token <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, embed_dim))</span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Position embedding for patches + class token</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pos_embedding <span class="op">=</span> nn.Parameter(</span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>            torch.zeros(<span class="dv">1</span>, <span class="va">self</span>.num_patches <span class="op">+</span> <span class="dv">1</span>, embed_dim))</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span></code></pre></div></div>
</section>
<section id="position-embedding" class="level3">
<h3 class="anchored" data-anchor-id="position-embedding" id="position-embedding">Position Embedding</h3>
<p>The position embeddings are added to provide spatial information:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get patch embeddings</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.patch_embedding(x)  <span class="co"># (B, num_patches, embed_dim)</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add class token</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>        cls_tokens <span class="op">=</span> <span class="va">self</span>.cls_token.expand(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)  <span class="co"># (B, 1, embed_dim)</span></span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.cat((cls_tokens, x), dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># (B, num_patches + 1, embed_dim)</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add position embedding</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.pos_embedding</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span></code></pre></div></div>
</section>
<section id="transformer-encoder" class="level3">
<h3 class="anchored" data-anchor-id="transformer-encoder" id="transformer-encoder">Transformer Encoder</h3>
<p>Next, let’s implement the transformer encoder blocks:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MultiHeadAttention(nn.Module):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, embed_dim, num_heads, dropout_rate<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_heads <span class="op">=</span> num_heads</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.head_dim <span class="op">=</span> embed_dim <span class="op">//</span> num_heads</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale <span class="op">=</span> <span class="va">self</span>.head_dim <span class="op">**</span> <span class="op">-</span><span class="fl">0.5</span></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.qkv <span class="op">=</span> nn.Linear(embed_dim, embed_dim <span class="op">*</span> <span class="dv">3</span>)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.proj <span class="op">=</span> nn.Linear(embed_dim, embed_dim)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>        batch_size, seq_len, embed_dim <span class="op">=</span> x.shape</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get query, key, and value projections</span></span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> <span class="va">self</span>.qkv(x)</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> qkv.reshape(batch_size, seq_len, <span class="dv">3</span>, <span class="va">self</span>.num_heads, <span class="va">self</span>.head_dim)</span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> qkv.permute(<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">4</span>)  <span class="co"># (3, B, H, N, D)</span></span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        q, k, v <span class="op">=</span> qkv[<span class="dv">0</span>], qkv[<span class="dv">1</span>], qkv[<span class="dv">2</span>]  <span class="co"># Each is (B, H, N, D)</span></span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Attention</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> (q <span class="op">@</span> k.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>)) <span class="op">*</span> <span class="va">self</span>.scale  <span class="co"># (B, H, N, N)</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> attn.softmax(dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> <span class="va">self</span>.dropout(attn)</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention to values</span></span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> (attn <span class="op">@</span> v).transpose(<span class="dv">1</span>, <span class="dv">2</span>)  <span class="co"># (B, N, H, D)</span></span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> out.reshape(batch_size, seq_len, embed_dim)  <span class="co"># (B, N, E)</span></span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>        out <span class="op">=</span> <span class="va">self</span>.proj(out)</span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> out</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TransformerEncoder(nn.Module):</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, embed_dim, num_heads, mlp_ratio<span class="op">=</span><span class="dv">4</span>, dropout_rate<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Layer normalization</span></span>
<span id="cb4-38"><a href="#cb4-38" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm1 <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb4-39"><a href="#cb4-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-40"><a href="#cb4-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Multi-head self-attention</span></span>
<span id="cb4-41"><a href="#cb4-41" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.attn <span class="op">=</span> MultiHeadAttention(embed_dim, num_heads, dropout_rate)</span>
<span id="cb4-42"><a href="#cb4-42" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout1 <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb4-43"><a href="#cb4-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-44"><a href="#cb4-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Layer normalization</span></span>
<span id="cb4-45"><a href="#cb4-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm2 <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb4-46"><a href="#cb4-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-47"><a href="#cb4-47" aria-hidden="true" tabindex="-1"></a>        <span class="co"># MLP block</span></span>
<span id="cb4-48"><a href="#cb4-48" aria-hidden="true" tabindex="-1"></a>        mlp_hidden_dim <span class="op">=</span> <span class="bu">int</span>(embed_dim <span class="op">*</span> mlp_ratio)</span>
<span id="cb4-49"><a href="#cb4-49" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mlp <span class="op">=</span> nn.Sequential(</span>
<span id="cb4-50"><a href="#cb4-50" aria-hidden="true" tabindex="-1"></a>            nn.Linear(embed_dim, mlp_hidden_dim),</span>
<span id="cb4-51"><a href="#cb4-51" aria-hidden="true" tabindex="-1"></a>            nn.GELU(),</span>
<span id="cb4-52"><a href="#cb4-52" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout_rate),</span>
<span id="cb4-53"><a href="#cb4-53" aria-hidden="true" tabindex="-1"></a>            nn.Linear(mlp_hidden_dim, embed_dim),</span>
<span id="cb4-54"><a href="#cb4-54" aria-hidden="true" tabindex="-1"></a>            nn.Dropout(dropout_rate)</span>
<span id="cb4-55"><a href="#cb4-55" aria-hidden="true" tabindex="-1"></a>        )</span>
<span id="cb4-56"><a href="#cb4-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-57"><a href="#cb4-57" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb4-58"><a href="#cb4-58" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply layer normalization and self-attention</span></span>
<span id="cb4-59"><a href="#cb4-59" aria-hidden="true" tabindex="-1"></a>        attn_output <span class="op">=</span> <span class="va">self</span>.attn(<span class="va">self</span>.norm1(x))</span>
<span id="cb4-60"><a href="#cb4-60" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.dropout1(attn_output)</span>
<span id="cb4-61"><a href="#cb4-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-62"><a href="#cb4-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply MLP block with residual connection</span></span>
<span id="cb4-63"><a href="#cb4-63" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.mlp(<span class="va">self</span>.norm2(x))</span>
<span id="cb4-64"><a href="#cb4-64" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
<p>Now, let’s update our main ViT class to include the transformer encoder blocks:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionTransformer(nn.Module):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_size, patch_size, in_channels, num_classes, </span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>                 embed_dim, depth, num_heads, mlp_ratio<span class="op">=</span><span class="dv">4</span>, </span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>                 dropout_rate<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_embedding <span class="op">=</span> PatchEmbedding(</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>            image_size, patch_size, in_channels, embed_dim)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> <span class="va">self</span>.patch_embedding.num_patches</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Class token</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cls_token <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, embed_dim))</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Position embedding for patches + class token</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pos_embedding <span class="op">=</span> nn.Parameter(</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>            torch.zeros(<span class="dv">1</span>, <span class="va">self</span>.num_patches <span class="op">+</span> <span class="dv">1</span>, embed_dim))</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Transformer encoder blocks</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transformer_blocks <span class="op">=</span> nn.ModuleList([</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>            TransformerEncoder(embed_dim, num_heads, mlp_ratio, dropout_rate)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(depth)</span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Layer normalization</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm <span class="op">=</span> nn.LayerNorm(embed_dim)</span></code></pre></div></div>
</section>
<section id="mlp-head" class="level3">
<h3 class="anchored" data-anchor-id="mlp-head" id="mlp-head">MLP Head</h3>
<p>Finally, let’s add the classification head and complete the forward pass:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> VisionTransformer(nn.Module):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, image_size, patch_size, in_channels, num_classes, </span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>                 embed_dim, depth, num_heads, mlp_ratio<span class="op">=</span><span class="dv">4</span>, </span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>                 dropout_rate<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.patch_embedding <span class="op">=</span> PatchEmbedding(</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>            image_size, patch_size, in_channels, embed_dim)</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_patches <span class="op">=</span> <span class="va">self</span>.patch_embedding.num_patches</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Class token</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.cls_token <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, embed_dim))</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Position embedding for patches + class token</span></span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pos_embedding <span class="op">=</span> nn.Parameter(</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>            torch.zeros(<span class="dv">1</span>, <span class="va">self</span>.num_patches <span class="op">+</span> <span class="dv">1</span>, embed_dim))</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout_rate)</span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-19"><a href="#cb6-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Transformer encoder blocks</span></span>
<span id="cb6-20"><a href="#cb6-20" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transformer_blocks <span class="op">=</span> nn.ModuleList([</span>
<span id="cb6-21"><a href="#cb6-21" aria-hidden="true" tabindex="-1"></a>            TransformerEncoder(embed_dim, num_heads, mlp_ratio, dropout_rate)</span>
<span id="cb6-22"><a href="#cb6-22" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(depth)</span>
<span id="cb6-23"><a href="#cb6-23" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb6-24"><a href="#cb6-24" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-25"><a href="#cb6-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Layer normalization</span></span>
<span id="cb6-26"><a href="#cb6-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.norm <span class="op">=</span> nn.LayerNorm(embed_dim)</span>
<span id="cb6-27"><a href="#cb6-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-28"><a href="#cb6-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classification head</span></span>
<span id="cb6-29"><a href="#cb6-29" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.mlp_head <span class="op">=</span> nn.Linear(embed_dim, num_classes)</span>
<span id="cb6-30"><a href="#cb6-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-31"><a href="#cb6-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize weights</span></span>
<span id="cb6-32"><a href="#cb6-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>._init_weights()</span>
<span id="cb6-33"><a href="#cb6-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-34"><a href="#cb6-34" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> _init_weights(<span class="va">self</span>):</span>
<span id="cb6-35"><a href="#cb6-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize patch embedding and MLP heads</span></span>
<span id="cb6-36"><a href="#cb6-36" aria-hidden="true" tabindex="-1"></a>        nn.init.normal_(<span class="va">self</span>.cls_token, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb6-37"><a href="#cb6-37" aria-hidden="true" tabindex="-1"></a>        nn.init.normal_(<span class="va">self</span>.pos_embedding, std<span class="op">=</span><span class="fl">0.02</span>)</span>
<span id="cb6-38"><a href="#cb6-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-39"><a href="#cb6-39" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb6-40"><a href="#cb6-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get patch embeddings</span></span>
<span id="cb6-41"><a href="#cb6-41" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.patch_embedding(x)  <span class="co"># (B, num_patches, embed_dim)</span></span>
<span id="cb6-42"><a href="#cb6-42" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-43"><a href="#cb6-43" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add class token</span></span>
<span id="cb6-44"><a href="#cb6-44" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb6-45"><a href="#cb6-45" aria-hidden="true" tabindex="-1"></a>        cls_tokens <span class="op">=</span> <span class="va">self</span>.cls_token.expand(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)  <span class="co"># (B, 1, embed_dim)</span></span>
<span id="cb6-46"><a href="#cb6-46" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.cat((cls_tokens, x), dim<span class="op">=</span><span class="dv">1</span>)  <span class="co"># (B, num_patches + 1, embed_dim)</span></span>
<span id="cb6-47"><a href="#cb6-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-48"><a href="#cb6-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add position embedding</span></span>
<span id="cb6-49"><a href="#cb6-49" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.pos_embedding</span>
<span id="cb6-50"><a href="#cb6-50" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb6-51"><a href="#cb6-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-52"><a href="#cb6-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply transformer blocks</span></span>
<span id="cb6-53"><a href="#cb6-53" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> block <span class="kw">in</span> <span class="va">self</span>.transformer_blocks:</span>
<span id="cb6-54"><a href="#cb6-54" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> block(x)</span>
<span id="cb6-55"><a href="#cb6-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-56"><a href="#cb6-56" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply final layer normalization</span></span>
<span id="cb6-57"><a href="#cb6-57" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm(x)</span>
<span id="cb6-58"><a href="#cb6-58" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-59"><a href="#cb6-59" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Take class token for classification</span></span>
<span id="cb6-60"><a href="#cb6-60" aria-hidden="true" tabindex="-1"></a>        cls_token_final <span class="op">=</span> x[:, <span class="dv">0</span>]</span>
<span id="cb6-61"><a href="#cb6-61" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb6-62"><a href="#cb6-62" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Classification</span></span>
<span id="cb6-63"><a href="#cb6-63" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> <span class="va">self</span>.mlp_head(cls_token_final)</span></code></pre></div></div>
</section>
</section>
<section id="training-the-model" class="level2">
<h2 class="anchored" data-anchor-id="training-the-model" id="training-the-model">4. Training the Model</h2>
<p>Let’s implement a training function for our Vision Transformer:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_vit(model, train_loader, optimizer, criterion, device, epochs<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>        correct <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>        total <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (inputs, targets) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Zero the gradients</span></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Forward pass</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            outputs <span class="op">=</span> model(inputs)</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>            loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Backward pass and optimize</span></span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>            loss.backward()</span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            optimizer.step()</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Statistics</span></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>            _, predicted <span class="op">=</span> outputs.<span class="bu">max</span>(<span class="dv">1</span>)</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>            total <span class="op">+=</span> targets.size(<span class="dv">0</span>)</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>            correct <span class="op">+=</span> predicted.eq(targets).<span class="bu">sum</span>().item()</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-28"><a href="#cb7-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Print statistics every 100 batches</span></span>
<span id="cb7-29"><a href="#cb7-29" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">99</span>:</span>
<span id="cb7-30"><a href="#cb7-30" aria-hidden="true" tabindex="-1"></a>                <span class="bu">print</span>(<span class="ss">f'Epoch: </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">, Batch: </span><span class="sc">{</span>batch_idx<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">, '</span></span>
<span id="cb7-31"><a href="#cb7-31" aria-hidden="true" tabindex="-1"></a>                      <span class="ss">f'Loss: </span><span class="sc">{</span>running_loss<span class="op">/</span><span class="dv">100</span><span class="sc">:.3f}</span><span class="ss">, '</span></span>
<span id="cb7-32"><a href="#cb7-32" aria-hidden="true" tabindex="-1"></a>                      <span class="ss">f'Accuracy: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%'</span>)</span>
<span id="cb7-33"><a href="#cb7-33" aria-hidden="true" tabindex="-1"></a>                running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb7-34"><a href="#cb7-34" aria-hidden="true" tabindex="-1"></a>                </span>
<span id="cb7-35"><a href="#cb7-35" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Epoch statistics</span></span>
<span id="cb7-36"><a href="#cb7-36" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="ss">f'Epoch </span><span class="sc">{</span>epoch<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss"> completed. '</span></span>
<span id="cb7-37"><a href="#cb7-37" aria-hidden="true" tabindex="-1"></a>              <span class="ss">f'Accuracy: </span><span class="sc">{</span><span class="fl">100.</span><span class="op">*</span>correct<span class="op">/</span>total<span class="sc">:.2f}</span><span class="ss">%'</span>)</span></code></pre></div></div>
<p>Example usage:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.optim <span class="im">as</span> optim</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> datasets, transforms</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Set device</span></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">'cuda'</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">'cpu'</span>)</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Create ViT model</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> VisionTransformer(</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    image_size<span class="op">=</span><span class="dv">224</span>,</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>    patch_size<span class="op">=</span><span class="dv">16</span>,</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>    in_channels<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>    num_classes<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>    embed_dim<span class="op">=</span><span class="dv">768</span>,</span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>    depth<span class="op">=</span><span class="dv">12</span>,</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>    num_heads<span class="op">=</span><span class="dv">12</span>,</span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>    mlp_ratio<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>    dropout_rate<span class="op">=</span><span class="fl">0.1</span></span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>).to(device)</span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Define loss function and optimizer</span></span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>criterion <span class="op">=</span> nn.CrossEntropyLoss()</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a><span class="co"># Load data</span></span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>    transforms.Resize(<span class="dv">224</span>),</span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>    transforms.ToTensor(),</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>    transforms.Normalize(mean<span class="op">=</span>[<span class="fl">0.485</span>, <span class="fl">0.456</span>, <span class="fl">0.406</span>], std<span class="op">=</span>[<span class="fl">0.229</span>, <span class="fl">0.224</span>, <span class="fl">0.225</span>])</span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>train_dataset <span class="op">=</span> datasets.FakeData(</span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    transform<span class="op">=</span>transforms.ToTensor()</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>train_loader <span class="op">=</span> DataLoader(train_dataset, batch_size<span class="op">=</span><span class="dv">32</span>, shuffle<span class="op">=</span><span class="va">True</span>, num_workers<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb8-36"><a href="#cb8-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-37"><a href="#cb8-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Train model</span></span>
<span id="cb8-38"><a href="#cb8-38" aria-hidden="true" tabindex="-1"></a>train_vit(model, train_loader, optimizer, criterion, device, epochs<span class="op">=</span><span class="dv">10</span>)</span></code></pre></div></div>
</section>
<section id="inference-and-usage" class="level2">
<h2 class="anchored" data-anchor-id="inference-and-usage" id="inference-and-usage">5. Inference and Usage</h2>
<p>Here’s how to use the model for inference:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> inference(model, image_tensor, device):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>        image_tensor <span class="op">=</span> image_tensor.unsqueeze(<span class="dv">0</span>).to(device)  <span class="co"># Add batch dimension</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> model(image_tensor)</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        probabilities <span class="op">=</span> F.softmax(output, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        predicted_class <span class="op">=</span> torch.argmax(probabilities, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> predicted_class.item(), probabilities[<span class="dv">0</span>]</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Example usage</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>image <span class="op">=</span> transform(image).to(device)</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>predicted_class, probabilities <span class="op">=</span> inference(model, image, device)</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Predicted class: </span><span class="sc">{</span>predicted_class<span class="sc">}</span><span class="ss">"</span>)</span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Confidence: </span><span class="sc">{</span>probabilities[predicted_class]<span class="sc">:.4f}</span><span class="ss">"</span>)</span></code></pre></div></div>
</section>
<section id="optimization-techniques" class="level2">
<h2 class="anchored" data-anchor-id="optimization-techniques" id="optimization-techniques">6. Optimization Techniques</h2>
<p>To improve the training and performance of ViT models, consider these optimization techniques:</p>
<section id="custom-attention-implementation" class="level3">
<h3 class="anchored" data-anchor-id="custom-attention-implementation" id="custom-attention-implementation">Custom Attention Implementation</h3>
<p>The standard attention implementation can be memory-intensive. You can use a more efficient implementation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> efficient_attention(q, k, v, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># q, k, v: [batch_size, num_heads, seq_len, head_dim]</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Scaled dot-product attention</span></span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    scale <span class="op">=</span> q.size(<span class="op">-</span><span class="dv">1</span>) <span class="op">**</span> <span class="op">-</span><span class="fl">0.5</span></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    attention <span class="op">=</span> torch.matmul(q, k.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>)) <span class="op">*</span> scale  <span class="co"># [B, H, L, L]</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>        attention <span class="op">=</span> attention.masked_fill(mask <span class="op">==</span> <span class="dv">0</span>, <span class="op">-</span><span class="fl">1e9</span>)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    attention <span class="op">=</span> F.softmax(attention, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    output <span class="op">=</span> torch.matmul(attention, v)  <span class="co"># [B, H, L, D]</span></span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> output</span></code></pre></div></div>
</section>
<section id="mixed-precision-training" class="level3">
<h3 class="anchored" data-anchor-id="mixed-precision-training" id="mixed-precision-training">Mixed Precision Training</h3>
<p>Use mixed precision training to reduce memory usage and increase training speed:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.cuda.amp <span class="im">import</span> autocast, GradScaler</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> train_with_mixed_precision(model, train_loader, optimizer, criterion, device, epochs<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    scaler <span class="op">=</span> GradScaler()</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    model.train()</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> epoch <span class="kw">in</span> <span class="bu">range</span>(epochs):</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>        running_loss <span class="op">=</span> <span class="fl">0.0</span></span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> batch_idx, (inputs, targets) <span class="kw">in</span> <span class="bu">enumerate</span>(train_loader):</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>            inputs, targets <span class="op">=</span> inputs.to(device), targets.to(device)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>            optimizer.zero_grad()</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Use autocast for mixed precision</span></span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>            <span class="cf">with</span> autocast():</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>                outputs <span class="op">=</span> model(inputs)</span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>                loss <span class="op">=</span> criterion(outputs, targets)</span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Scale gradients and optimize</span></span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>            scaler.scale(loss).backward()</span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>            scaler.step(optimizer)</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>            scaler.update()</span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>            running_loss <span class="op">+=</span> loss.item()</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Rest of the training loop...</span></span></code></pre></div></div>
</section>
<section id="regularization-techniques" class="level3">
<h3 class="anchored" data-anchor-id="regularization-techniques" id="regularization-techniques">Regularization Techniques</h3>
<p>Implement regularization techniques such as stochastic depth to prevent overfitting:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> StochasticDepth(nn.Module):</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, drop_prob<span class="op">=</span><span class="fl">0.1</span>):</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.drop_prob <span class="op">=</span> drop_prob</span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> <span class="va">self</span>.training <span class="kw">or</span> <span class="va">self</span>.drop_prob <span class="op">==</span> <span class="fl">0.</span>:</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> x</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>        keep_prob <span class="op">=</span> <span class="dv">1</span> <span class="op">-</span> <span class="va">self</span>.drop_prob</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        shape <span class="op">=</span> (x.shape[<span class="dv">0</span>],) <span class="op">+</span> (<span class="dv">1</span>,) <span class="op">*</span> (x.ndim <span class="op">-</span> <span class="dv">1</span>)</span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        random_tensor <span class="op">=</span> keep_prob <span class="op">+</span> torch.rand(shape, dtype<span class="op">=</span>x.dtype, device<span class="op">=</span>x.device)</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        random_tensor.floor_()  <span class="co"># binarize</span></span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        output <span class="op">=</span> x.div(keep_prob) <span class="op">*</span> random_tensor</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> output</span></code></pre></div></div>
</section>
</section>
<section id="advanced-variants" class="level2">
<h2 class="anchored" data-anchor-id="advanced-variants" id="advanced-variants">7. Advanced Variants</h2>
<p>Several advanced variants of Vision Transformers have been developed:</p>
<section id="deit-data-efficient-image-transformer" class="level3">
<h3 class="anchored" data-anchor-id="deit-data-efficient-image-transformer" id="deit-data-efficient-image-transformer">DeiT (Data-efficient Image Transformer)</h3>
<p>DeiT introduces a distillation token and a teacher-student strategy:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> DeiT(VisionTransformer):</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, <span class="op">*</span>args, <span class="op">**</span>kwargs):</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>(<span class="op">*</span>args, <span class="op">**</span>kwargs)</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Distillation token</span></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dist_token <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, <span class="dv">1</span>, <span class="va">self</span>.embed_dim))</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update position embeddings to include distillation token</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Original position embeddings are for [class_token, patches]</span></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># New position embeddings are for [class_token, dist_token, patches]</span></span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>        num_patches <span class="op">=</span> <span class="va">self</span>.patch_embedding.num_patches</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>        new_pos_embed <span class="op">=</span> nn.Parameter(torch.zeros(<span class="dv">1</span>, num_patches <span class="op">+</span> <span class="dv">2</span>, <span class="va">self</span>.embed_dim))</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Initialize new position embeddings with the original ones</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Copy class token position embedding</span></span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a>        new_pos_embed.data[:, <span class="dv">0</span>:<span class="dv">1</span>, :] <span class="op">=</span> <span class="va">self</span>.pos_embedding.data[:, <span class="dv">0</span>:<span class="dv">1</span>, :]</span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add a new position embedding for distillation token</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>        new_pos_embed.data[:, <span class="dv">1</span>:<span class="dv">2</span>, :] <span class="op">=</span> <span class="va">self</span>.pos_embedding.data[:, <span class="dv">0</span>:<span class="dv">1</span>, :]</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Copy patch position embeddings</span></span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>        new_pos_embed.data[:, <span class="dv">2</span>:, :] <span class="op">=</span> <span class="va">self</span>.pos_embedding.data[:, <span class="dv">1</span>:, :]</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.pos_embedding <span class="op">=</span> new_pos_embed</span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Additional classification head for distillation</span></span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.head_dist <span class="op">=</span> nn.Linear(<span class="va">self</span>.embed_dim, kwargs.get(<span class="st">'num_classes'</span>, <span class="dv">1000</span>))</span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get patch embeddings</span></span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.patch_embedding(x)</span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add class and distillation tokens</span></span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>        batch_size <span class="op">=</span> x.shape[<span class="dv">0</span>]</span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>        cls_tokens <span class="op">=</span> <span class="va">self</span>.cls_token.expand(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>        dist_tokens <span class="op">=</span> <span class="va">self</span>.dist_token.expand(batch_size, <span class="op">-</span><span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.cat((cls_tokens, dist_tokens, x), dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add position embedding and apply dropout</span></span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x <span class="op">+</span> <span class="va">self</span>.pos_embedding</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.dropout(x)</span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply transformer blocks</span></span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> block <span class="kw">in</span> <span class="va">self</span>.transformer_blocks:</span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> block(x)</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply final layer normalization</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.norm(x)</span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get class and distillation tokens</span></span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>        cls_token_final <span class="op">=</span> x[:, <span class="dv">0</span>]</span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>        dist_token_final <span class="op">=</span> x[:, <span class="dv">1</span>]</span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply classification heads</span></span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>        x_cls <span class="op">=</span> <span class="va">self</span>.mlp_head(cls_token_final)</span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a>        x_dist <span class="op">=</span> <span class="va">self</span>.head_dist(dist_token_final)</span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-56"><a href="#cb13-56" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="va">self</span>.training:</span>
<span id="cb13-57"><a href="#cb13-57" aria-hidden="true" tabindex="-1"></a>            <span class="co"># During training, return both outputs</span></span>
<span id="cb13-58"><a href="#cb13-58" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> x_cls, x_dist</span>
<span id="cb13-59"><a href="#cb13-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb13-60"><a href="#cb13-60" aria-hidden="true" tabindex="-1"></a>            <span class="co"># During inference, return average</span></span>
<span id="cb13-61"><a href="#cb13-61" aria-hidden="true" tabindex="-1"></a>            <span class="cf">return</span> (x_cls <span class="op">+</span> x_dist) <span class="op">/</span> <span class="dv">2</span></span></code></pre></div></div>
</section>
<section id="swin-transformer" class="level3">
<h3 class="anchored" data-anchor-id="swin-transformer" id="swin-transformer">Swin Transformer</h3>
<p>Swin Transformer introduces hierarchical feature maps and shifted windows:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># This is a simplified conceptual implementation of the Swin Transformer block</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> WindowAttention(nn.Module):</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, dim, window_size, num_heads, qkv_bias<span class="op">=</span><span class="va">True</span>, dropout<span class="op">=</span><span class="fl">0.</span>):</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dim <span class="op">=</span> dim</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.window_size <span class="op">=</span> window_size  <span class="co"># (height, width)</span></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.num_heads <span class="op">=</span> num_heads</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.scale <span class="op">=</span> (dim <span class="op">//</span> num_heads) <span class="op">**</span> <span class="op">-</span><span class="fl">0.5</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Linear projections</span></span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.qkv <span class="op">=</span> nn.Linear(dim, dim <span class="op">*</span> <span class="dv">3</span>, bias<span class="op">=</span>qkv_bias)</span>
<span id="cb14-12"><a href="#cb14-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.proj <span class="op">=</span> nn.Linear(dim, dim)</span>
<span id="cb14-13"><a href="#cb14-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-14"><a href="#cb14-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Define relative position bias</span></span>
<span id="cb14-15"><a href="#cb14-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.relative_position_bias_table <span class="op">=</span> nn.Parameter(</span>
<span id="cb14-16"><a href="#cb14-16" aria-hidden="true" tabindex="-1"></a>            torch.zeros((<span class="dv">2</span> <span class="op">*</span> window_size[<span class="dv">0</span>] <span class="op">-</span> <span class="dv">1</span>) <span class="op">*</span> (<span class="dv">2</span> <span class="op">*</span> window_size[<span class="dv">1</span>] <span class="op">-</span> <span class="dv">1</span>), num_heads))</span>
<span id="cb14-17"><a href="#cb14-17" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb14-18"><a href="#cb14-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate pair-wise relative position index for each token in the window</span></span>
<span id="cb14-19"><a href="#cb14-19" aria-hidden="true" tabindex="-1"></a>        coords_h <span class="op">=</span> torch.arange(<span class="va">self</span>.window_size[<span class="dv">0</span>])</span>
<span id="cb14-20"><a href="#cb14-20" aria-hidden="true" tabindex="-1"></a>        coords_w <span class="op">=</span> torch.arange(<span class="va">self</span>.window_size[<span class="dv">1</span>])</span>
<span id="cb14-21"><a href="#cb14-21" aria-hidden="true" tabindex="-1"></a>        coords <span class="op">=</span> torch.stack(torch.meshgrid([coords_h, coords_w], indexing<span class="op">=</span><span class="st">"ij"</span>))  <span class="co"># 2, Wh, Ww</span></span>
<span id="cb14-22"><a href="#cb14-22" aria-hidden="true" tabindex="-1"></a>        coords_flatten <span class="op">=</span> torch.flatten(coords, <span class="dv">1</span>)  <span class="co"># 2, Wh*Ww</span></span>
<span id="cb14-23"><a href="#cb14-23" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-24"><a href="#cb14-24" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate relative positions</span></span>
<span id="cb14-25"><a href="#cb14-25" aria-hidden="true" tabindex="-1"></a>        relative_coords <span class="op">=</span> coords_flatten[:, :, <span class="va">None</span>] <span class="op">-</span> coords_flatten[:, <span class="va">None</span>, :]  <span class="co"># 2, Wh*Ww, Wh*Ww</span></span>
<span id="cb14-26"><a href="#cb14-26" aria-hidden="true" tabindex="-1"></a>        relative_coords <span class="op">=</span> relative_coords.permute(<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">0</span>).contiguous()  <span class="co"># Wh*Ww, Wh*Ww, 2</span></span>
<span id="cb14-27"><a href="#cb14-27" aria-hidden="true" tabindex="-1"></a>        relative_coords[:, :, <span class="dv">0</span>] <span class="op">+=</span> <span class="va">self</span>.window_size[<span class="dv">0</span>] <span class="op">-</span> <span class="dv">1</span>  <span class="co"># shift to start from 0</span></span>
<span id="cb14-28"><a href="#cb14-28" aria-hidden="true" tabindex="-1"></a>        relative_coords[:, :, <span class="dv">1</span>] <span class="op">+=</span> <span class="va">self</span>.window_size[<span class="dv">1</span>] <span class="op">-</span> <span class="dv">1</span></span>
<span id="cb14-29"><a href="#cb14-29" aria-hidden="true" tabindex="-1"></a>        relative_coords[:, :, <span class="dv">0</span>] <span class="op">*=</span> <span class="dv">2</span> <span class="op">*</span> <span class="va">self</span>.window_size[<span class="dv">1</span>] <span class="op">-</span> <span class="dv">1</span></span>
<span id="cb14-30"><a href="#cb14-30" aria-hidden="true" tabindex="-1"></a>        relative_position_index <span class="op">=</span> relative_coords.<span class="bu">sum</span>(<span class="op">-</span><span class="dv">1</span>)  <span class="co"># Wh*Ww, Wh*Ww</span></span>
<span id="cb14-31"><a href="#cb14-31" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-32"><a href="#cb14-32" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.register_buffer(<span class="st">"relative_position_index"</span>, relative_position_index)</span>
<span id="cb14-33"><a href="#cb14-33" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-34"><a href="#cb14-34" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout <span class="op">=</span> nn.Dropout(dropout)</span>
<span id="cb14-35"><a href="#cb14-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-36"><a href="#cb14-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, mask<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb14-37"><a href="#cb14-37" aria-hidden="true" tabindex="-1"></a>        B_, N, C <span class="op">=</span> x.shape</span>
<span id="cb14-38"><a href="#cb14-38" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-39"><a href="#cb14-39" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Generate QKV matrices</span></span>
<span id="cb14-40"><a href="#cb14-40" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> <span class="va">self</span>.qkv(x).reshape(B_, N, <span class="dv">3</span>, <span class="va">self</span>.num_heads, C <span class="op">//</span> <span class="va">self</span>.num_heads)</span>
<span id="cb14-41"><a href="#cb14-41" aria-hidden="true" tabindex="-1"></a>        qkv <span class="op">=</span> qkv.permute(<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">3</span>, <span class="dv">1</span>, <span class="dv">4</span>)  <span class="co"># 3, B_, num_heads, N, C//num_heads</span></span>
<span id="cb14-42"><a href="#cb14-42" aria-hidden="true" tabindex="-1"></a>        q, k, v <span class="op">=</span> qkv[<span class="dv">0</span>], qkv[<span class="dv">1</span>], qkv[<span class="dv">2</span>]  <span class="co"># each has shape [B_, num_heads, N, C//num_heads]</span></span>
<span id="cb14-43"><a href="#cb14-43" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-44"><a href="#cb14-44" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Scaled dot-product attention</span></span>
<span id="cb14-45"><a href="#cb14-45" aria-hidden="true" tabindex="-1"></a>        q <span class="op">=</span> q <span class="op">*</span> <span class="va">self</span>.scale</span>
<span id="cb14-46"><a href="#cb14-46" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> (q <span class="op">@</span> k.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))  <span class="co"># B_, num_heads, N, N</span></span>
<span id="cb14-47"><a href="#cb14-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-48"><a href="#cb14-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add relative position bias</span></span>
<span id="cb14-49"><a href="#cb14-49" aria-hidden="true" tabindex="-1"></a>        relative_position_bias <span class="op">=</span> <span class="va">self</span>.relative_position_bias_table[<span class="va">self</span>.relative_position_index.view(<span class="op">-</span><span class="dv">1</span>)]</span>
<span id="cb14-50"><a href="#cb14-50" aria-hidden="true" tabindex="-1"></a>        relative_position_bias <span class="op">=</span> relative_position_bias.view(</span>
<span id="cb14-51"><a href="#cb14-51" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.window_size[<span class="dv">0</span>] <span class="op">*</span> <span class="va">self</span>.window_size[<span class="dv">1</span>], </span>
<span id="cb14-52"><a href="#cb14-52" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.window_size[<span class="dv">0</span>] <span class="op">*</span> <span class="va">self</span>.window_size[<span class="dv">1</span>], </span>
<span id="cb14-53"><a href="#cb14-53" aria-hidden="true" tabindex="-1"></a>            <span class="op">-</span><span class="dv">1</span>)  <span class="co"># Wh*Ww, Wh*Ww, num_heads</span></span>
<span id="cb14-54"><a href="#cb14-54" aria-hidden="true" tabindex="-1"></a>        relative_position_bias <span class="op">=</span> relative_position_bias.permute(<span class="dv">2</span>, <span class="dv">0</span>, <span class="dv">1</span>).contiguous()  <span class="co"># num_heads, Wh*Ww, Wh*Ww</span></span>
<span id="cb14-55"><a href="#cb14-55" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> attn <span class="op">+</span> relative_position_bias.unsqueeze(<span class="dv">0</span>)</span>
<span id="cb14-56"><a href="#cb14-56" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-57"><a href="#cb14-57" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply mask if needed</span></span>
<span id="cb14-58"><a href="#cb14-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> mask <span class="kw">is</span> <span class="kw">not</span> <span class="va">None</span>:</span>
<span id="cb14-59"><a href="#cb14-59" aria-hidden="true" tabindex="-1"></a>            nW <span class="op">=</span> mask.shape[<span class="dv">0</span>]</span>
<span id="cb14-60"><a href="#cb14-60" aria-hidden="true" tabindex="-1"></a>            attn <span class="op">=</span> attn.view(B_ <span class="op">//</span> nW, nW, <span class="va">self</span>.num_heads, N, N) <span class="op">+</span> mask.unsqueeze(<span class="dv">1</span>).unsqueeze(<span class="dv">0</span>)</span>
<span id="cb14-61"><a href="#cb14-61" aria-hidden="true" tabindex="-1"></a>            attn <span class="op">=</span> attn.view(<span class="op">-</span><span class="dv">1</span>, <span class="va">self</span>.num_heads, N, N)</span>
<span id="cb14-62"><a href="#cb14-62" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-63"><a href="#cb14-63" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply softmax</span></span>
<span id="cb14-64"><a href="#cb14-64" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> F.softmax(attn, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb14-65"><a href="#cb14-65" aria-hidden="true" tabindex="-1"></a>        attn <span class="op">=</span> <span class="va">self</span>.dropout(attn)</span>
<span id="cb14-66"><a href="#cb14-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-67"><a href="#cb14-67" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Apply attention to values</span></span>
<span id="cb14-68"><a href="#cb14-68" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> (attn <span class="op">@</span> v).transpose(<span class="dv">1</span>, <span class="dv">2</span>).reshape(B_, N, C)</span>
<span id="cb14-69"><a href="#cb14-69" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.proj(x)</span>
<span id="cb14-70"><a href="#cb14-70" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb14-71"><a href="#cb14-71" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span></code></pre></div></div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Vision Transformers represent a significant advancement in computer vision, offering a different approach from traditional CNNs. This guide has covered the essential components for implementing a Vision Transformer from scratch, including image patching, position embeddings, transformer encoders, and classification heads.</p>
<p>By understanding these fundamentals, you can implement your own ViT models and experiment with various modifications to improve performance for specific tasks.</p>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references" id="references">References</h2>
<ol type="1">
<li>Dosovitskiy, A., et al.&nbsp;(2020). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” arXiv:2010.11929.</li>
<li>Touvron, H., et al.&nbsp;(2021). “Training data-efficient image transformers &amp; distillation through attention.” arXiv:2012.12877.</li>
<li>Liu, Z., et al.&nbsp;(2021). “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” arXiv:2103.14030.</li>
</ol>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Active Learning Influence Selection: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/influence-selection/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/influence-selection/</guid>
      <pubDate>Sat, 19 Apr 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>research</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="active-learning-influence-selection-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/influence-selection/learn.png" class="img-fluid"></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction" id="introduction">Introduction</h2>
<p>Active learning is a machine learning paradigm where the algorithm can interactively query an oracle (typically a human annotator) to label new data points. The key idea is to select the most informative samples to be labeled, reducing the overall labeling effort while maintaining or improving model performance. This guide focuses on influence selection methods used in active learning strategies.</p>
</section>
<section id="fundamentals-of-active-learning" class="level2">
<h2 class="anchored" data-anchor-id="fundamentals-of-active-learning" id="fundamentals-of-active-learning">Fundamentals of Active Learning</h2>
<section id="the-active-learning-loop" class="level3">
<h3 class="anchored" data-anchor-id="the-active-learning-loop" id="the-active-learning-loop">The Active Learning Loop</h3>
<p>The typical active learning process follows these steps:</p>
<ol type="1">
<li>Start with a small labeled dataset and a large unlabeled pool</li>
<li>Train an initial model on the labeled data</li>
<li>Apply an influence selection strategy to choose informative samples from the unlabeled pool</li>
<li>Get annotations for the selected samples</li>
<li>Add the newly labeled samples to the training set</li>
<li>Retrain the model and repeat steps 3-6 until a stopping condition is met</li>
</ol>
</section>
<section id="pool-based-vs.-stream-based-learning" class="level3">
<h3 class="anchored" data-anchor-id="pool-based-vs.-stream-based-learning" id="pool-based-vs.-stream-based-learning">Pool-Based vs.&nbsp;Stream-Based Learning</h3>
<ul>
<li><strong>Pool-based</strong>: The learner has access to a pool of unlabeled data and selects the most informative samples</li>
<li><strong>Stream-based</strong>: Samples arrive sequentially, and the learner must decide on-the-fly whether to request labels</li>
</ul>
</section>
</section>
<section id="influence-selection-strategies" class="level2">
<h2 class="anchored" data-anchor-id="influence-selection-strategies" id="influence-selection-strategies">Influence Selection Strategies</h2>
<p>Influence selection is about identifying which unlabeled samples would be most beneficial to label next. Here are the main strategies:</p>
</section>
<section id="uncertainty-based-methods" class="level2">
<h2 class="anchored" data-anchor-id="uncertainty-based-methods" id="uncertainty-based-methods">Uncertainty-Based Methods</h2>
<p>These methods select samples that the model is most uncertain about.</p>
<section id="least-confidence" class="level3">
<h3 class="anchored" data-anchor-id="least-confidence" id="least-confidence">Least Confidence</h3>
<div id="b3a7ed52" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> least_confidence(model, unlabeled_pool, k):</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict probabilities for each sample in the unlabeled pool</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_pool)</span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get the confidence values for the most probable class</span></span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>    confidences <span class="op">=</span> np.<span class="bu">max</span>(probabilities, axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the lowest confidence</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>    least_confident_indices <span class="op">=</span> np.argsort(confidences)[:k]</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[least_confident_indices]</span></code></pre></div></div>
</div>
</section>
<section id="margin-sampling" class="level3">
<h3 class="anchored" data-anchor-id="margin-sampling" id="margin-sampling">Margin Sampling</h3>
<p>Selects samples with the smallest margin between the two most likely classes:</p>
<div id="c336c5cb" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> margin_sampling(model, unlabeled_pool, k):</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict probabilities for each sample in the unlabeled pool</span></span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_pool)</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Sort the probabilities in descending order</span></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a>    sorted_probs <span class="op">=</span> np.sort(probabilities, axis<span class="op">=</span><span class="dv">1</span>)[:, ::<span class="op">-</span><span class="dv">1</span>]</span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate the margin between the first and second most probable classes</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>    margins <span class="op">=</span> sorted_probs[:, <span class="dv">0</span>] <span class="op">-</span> sorted_probs[:, <span class="dv">1</span>]</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the smallest margins</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>    smallest_margin_indices <span class="op">=</span> np.argsort(margins)[:k]</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[smallest_margin_indices]</span></code></pre></div></div>
</div>
</section>
<section id="entropy-based-sampling" class="level3">
<h3 class="anchored" data-anchor-id="entropy-based-sampling" id="entropy-based-sampling">Entropy-Based Sampling</h3>
<p>Selects samples with the highest predictive entropy:</p>
<div id="cd66cd2e" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> entropy_sampling(model, unlabeled_pool, k):</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict probabilities for each sample in the unlabeled pool</span></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_pool)</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate entropy for each sample</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    entropies <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(probabilities <span class="op">*</span> np.log(probabilities <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest entropy</span></span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>    highest_entropy_indices <span class="op">=</span> np.argsort(entropies)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_entropy_indices]</span></code></pre></div></div>
</div>
</section>
<section id="bayesian-active-learning-by-disagreement-bald" class="level3">
<h3 class="anchored" data-anchor-id="bayesian-active-learning-by-disagreement-bald" id="bayesian-active-learning-by-disagreement-bald">Bayesian Active Learning by Disagreement (BALD)</h3>
<p>For Bayesian models, BALD selects samples that maximize the mutual information between predictions and model parameters:</p>
<div id="4f4228cf" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> bald_sampling(bayesian_model, unlabeled_pool, k, n_samples<span class="op">=</span><span class="dv">100</span>):</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get multiple predictions by sampling from the model's posterior</span></span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a>    probs_samples <span class="op">=</span> []</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_samples):</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>        probs <span class="op">=</span> bayesian_model.predict_proba(unlabeled_pool)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a>        probs_samples.append(probs)</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stack into a 3D array: (samples, data points, classes)</span></span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>    probs_samples <span class="op">=</span> np.stack(probs_samples)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate the average probability across all samples</span></span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>    mean_probs <span class="op">=</span> np.mean(probs_samples, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate the entropy of the average prediction</span></span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>    entropy_mean <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(mean_probs <span class="op">*</span> np.log(mean_probs <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate the average entropy across all samples</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>    entropy_samples <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(probs_samples <span class="op">*</span> np.log(probs_samples <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>    mean_entropy <span class="op">=</span> np.mean(entropy_samples, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Mutual information = entropy of the mean - mean of entropies</span></span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>    bald_scores <span class="op">=</span> entropy_mean <span class="op">-</span> mean_entropy</span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest BALD scores</span></span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>    highest_bald_indices <span class="op">=</span> np.argsort(bald_scores)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_bald_indices]</span></code></pre></div></div>
</div>
</section>
</section>
<section id="diversity-based-methods" class="level2">
<h2 class="anchored" data-anchor-id="diversity-based-methods" id="diversity-based-methods">Diversity-Based Methods</h2>
<p>These methods aim to select a diverse set of examples to ensure broad coverage of the input space.</p>
<section id="clustering-based-sampling" class="level3">
<h3 class="anchored" data-anchor-id="clustering-based-sampling" id="clustering-based-sampling">Clustering-Based Sampling</h3>
<div id="1c00d11c" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> clustering_based_sampling(unlabeled_pool, k, n_clusters<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> n_clusters <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>        n_clusters <span class="op">=</span> k</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Apply K-means clustering</span></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    kmeans <span class="op">=</span> KMeans(n_clusters<span class="op">=</span>n_clusters)</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    kmeans.fit(unlabeled_pool)</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get the cluster centers and distances to each point</span></span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    centers <span class="op">=</span> kmeans.cluster_centers_</span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    distances <span class="op">=</span> kmeans.transform(unlabeled_pool)  <span class="co"># Distance to each cluster center</span></span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select one sample from each cluster (closest to the center)</span></span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>    selected_indices <span class="op">=</span> []</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_clusters):</span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get the samples in this cluster</span></span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a>        cluster_samples <span class="op">=</span> np.where(kmeans.labels_ <span class="op">==</span> i)[<span class="dv">0</span>]</span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Find the sample closest to the center</span></span>
<span id="cb5-20"><a href="#cb5-20" aria-hidden="true" tabindex="-1"></a>        closest_sample <span class="op">=</span> cluster_samples[np.argmin(distances[cluster_samples, i])]</span>
<span id="cb5-21"><a href="#cb5-21" aria-hidden="true" tabindex="-1"></a>        selected_indices.append(closest_sample)</span>
<span id="cb5-22"><a href="#cb5-22" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-23"><a href="#cb5-23" aria-hidden="true" tabindex="-1"></a>    <span class="co"># If we need more samples than clusters, fill with the most uncertain samples</span></span>
<span id="cb5-24"><a href="#cb5-24" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> k <span class="op">&gt;</span> n_clusters:</span>
<span id="cb5-25"><a href="#cb5-25" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Implementation depends on uncertainty measure</span></span>
<span id="cb5-26"><a href="#cb5-26" aria-hidden="true" tabindex="-1"></a>        <span class="cf">pass</span></span>
<span id="cb5-27"><a href="#cb5-27" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb5-28"><a href="#cb5-28" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[selected_indices[:k]]</span></code></pre></div></div>
</div>
</section>
<section id="core-set-approach" class="level3">
<h3 class="anchored" data-anchor-id="core-set-approach" id="core-set-approach">Core-Set Approach</h3>
<p>The core-set approach aims to select a subset of data that best represents the whole dataset:</p>
<div id="c4909f25" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> core_set_sampling(labeled_pool, unlabeled_pool, k):</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Combine labeled and unlabeled data for distance calculations</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    all_data <span class="op">=</span> np.vstack((labeled_pool, unlabeled_pool))</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Compute pairwise distances</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    distances <span class="op">=</span> pairwise_distances(all_data)</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Split distances into labeled-unlabeled and unlabeled-unlabeled</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    n_labeled <span class="op">=</span> labeled_pool.shape[<span class="dv">0</span>]</span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    dist_labeled_unlabeled <span class="op">=</span> distances[:n_labeled, n_labeled:]</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># For each unlabeled sample, find the minimum distance to any labeled sample</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>    min_distances <span class="op">=</span> np.<span class="bu">min</span>(dist_labeled_unlabeled, axis<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the largest minimum distances</span></span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    farthest_indices <span class="op">=</span> np.argsort(min_distances)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb6-18"><a href="#cb6-18" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[farthest_indices]</span></code></pre></div></div>
</div>
</section>
</section>
<section id="expected-model-change" class="level2">
<h2 class="anchored" data-anchor-id="expected-model-change" id="expected-model-change">Expected Model Change</h2>
<p>The Expected Model Change (EMC) method selects samples that would cause the greatest change in the model if they were labeled:</p>
<div id="81d4372e" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> expected_model_change(model, unlabeled_pool, k):</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict probabilities for each sample in the unlabeled pool</span></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_pool)</span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>    n_classes <span class="op">=</span> probabilities.shape[<span class="dv">1</span>]</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate expected gradient length for each sample</span></span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    expected_changes <span class="op">=</span> []</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, x <span class="kw">in</span> <span class="bu">enumerate</span>(unlabeled_pool):</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate expected gradient length across all possible labels</span></span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>        change <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> c <span class="kw">in</span> <span class="bu">range</span>(n_classes):</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>            <span class="co"># For each possible class, calculate the gradient if this was the true label</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>            x_expanded <span class="op">=</span> x.reshape(<span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Here we would compute the gradient of the model with respect to the sample</span></span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>            <span class="co"># For simplicity, we use a placeholder</span></span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>            gradient <span class="op">=</span> compute_gradient(model, x_expanded, c)</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>            norm_gradient <span class="op">=</span> np.linalg.norm(gradient)</span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Weight by the probability of this class</span></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a>            change <span class="op">+=</span> probabilities[i, c] <span class="op">*</span> norm_gradient</span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a>        expected_changes.append(change)</span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest expected change</span></span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    highest_change_indices <span class="op">=</span> np.argsort(expected_changes)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_change_indices]</span></code></pre></div></div>
</div>
<p><em>Note: The <code>compute_gradient</code> function would need to be implemented based on the specific model being used.</em></p>
</section>
<section id="expected-error-reduction" class="level2">
<h2 class="anchored" data-anchor-id="expected-error-reduction" id="expected-error-reduction">Expected Error Reduction</h2>
<p>The Expected Error Reduction method selects samples that, when labeled, would minimally reduce the model’s expected error:</p>
<div id="3e417a22" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> expected_error_reduction(model, unlabeled_pool, unlabeled_pool_remaining, k):</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict probabilities for all remaining unlabeled data</span></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a>    current_probs <span class="op">=</span> model.predict_proba(unlabeled_pool_remaining)</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    current_entropy <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(current_probs <span class="op">*</span> np.log(current_probs <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>    expected_error_reductions <span class="op">=</span> []</span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># For each sample in the unlabeled pool we're considering</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i, x <span class="kw">in</span> <span class="bu">enumerate</span>(unlabeled_pool):</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Predict probabilities for this sample</span></span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        probs <span class="op">=</span> model.predict_proba(x.reshape(<span class="dv">1</span>, <span class="op">-</span><span class="dv">1</span>))[<span class="dv">0</span>]</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-13"><a href="#cb8-13" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate the expected error reduction for each possible label</span></span>
<span id="cb8-14"><a href="#cb8-14" aria-hidden="true" tabindex="-1"></a>        error_reduction <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb8-15"><a href="#cb8-15" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> c <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(probs)):</span>
<span id="cb8-16"><a href="#cb8-16" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Create a hypothetical new model with this labeled sample</span></span>
<span id="cb8-17"><a href="#cb8-17" aria-hidden="true" tabindex="-1"></a>            <span class="co"># For simplicity, we use a placeholder function</span></span>
<span id="cb8-18"><a href="#cb8-18" aria-hidden="true" tabindex="-1"></a>            hypothetical_model <span class="op">=</span> train_with_additional_sample(model, x, c)</span>
<span id="cb8-19"><a href="#cb8-19" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-20"><a href="#cb8-20" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Get new probabilities with this model</span></span>
<span id="cb8-21"><a href="#cb8-21" aria-hidden="true" tabindex="-1"></a>            new_probs <span class="op">=</span> hypothetical_model.predict_proba(unlabeled_pool_remaining)</span>
<span id="cb8-22"><a href="#cb8-22" aria-hidden="true" tabindex="-1"></a>            new_entropy <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(new_probs <span class="op">*</span> np.log(new_probs <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb8-23"><a href="#cb8-23" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-24"><a href="#cb8-24" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Expected entropy reduction</span></span>
<span id="cb8-25"><a href="#cb8-25" aria-hidden="true" tabindex="-1"></a>            reduction <span class="op">=</span> np.<span class="bu">sum</span>(current_entropy <span class="op">-</span> new_entropy)</span>
<span id="cb8-26"><a href="#cb8-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb8-27"><a href="#cb8-27" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Weight by the probability of this class</span></span>
<span id="cb8-28"><a href="#cb8-28" aria-hidden="true" tabindex="-1"></a>            error_reduction <span class="op">+=</span> probs[c] <span class="op">*</span> reduction</span>
<span id="cb8-29"><a href="#cb8-29" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-30"><a href="#cb8-30" aria-hidden="true" tabindex="-1"></a>        expected_error_reductions.append(error_reduction)</span>
<span id="cb8-31"><a href="#cb8-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-32"><a href="#cb8-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest expected error reduction</span></span>
<span id="cb8-33"><a href="#cb8-33" aria-hidden="true" tabindex="-1"></a>    highest_reduction_indices <span class="op">=</span> np.argsort(expected_error_reductions)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb8-34"><a href="#cb8-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb8-35"><a href="#cb8-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_reduction_indices]</span></code></pre></div></div>
</div>
<p><em>Note: The <code>train_with_additional_sample</code> function would need to be implemented based on the specific model being used.</em></p>
</section>
<section id="influence-functions" class="level2">
<h2 class="anchored" data-anchor-id="influence-functions" id="influence-functions">Influence Functions</h2>
<p>Influence functions approximate the effect of adding or removing a training example without retraining the model:</p>
<div id="cdab908a" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> influence_function_sampling(model, unlabeled_pool, labeled_pool, k, labels):</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>    influences <span class="op">=</span> []</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># For each unlabeled sample</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> x_u <span class="kw">in</span> unlabeled_pool:</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate the influence of adding this sample to the training set</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>        influence <span class="op">=</span> calculate_influence(model, x_u, labeled_pool, labels)</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>        influences.append(influence)</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest influence</span></span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    highest_influence_indices <span class="op">=</span> np.argsort(influences)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_influence_indices]</span></code></pre></div></div>
</div>
<p><em>Note: The <code>calculate_influence</code> function would need to be implemented based on the specific model and influence metric being used.</em></p>
</section>
<section id="query-by-committee" class="level2">
<h2 class="anchored" data-anchor-id="query-by-committee" id="query-by-committee">Query-by-Committee</h2>
<p>Query-by-Committee (QBC) methods train multiple models (a committee) and select samples where they disagree:</p>
<div id="21fceba9" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> query_by_committee(committee_models, unlabeled_pool, k):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get predictions from all committee members</span></span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    all_predictions <span class="op">=</span> []</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> model <span class="kw">in</span> committee_models:</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>        preds <span class="op">=</span> model.predict(unlabeled_pool)</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>        all_predictions.append(preds)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stack predictions into a 2D array (committee members, data points)</span></span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    all_predictions <span class="op">=</span> np.stack(all_predictions)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate disagreement (e.g., using vote entropy)</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    disagreements <span class="op">=</span> []</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(unlabeled_pool.shape[<span class="dv">0</span>]):</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Count votes for each class</span></span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        votes <span class="op">=</span> np.bincount(all_predictions[:, i])</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize to get probabilities</span></span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>        vote_probs <span class="op">=</span> votes <span class="op">/</span> <span class="bu">len</span>(committee_models)</span>
<span id="cb10-18"><a href="#cb10-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate entropy</span></span>
<span id="cb10-19"><a href="#cb10-19" aria-hidden="true" tabindex="-1"></a>        entropy <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(vote_probs <span class="op">*</span> np.log2(vote_probs <span class="op">+</span> <span class="fl">1e-10</span>))</span>
<span id="cb10-20"><a href="#cb10-20" aria-hidden="true" tabindex="-1"></a>        disagreements.append(entropy)</span>
<span id="cb10-21"><a href="#cb10-21" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-22"><a href="#cb10-22" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select the k samples with the highest disagreement</span></span>
<span id="cb10-23"><a href="#cb10-23" aria-hidden="true" tabindex="-1"></a>    highest_disagreement_indices <span class="op">=</span> np.argsort(disagreements)[::<span class="op">-</span><span class="dv">1</span>][:k]</span>
<span id="cb10-24"><a href="#cb10-24" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-25"><a href="#cb10-25" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[highest_disagreement_indices]</span></code></pre></div></div>
</div>
</section>
<section id="implementation-considerations" class="level2">
<h2 class="anchored" data-anchor-id="implementation-considerations" id="implementation-considerations">Implementation Considerations</h2>
<section id="batch-mode-active-learning" class="level3">
<h3 class="anchored" data-anchor-id="batch-mode-active-learning" id="batch-mode-active-learning">Batch Mode Active Learning</h3>
<p>In practice, it’s often more efficient to select multiple samples at once. However, simply selecting the top-k samples may lead to redundancy. Consider using:</p>
<ol type="1">
<li><strong>Greedy Selection with Diversity</strong>: Select one sample at a time, then update the diversity metrics to avoid selecting similar samples.</li>
</ol>
<div id="b9721a71" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> batch_selection_with_diversity(model, unlabeled_pool, k, lambda_diversity<span class="op">=</span><span class="fl">0.5</span>):</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>    selected_indices <span class="op">=</span> []</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>    remaining_indices <span class="op">=</span> <span class="bu">list</span>(<span class="bu">range</span>(<span class="bu">len</span>(unlabeled_pool)))</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate uncertainty scores for all samples</span></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_pool)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>    entropies <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(probabilities <span class="op">*</span> np.log(probabilities <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate distance matrix for diversity</span></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    distance_matrix <span class="op">=</span> pairwise_distances(unlabeled_pool)</span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(k):</span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="kw">not</span> remaining_indices:</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>            <span class="cf">break</span></span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> np.zeros(<span class="bu">len</span>(remaining_indices))</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate uncertainty scores</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>        uncertainty_scores <span class="op">=</span> entropies[remaining_indices]</span>
<span id="cb11-20"><a href="#cb11-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-21"><a href="#cb11-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate diversity scores (if we have already selected some samples)</span></span>
<span id="cb11-22"><a href="#cb11-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> selected_indices:</span>
<span id="cb11-23"><a href="#cb11-23" aria-hidden="true" tabindex="-1"></a>            <span class="co"># For each remaining sample, calculate the minimum distance to any selected sample</span></span>
<span id="cb11-24"><a href="#cb11-24" aria-hidden="true" tabindex="-1"></a>            diversity_scores <span class="op">=</span> np.<span class="bu">min</span>(distance_matrix[remaining_indices][:, selected_indices], axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb11-25"><a href="#cb11-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">else</span>:</span>
<span id="cb11-26"><a href="#cb11-26" aria-hidden="true" tabindex="-1"></a>            diversity_scores <span class="op">=</span> np.zeros(<span class="bu">len</span>(remaining_indices))</span>
<span id="cb11-27"><a href="#cb11-27" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-28"><a href="#cb11-28" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Normalize scores</span></span>
<span id="cb11-29"><a href="#cb11-29" aria-hidden="true" tabindex="-1"></a>        uncertainty_scores <span class="op">=</span> (uncertainty_scores <span class="op">-</span> np.<span class="bu">min</span>(uncertainty_scores)) <span class="op">/</span> (np.<span class="bu">max</span>(uncertainty_scores) <span class="op">-</span> np.<span class="bu">min</span>(uncertainty_scores) <span class="op">+</span> <span class="fl">1e-10</span>)</span>
<span id="cb11-30"><a href="#cb11-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> selected_indices:</span>
<span id="cb11-31"><a href="#cb11-31" aria-hidden="true" tabindex="-1"></a>            diversity_scores <span class="op">=</span> (diversity_scores <span class="op">-</span> np.<span class="bu">min</span>(diversity_scores)) <span class="op">/</span> (np.<span class="bu">max</span>(diversity_scores) <span class="op">-</span> np.<span class="bu">min</span>(diversity_scores) <span class="op">+</span> <span class="fl">1e-10</span>)</span>
<span id="cb11-32"><a href="#cb11-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-33"><a href="#cb11-33" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Combine scores</span></span>
<span id="cb11-34"><a href="#cb11-34" aria-hidden="true" tabindex="-1"></a>        scores <span class="op">=</span> (<span class="dv">1</span> <span class="op">-</span> lambda_diversity) <span class="op">*</span> uncertainty_scores <span class="op">+</span> lambda_diversity <span class="op">*</span> diversity_scores</span>
<span id="cb11-35"><a href="#cb11-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-36"><a href="#cb11-36" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select the sample with the highest score</span></span>
<span id="cb11-37"><a href="#cb11-37" aria-hidden="true" tabindex="-1"></a>        best_idx <span class="op">=</span> np.argmax(scores)</span>
<span id="cb11-38"><a href="#cb11-38" aria-hidden="true" tabindex="-1"></a>        selected_idx <span class="op">=</span> remaining_indices[best_idx]</span>
<span id="cb11-39"><a href="#cb11-39" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb11-40"><a href="#cb11-40" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Add to selected and remove from remaining</span></span>
<span id="cb11-41"><a href="#cb11-41" aria-hidden="true" tabindex="-1"></a>        selected_indices.append(selected_idx)</span>
<span id="cb11-42"><a href="#cb11-42" aria-hidden="true" tabindex="-1"></a>        remaining_indices.remove(selected_idx)</span>
<span id="cb11-43"><a href="#cb11-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb11-44"><a href="#cb11-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> unlabeled_pool[selected_indices]</span></code></pre></div></div>
</div>
<ol start="2" type="1">
<li><strong>Submodular Function Maximization</strong>: Use a submodular function to ensure diversity in the selected batch.</li>
</ol>
</section>
<section id="handling-imbalanced-data" class="level3">
<h3 class="anchored" data-anchor-id="handling-imbalanced-data" id="handling-imbalanced-data">Handling Imbalanced Data</h3>
<p>Active learning can inadvertently reinforce class imbalance. Consider:</p>
<ol type="1">
<li><strong>Stratified Sampling</strong>: Ensure representation from all classes.</li>
<li><strong>Hybrid Approaches</strong>: Combine uncertainty-based and density-based methods.</li>
<li><strong>Diversity Constraints</strong>: Explicitly enforce diversity in feature space.</li>
</ol>
</section>
<section id="computational-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="computational-efficiency" id="computational-efficiency">Computational Efficiency</h3>
<p>Some methods (like expected error reduction) can be computationally expensive. Consider:</p>
<ol type="1">
<li><strong>Subsample the Unlabeled Pool</strong>: Only consider a random subset for selection.</li>
<li><strong>Pre-compute Embeddings</strong>: Use a fixed feature extractor to pre-compute embeddings.</li>
<li><strong>Approximate Methods</strong>: Use approximations for expensive operations.</li>
</ol>
</section>
</section>
<section id="evaluation-metrics" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-metrics" id="evaluation-metrics">Evaluation Metrics</h2>
<section id="learning-curves" class="level3">
<h3 class="anchored" data-anchor-id="learning-curves" id="learning-curves">Learning Curves</h3>
<p>Plot model performance vs.&nbsp;number of labeled samples:</p>
<div id="3c5f974e" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> plot_learning_curve(model_factory, X_train, y_train, X_test, y_test, </span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>                        active_learning_strategy, initial_size<span class="op">=</span><span class="dv">10</span>, </span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a>                        batch_size<span class="op">=</span><span class="dv">10</span>, n_iterations<span class="op">=</span><span class="dv">20</span>):</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Initialize with a small labeled set</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>    labeled_indices <span class="op">=</span> np.random.choice(<span class="bu">len</span>(X_train), initial_size, replace<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>    unlabeled_indices <span class="op">=</span> np.setdiff1d(np.arange(<span class="bu">len</span>(X_train)), labeled_indices)</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    performance <span class="op">=</span> []</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(n_iterations):</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Create a fresh model</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>        model <span class="op">=</span> model_factory()</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Train on the currently labeled data</span></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>        model.fit(X_train[labeled_indices], y_train[labeled_indices])</span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Evaluate on the test set</span></span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>        score <span class="op">=</span> model.score(X_test, y_test)</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>        performance.append((<span class="bu">len</span>(labeled_indices), score))</span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb12-21"><a href="#cb12-21" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select the next batch of samples</span></span>
<span id="cb12-22"><a href="#cb12-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> <span class="bu">len</span>(unlabeled_indices) <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb12-23"><a href="#cb12-23" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Use the specified active learning strategy</span></span>
<span id="cb12-24"><a href="#cb12-24" aria-hidden="true" tabindex="-1"></a>            selected_indices <span class="op">=</span> active_learning_strategy(</span>
<span id="cb12-25"><a href="#cb12-25" aria-hidden="true" tabindex="-1"></a>                model, X_train[unlabeled_indices], batch_size</span>
<span id="cb12-26"><a href="#cb12-26" aria-hidden="true" tabindex="-1"></a>            )</span>
<span id="cb12-27"><a href="#cb12-27" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-28"><a href="#cb12-28" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Map back to original indices</span></span>
<span id="cb12-29"><a href="#cb12-29" aria-hidden="true" tabindex="-1"></a>            selected_original_indices <span class="op">=</span> unlabeled_indices[selected_indices]</span>
<span id="cb12-30"><a href="#cb12-30" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb12-31"><a href="#cb12-31" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Update labeled and unlabeled indices</span></span>
<span id="cb12-32"><a href="#cb12-32" aria-hidden="true" tabindex="-1"></a>            labeled_indices <span class="op">=</span> np.append(labeled_indices, selected_original_indices)</span>
<span id="cb12-33"><a href="#cb12-33" aria-hidden="true" tabindex="-1"></a>            unlabeled_indices <span class="op">=</span> np.setdiff1d(unlabeled_indices, selected_original_indices)</span>
<span id="cb12-34"><a href="#cb12-34" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-35"><a href="#cb12-35" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Plot the learning curve</span></span>
<span id="cb12-36"><a href="#cb12-36" aria-hidden="true" tabindex="-1"></a>    counts, scores <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>performance)</span>
<span id="cb12-37"><a href="#cb12-37" aria-hidden="true" tabindex="-1"></a>    plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb12-38"><a href="#cb12-38" aria-hidden="true" tabindex="-1"></a>    plt.plot(counts, scores, <span class="st">'o-'</span>)</span>
<span id="cb12-39"><a href="#cb12-39" aria-hidden="true" tabindex="-1"></a>    plt.xlabel(<span class="st">'Number of labeled samples'</span>)</span>
<span id="cb12-40"><a href="#cb12-40" aria-hidden="true" tabindex="-1"></a>    plt.ylabel(<span class="st">'Model accuracy'</span>)</span>
<span id="cb12-41"><a href="#cb12-41" aria-hidden="true" tabindex="-1"></a>    plt.title(<span class="st">'Active Learning Performance'</span>)</span>
<span id="cb12-42"><a href="#cb12-42" aria-hidden="true" tabindex="-1"></a>    plt.grid(<span class="va">True</span>)</span>
<span id="cb12-43"><a href="#cb12-43" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb12-44"><a href="#cb12-44" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> performance</span></code></pre></div></div>
</div>
</section>
<section id="comparison-with-random-sampling" class="level3">
<h3 class="anchored" data-anchor-id="comparison-with-random-sampling" id="comparison-with-random-sampling">Comparison with Random Sampling</h3>
<p>Always compare your active learning strategy with random sampling as a baseline.</p>
</section>
<section id="annotation-efficiency" class="level3">
<h3 class="anchored" data-anchor-id="annotation-efficiency" id="annotation-efficiency">Annotation Efficiency</h3>
<p>Calculate how many annotations you saved compared to using the entire dataset.</p>
</section>
</section>
<section id="practical-examples" class="level2">
<h2 class="anchored" data-anchor-id="practical-examples" id="practical-examples">Practical Examples</h2>
<section id="image-classification-with-uncertainty-sampling" class="level3">
<h3 class="anchored" data-anchor-id="image-classification-with-uncertainty-sampling" id="image-classification-with-uncertainty-sampling">Image Classification with Uncertainty Sampling</h3>
<div id="016b5e79" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.datasets <span class="im">import</span> fetch_openml</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.ensemble <span class="im">import</span> RandomForestClassifier</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.metrics <span class="im">import</span> accuracy_score</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.model_selection <span class="im">import</span> train_test_split</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Load data</span></span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>mnist <span class="op">=</span> fetch_openml(<span class="st">'mnist_784'</span>, version<span class="op">=</span><span class="dv">1</span>, cache<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>X, y <span class="op">=</span> mnist[<span class="st">'data'</span>], mnist[<span class="st">'target'</span>]</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Split into train and test</span></span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>X_train, X_test, y_train, y_test <span class="op">=</span> train_test_split(</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a>    X, y, test_size<span class="op">=</span><span class="fl">0.2</span>, random_state<span class="op">=</span><span class="dv">42</span></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Initially, only a small portion is labeled</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>initial_size <span class="op">=</span> <span class="dv">100</span></span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a>labeled_indices <span class="op">=</span> np.random.choice(<span class="bu">len</span>(X_train), initial_size, replace<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a>unlabeled_indices <span class="op">=</span> np.setdiff1d(np.arange(<span class="bu">len</span>(X_train)), labeled_indices)</span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a><span class="co"># Tracking performance</span></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a>active_learning_performance <span class="op">=</span> []</span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>random_sampling_performance <span class="op">=</span> []</span>
<span id="cb13-25"><a href="#cb13-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-26"><a href="#cb13-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Active learning loop</span></span>
<span id="cb13-27"><a href="#cb13-27" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):  <span class="co"># 10 iterations</span></span>
<span id="cb13-28"><a href="#cb13-28" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train a model on the currently labeled data</span></span>
<span id="cb13-29"><a href="#cb13-29" aria-hidden="true" tabindex="-1"></a>    model <span class="op">=</span> RandomForestClassifier(n_estimators<span class="op">=</span><span class="dv">50</span>, random_state<span class="op">=</span><span class="dv">42</span>)</span>
<span id="cb13-30"><a href="#cb13-30" aria-hidden="true" tabindex="-1"></a>    model.fit(X_train.iloc[labeled_indices], y_train.iloc[labeled_indices])</span>
<span id="cb13-31"><a href="#cb13-31" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-32"><a href="#cb13-32" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluate on the test set</span></span>
<span id="cb13-33"><a href="#cb13-33" aria-hidden="true" tabindex="-1"></a>    y_pred <span class="op">=</span> model.predict(X_test)</span>
<span id="cb13-34"><a href="#cb13-34" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> accuracy_score(y_test, y_pred)</span>
<span id="cb13-35"><a href="#cb13-35" aria-hidden="true" tabindex="-1"></a>    active_learning_performance.append((<span class="bu">len</span>(labeled_indices), accuracy))</span>
<span id="cb13-36"><a href="#cb13-36" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-37"><a href="#cb13-37" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Iteration </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span><span class="bu">len</span>(labeled_indices)<span class="sc">}</span><span class="ss"> labeled samples, "</span></span>
<span id="cb13-38"><a href="#cb13-38" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"accuracy: </span><span class="sc">{</span>accuracy<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb13-39"><a href="#cb13-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb13-40"><a href="#cb13-40" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select 100 new samples using entropy sampling</span></span>
<span id="cb13-41"><a href="#cb13-41" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(unlabeled_indices) <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb13-42"><a href="#cb13-42" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Predict probabilities for each unlabeled sample</span></span>
<span id="cb13-43"><a href="#cb13-43" aria-hidden="true" tabindex="-1"></a>        probs <span class="op">=</span> model.predict_proba(X_train.iloc[unlabeled_indices])</span>
<span id="cb13-44"><a href="#cb13-44" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-45"><a href="#cb13-45" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate entropy</span></span>
<span id="cb13-46"><a href="#cb13-46" aria-hidden="true" tabindex="-1"></a>        entropies <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(probs <span class="op">*</span> np.log(probs <span class="op">+</span> <span class="fl">1e-10</span>), axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb13-47"><a href="#cb13-47" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-48"><a href="#cb13-48" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select samples with the highest entropy</span></span>
<span id="cb13-49"><a href="#cb13-49" aria-hidden="true" tabindex="-1"></a>        top_indices <span class="op">=</span> np.argsort(entropies)[::<span class="op">-</span><span class="dv">1</span>][:<span class="dv">100</span>]</span>
<span id="cb13-50"><a href="#cb13-50" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb13-51"><a href="#cb13-51" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update labeled and unlabeled indices</span></span>
<span id="cb13-52"><a href="#cb13-52" aria-hidden="true" tabindex="-1"></a>        selected_indices <span class="op">=</span> unlabeled_indices[top_indices]</span>
<span id="cb13-53"><a href="#cb13-53" aria-hidden="true" tabindex="-1"></a>        labeled_indices <span class="op">=</span> np.append(labeled_indices, selected_indices)</span>
<span id="cb13-54"><a href="#cb13-54" aria-hidden="true" tabindex="-1"></a>        unlabeled_indices <span class="op">=</span> np.setdiff1d(unlabeled_indices, selected_indices)</span>
<span id="cb13-55"><a href="#cb13-55" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-56"><a href="#cb13-56" aria-hidden="true" tabindex="-1"></a><span class="co"># Plot learning curve</span></span>
<span id="cb13-57"><a href="#cb13-57" aria-hidden="true" tabindex="-1"></a>counts, scores <span class="op">=</span> <span class="bu">zip</span>(<span class="op">*</span>active_learning_performance)</span>
<span id="cb13-58"><a href="#cb13-58" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb13-59"><a href="#cb13-59" aria-hidden="true" tabindex="-1"></a>plt.plot(counts, scores, <span class="st">'o-'</span>, label<span class="op">=</span><span class="st">'Active Learning'</span>)</span>
<span id="cb13-60"><a href="#cb13-60" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'Number of labeled samples'</span>)</span>
<span id="cb13-61"><a href="#cb13-61" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Model accuracy'</span>)</span>
<span id="cb13-62"><a href="#cb13-62" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Active Learning Performance'</span>)</span>
<span id="cb13-63"><a href="#cb13-63" aria-hidden="true" tabindex="-1"></a>plt.grid(<span class="va">True</span>)</span>
<span id="cb13-64"><a href="#cb13-64" aria-hidden="true" tabindex="-1"></a>plt.legend()</span>
<span id="cb13-65"><a href="#cb13-65" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Iteration 1: 100 labeled samples, accuracy: 0.6221
Iteration 2: 200 labeled samples, accuracy: 0.6894
Iteration 3: 300 labeled samples, accuracy: 0.7181
Iteration 4: 400 labeled samples, accuracy: 0.7633
Iteration 5: 500 labeled samples, accuracy: 0.7948
Iteration 6: 600 labeled samples, accuracy: 0.8139
Iteration 7: 700 labeled samples, accuracy: 0.8338
Iteration 8: 800 labeled samples, accuracy: 0.8404
Iteration 9: 900 labeled samples, accuracy: 0.8566
Iteration 10: 1000 labeled samples, accuracy: 0.8621</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/influence-selection/cell-14-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="text-classification-with-query-by-committee" class="level3">
<h3 class="anchored" data-anchor-id="text-classification-with-query-by-committee" id="text-classification-with-query-by-committee">Text Classification with Query-by-Committee</h3>
<div id="cb4392fc" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.feature_extraction.text <span class="im">import</span> TfidfVectorizer</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.naive_bayes <span class="im">import</span> MultinomialNB</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.svm <span class="im">import</span> SVC</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.ensemble <span class="im">import</span> VotingClassifier</span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> sklearn.datasets <span class="im">import</span> fetch_20newsgroups</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Load data</span></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>categories <span class="op">=</span> [<span class="st">'alt.atheism'</span>, <span class="st">'soc.religion.christian'</span>, <span class="st">'comp.graphics'</span>, <span class="st">'sci.med'</span>]</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>twenty_train <span class="op">=</span> fetch_20newsgroups(subset<span class="op">=</span><span class="st">'train'</span>, categories<span class="op">=</span>categories, shuffle<span class="op">=</span><span class="va">True</span>, random_state<span class="op">=</span><span class="dv">42</span>)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>twenty_test <span class="op">=</span> fetch_20newsgroups(subset<span class="op">=</span><span class="st">'test'</span>, categories<span class="op">=</span>categories, shuffle<span class="op">=</span><span class="va">True</span>, random_state<span class="op">=</span><span class="dv">42</span>)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Feature extraction</span></span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>vectorizer <span class="op">=</span> TfidfVectorizer(stop_words<span class="op">=</span><span class="st">'english'</span>)</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>X_train <span class="op">=</span> vectorizer.fit_transform(twenty_train.data)</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>X_test <span class="op">=</span> vectorizer.transform(twenty_test.data)</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>y_train <span class="op">=</span> twenty_train.target</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>y_test <span class="op">=</span> twenty_test.target</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Initially, only a small portion is labeled</span></span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>initial_size <span class="op">=</span> <span class="dv">20</span></span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>labeled_indices <span class="op">=</span> np.random.choice(<span class="bu">len</span>(X_train.toarray()), initial_size, replace<span class="op">=</span><span class="va">False</span>)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>unlabeled_indices <span class="op">=</span> np.setdiff1d(np.arange(<span class="bu">len</span>(X_train.toarray())), labeled_indices)</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a committee of models</span></span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> [</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'nb'</span>, MultinomialNB()),</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'svm'</span>, SVC(kernel<span class="op">=</span><span class="st">'linear'</span>, probability<span class="op">=</span><span class="va">True</span>)),</span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a>    (<span class="st">'svm2'</span>, SVC(kernel<span class="op">=</span><span class="st">'rbf'</span>, probability<span class="op">=</span><span class="va">True</span>))</span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>]</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a><span class="co"># Active learning loop</span></span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">10</span>):  <span class="co"># 10 iterations</span></span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train each model on the currently labeled data</span></span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a>    committee_models <span class="op">=</span> []</span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> name, model <span class="kw">in</span> models:</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a>        model.fit(X_train[labeled_indices], y_train[labeled_indices])</span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a>        committee_models.append(model)</span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Evaluate using the VotingClassifier</span></span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>    voting_clf <span class="op">=</span> VotingClassifier(estimators<span class="op">=</span>models, voting<span class="op">=</span><span class="st">'soft'</span>)</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>    voting_clf.fit(X_train[labeled_indices], y_train[labeled_indices])</span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>    accuracy <span class="op">=</span> voting_clf.score(X_test, y_test)</span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>    <span class="bu">print</span>(<span class="ss">f"Iteration </span><span class="sc">{</span>i<span class="op">+</span><span class="dv">1</span><span class="sc">}</span><span class="ss">: </span><span class="sc">{</span><span class="bu">len</span>(labeled_indices)<span class="sc">}</span><span class="ss"> labeled samples, "</span></span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>          <span class="ss">f"accuracy: </span><span class="sc">{</span>accuracy<span class="sc">:.4f}</span><span class="ss">"</span>)</span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Select 10 new samples using Query-by-Committee</span></span>
<span id="cb15-48"><a href="#cb15-48" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> <span class="bu">len</span>(unlabeled_indices) <span class="op">&gt;</span> <span class="dv">0</span>:</span>
<span id="cb15-49"><a href="#cb15-49" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Get predictions from all committee members</span></span>
<span id="cb15-50"><a href="#cb15-50" aria-hidden="true" tabindex="-1"></a>        all_predictions <span class="op">=</span> []</span>
<span id="cb15-51"><a href="#cb15-51" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> model <span class="kw">in</span> committee_models:</span>
<span id="cb15-52"><a href="#cb15-52" aria-hidden="true" tabindex="-1"></a>            preds <span class="op">=</span> model.predict(X_train[unlabeled_indices])</span>
<span id="cb15-53"><a href="#cb15-53" aria-hidden="true" tabindex="-1"></a>            all_predictions.append(preds)</span>
<span id="cb15-54"><a href="#cb15-54" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-55"><a href="#cb15-55" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Calculate vote entropy</span></span>
<span id="cb15-56"><a href="#cb15-56" aria-hidden="true" tabindex="-1"></a>        vote_entropies <span class="op">=</span> []</span>
<span id="cb15-57"><a href="#cb15-57" aria-hidden="true" tabindex="-1"></a>        all_predictions <span class="op">=</span> np.array(all_predictions)</span>
<span id="cb15-58"><a href="#cb15-58" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="bu">len</span>(unlabeled_indices)):</span>
<span id="cb15-59"><a href="#cb15-59" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Count votes for each class</span></span>
<span id="cb15-60"><a href="#cb15-60" aria-hidden="true" tabindex="-1"></a>            votes <span class="op">=</span> np.bincount(all_predictions[:, i], minlength<span class="op">=</span><span class="bu">len</span>(categories))</span>
<span id="cb15-61"><a href="#cb15-61" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Normalize to get probabilities</span></span>
<span id="cb15-62"><a href="#cb15-62" aria-hidden="true" tabindex="-1"></a>            vote_probs <span class="op">=</span> votes <span class="op">/</span> <span class="bu">len</span>(committee_models)</span>
<span id="cb15-63"><a href="#cb15-63" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Calculate entropy</span></span>
<span id="cb15-64"><a href="#cb15-64" aria-hidden="true" tabindex="-1"></a>            entropy <span class="op">=</span> <span class="op">-</span>np.<span class="bu">sum</span>(vote_probs <span class="op">*</span> np.log2(vote_probs <span class="op">+</span> <span class="fl">1e-10</span>))</span>
<span id="cb15-65"><a href="#cb15-65" aria-hidden="true" tabindex="-1"></a>            vote_entropies.append(entropy)</span>
<span id="cb15-66"><a href="#cb15-66" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-67"><a href="#cb15-67" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Select samples with the highest vote entropy</span></span>
<span id="cb15-68"><a href="#cb15-68" aria-hidden="true" tabindex="-1"></a>        top_indices <span class="op">=</span> np.argsort(vote_entropies)[::<span class="op">-</span><span class="dv">1</span>][:<span class="dv">10</span>]</span>
<span id="cb15-69"><a href="#cb15-69" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb15-70"><a href="#cb15-70" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Update labeled and unlabeled indices</span></span>
<span id="cb15-71"><a href="#cb15-71" aria-hidden="true" tabindex="-1"></a>        selected_indices <span class="op">=</span> unlabeled_indices[top_indices]</span>
<span id="cb15-72"><a href="#cb15-72" aria-hidden="true" tabindex="-1"></a>        labeled_indices <span class="op">=</span> np.append(labeled_indices, selected_indices)</span>
<span id="cb15-73"><a href="#cb15-73" aria-hidden="true" tabindex="-1"></a>        unlabeled_indices <span class="op">=</span> np.setdiff1d(unlabeled_indices, selected_indices)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Iteration 1: 20 labeled samples, accuracy: 0.2230
Iteration 2: 30 labeled samples, accuracy: 0.2636
Iteration 3: 40 labeled samples, accuracy: 0.7044
Iteration 4: 50 labeled samples, accuracy: 0.4834
Iteration 5: 60 labeled samples, accuracy: 0.6411
Iteration 6: 70 labeled samples, accuracy: 0.7583
Iteration 7: 80 labeled samples, accuracy: 0.7244
Iteration 8: 90 labeled samples, accuracy: 0.7803
Iteration 9: 100 labeled samples, accuracy: 0.8009
Iteration 10: 110 labeled samples, accuracy: 0.8109</code></pre>
</div>
</div>
</section>
</section>
<section id="advanced-topics" class="level2">
<h2 class="anchored" data-anchor-id="advanced-topics" id="advanced-topics">Advanced Topics</h2>
<section id="transfer-learning-with-active-learning" class="level3">
<h3 class="anchored" data-anchor-id="transfer-learning-with-active-learning" id="transfer-learning-with-active-learning">Transfer Learning with Active Learning</h3>
<p>Combining transfer learning with active learning can be powerful:</p>
<ol type="1">
<li>Use pre-trained models as feature extractors.</li>
<li>Apply active learning on the feature space.</li>
<li>Fine-tune the model on the selected samples.</li>
</ol>
</section>
<section id="active-learning-with-deep-learning" class="level3">
<h3 class="anchored" data-anchor-id="active-learning-with-deep-learning" id="active-learning-with-deep-learning">Active Learning with Deep Learning</h3>
<p>Special considerations for deep learning models:</p>
<ol type="1">
<li><strong>Uncertainty Estimation</strong>: Use dropout or ensemble methods for better uncertainty estimation.</li>
<li><strong>Batch Normalization</strong>: Be careful with batch normalization layers when retraining.</li>
<li><strong>Data Augmentation</strong>: Apply data augmentation to increase the effective size of the labeled pool.</li>
</ol>
<div id="5502d45f" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, Subset</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.transforms <span class="im">as</span> transforms</span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torchvision.datasets <span class="im">as</span> datasets</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Define a simple CNN</span></span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SimpleCNN(nn.Module):</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>(SimpleCNN, <span class="va">self</span>).<span class="fu">__init__</span>()</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv1 <span class="op">=</span> nn.Conv2d(<span class="dv">1</span>, <span class="dv">32</span>, <span class="dv">3</span>, <span class="dv">1</span>)</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.conv2 <span class="op">=</span> nn.Conv2d(<span class="dv">32</span>, <span class="dv">64</span>, <span class="dv">3</span>, <span class="dv">1</span>)</span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout1 <span class="op">=</span> nn.Dropout2d(<span class="fl">0.25</span>)</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.dropout2 <span class="op">=</span> nn.Dropout2d(<span class="fl">0.5</span>)</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc1 <span class="op">=</span> nn.Linear(<span class="dv">9216</span>, <span class="dv">128</span>)</span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.fc2 <span class="op">=</span> nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x, dropout<span class="op">=</span><span class="va">True</span>):</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv1(x)</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.conv2(x)</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.max_pool2d(x, <span class="dv">2</span>)</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> dropout:</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> <span class="va">self</span>.dropout1(x)</span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> torch.flatten(x, <span class="dv">1</span>)</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc1(x)</span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb17-30"><a href="#cb17-30" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> dropout:</span>
<span id="cb17-31"><a href="#cb17-31" aria-hidden="true" tabindex="-1"></a>            x <span class="op">=</span> <span class="va">self</span>.dropout2(x)</span>
<span id="cb17-32"><a href="#cb17-32" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.fc2(x)</span>
<span id="cb17-33"><a href="#cb17-33" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb17-34"><a href="#cb17-34" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-35"><a href="#cb17-35" aria-hidden="true" tabindex="-1"></a><span class="co"># MC Dropout for uncertainty estimation</span></span>
<span id="cb17-36"><a href="#cb17-36" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mc_dropout_uncertainty(model, data_loader, n_samples<span class="op">=</span><span class="dv">10</span>):</span>
<span id="cb17-37"><a href="#cb17-37" aria-hidden="true" tabindex="-1"></a>    model.<span class="bu">eval</span>()</span>
<span id="cb17-38"><a href="#cb17-38" aria-hidden="true" tabindex="-1"></a>    all_probs <span class="op">=</span> []</span>
<span id="cb17-39"><a href="#cb17-39" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-40"><a href="#cb17-40" aria-hidden="true" tabindex="-1"></a>    <span class="cf">with</span> torch.no_grad():</span>
<span id="cb17-41"><a href="#cb17-41" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(n_samples):</span>
<span id="cb17-42"><a href="#cb17-42" aria-hidden="true" tabindex="-1"></a>            batch_probs <span class="op">=</span> []</span>
<span id="cb17-43"><a href="#cb17-43" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> data, _ <span class="kw">in</span> data_loader:</span>
<span id="cb17-44"><a href="#cb17-44" aria-hidden="true" tabindex="-1"></a>                output <span class="op">=</span> model(data, dropout<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb17-45"><a href="#cb17-45" aria-hidden="true" tabindex="-1"></a>                probs <span class="op">=</span> F.softmax(output, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb17-46"><a href="#cb17-46" aria-hidden="true" tabindex="-1"></a>                batch_probs.append(probs)</span>
<span id="cb17-47"><a href="#cb17-47" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb17-48"><a href="#cb17-48" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Concatenate batch probabilities</span></span>
<span id="cb17-49"><a href="#cb17-49" aria-hidden="true" tabindex="-1"></a>            all_probs.append(torch.cat(batch_probs))</span>
<span id="cb17-50"><a href="#cb17-50" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-51"><a href="#cb17-51" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Stack along a new dimension</span></span>
<span id="cb17-52"><a href="#cb17-52" aria-hidden="true" tabindex="-1"></a>    all_probs <span class="op">=</span> torch.stack(all_probs)</span>
<span id="cb17-53"><a href="#cb17-53" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-54"><a href="#cb17-54" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate the mean probabilities</span></span>
<span id="cb17-55"><a href="#cb17-55" aria-hidden="true" tabindex="-1"></a>    mean_probs <span class="op">=</span> torch.mean(all_probs, dim<span class="op">=</span><span class="dv">0</span>)</span>
<span id="cb17-56"><a href="#cb17-56" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-57"><a href="#cb17-57" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Calculate entropy of the mean prediction</span></span>
<span id="cb17-58"><a href="#cb17-58" aria-hidden="true" tabindex="-1"></a>    entropy <span class="op">=</span> <span class="op">-</span>torch.<span class="bu">sum</span>(mean_probs <span class="op">*</span> torch.log(mean_probs <span class="op">+</span> <span class="fl">1e-10</span>), dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb17-59"><a href="#cb17-59" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb17-60"><a href="#cb17-60" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> entropy.numpy()</span></code></pre></div></div>
</div>
</section>
<section id="semi-supervised-active-learning" class="level3">
<h3 class="anchored" data-anchor-id="semi-supervised-active-learning" id="semi-supervised-active-learning">Semi-Supervised Active Learning</h3>
<p>Leverage both labeled and unlabeled data during training:</p>
<ol type="1">
<li><strong>Self-Training</strong>: Use model predictions on unlabeled data as pseudo-labels.</li>
<li><strong>Co-Training</strong>: Train multiple models and use their predictions to teach each other.</li>
<li><strong>Consistency Regularization</strong>: Enforce consistent predictions across different perturbations.</li>
</ol>
<div id="30c7562b" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> semi_supervised_active_learning(labeled_X, labeled_y, unlabeled_X, model, confidence_threshold<span class="op">=</span><span class="fl">0.95</span>):</span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train model on labeled data</span></span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    model.fit(labeled_X, labeled_y)</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Predict on unlabeled data</span></span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    probabilities <span class="op">=</span> model.predict_proba(unlabeled_X)</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>    max_probs <span class="op">=</span> np.<span class="bu">max</span>(probabilities, axis<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb18-8"><a href="#cb18-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-9"><a href="#cb18-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get high confidence predictions</span></span>
<span id="cb18-10"><a href="#cb18-10" aria-hidden="true" tabindex="-1"></a>    confident_indices <span class="op">=</span> np.where(max_probs <span class="op">&gt;=</span> confidence_threshold)[<span class="dv">0</span>]</span>
<span id="cb18-11"><a href="#cb18-11" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-12"><a href="#cb18-12" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Get pseudo-labels for confident predictions</span></span>
<span id="cb18-13"><a href="#cb18-13" aria-hidden="true" tabindex="-1"></a>    pseudo_labels <span class="op">=</span> model.predict(unlabeled_X[confident_indices])</span>
<span id="cb18-14"><a href="#cb18-14" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-15"><a href="#cb18-15" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Train on combined dataset</span></span>
<span id="cb18-16"><a href="#cb18-16" aria-hidden="true" tabindex="-1"></a>    combined_X <span class="op">=</span> np.vstack([labeled_X, unlabeled_X[confident_indices]])</span>
<span id="cb18-17"><a href="#cb18-17" aria-hidden="true" tabindex="-1"></a>    combined_y <span class="op">=</span> np.concatenate([labeled_y, pseudo_labels])</span>
<span id="cb18-18"><a href="#cb18-18" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-19"><a href="#cb18-19" aria-hidden="true" tabindex="-1"></a>    model.fit(combined_X, combined_y)</span>
<span id="cb18-20"><a href="#cb18-20" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb18-21"><a href="#cb18-21" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> model, confident_indices</span></code></pre></div></div>
</div>
</section>
<section id="active-learning-for-domain-adaptation" class="level3">
<h3 class="anchored" data-anchor-id="active-learning-for-domain-adaptation" id="active-learning-for-domain-adaptation">Active Learning for Domain Adaptation</h3>
<p>When labeled data from the target domain is scarce, active learning can help select the most informative samples:</p>
<ol type="1">
<li><strong>Domain Discrepancy Measures</strong>: Select samples that minimize domain discrepancy.</li>
<li><strong>Adversarial Selection</strong>: Select samples that the domain discriminator is most uncertain about.</li>
<li><strong>Feature Space Alignment</strong>: Select samples that help align feature spaces between domains.</li>
</ol>
</section>
<section id="human-in-the-loop-considerations" class="level3">
<h3 class="anchored" data-anchor-id="human-in-the-loop-considerations" id="human-in-the-loop-considerations">Human-in-the-Loop Considerations</h3>
<ol type="1">
<li><strong>Annotation Interface Design</strong>: Make the annotation process intuitive and efficient.</li>
<li><strong>Cognitive Load Management</strong>: Group similar samples to reduce cognitive switching.</li>
<li><strong>Explanations</strong>: Provide model explanations to help annotators understand the current model’s decisions.</li>
<li><strong>Quality Control</strong>: Incorporate mechanisms to detect and correct annotation errors.</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Active learning provides a powerful framework for efficiently building machine learning models with limited labeled data. By selecting the most informative samples for annotation, active learning can significantly reduce the labeling effort while maintaining high model performance.</p>
<p>The key to successful active learning is choosing the right influence selection strategy for your specific problem and data characteristics. Consider the following when designing your active learning pipeline:</p>
<ol type="1">
<li><strong>Data Characteristics</strong>: Dense vs.&nbsp;sparse data, balanced vs.&nbsp;imbalanced classes, feature distribution.</li>
<li><strong>Model Type</strong>: Linear models, tree-based models, deep learning models.</li>
<li><strong>Computational Resources</strong>: Available memory and processing power.</li>
<li><strong>Annotation Budget</strong>: Number of samples that can be labeled.</li>
<li><strong>Task Complexity</strong>: Classification vs.&nbsp;regression, number of classes, difficulty of the task.</li>
</ol>
<p>By carefully considering these factors and implementing the appropriate influence selection methods, you can build high-performance models with minimal annotation effort.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Python Data Visualization: Matplotlib vs Seaborn vs Altair]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/</guid>
      <pubDate>Sat, 12 Apr 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>beginner</category>
      <content:encoded><![CDATA[






<section id="python-data-visualization-matplotlib-vs-seaborn-vs-altair" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/dataviz.jpg" class="img-fluid"></p>
<p>This guide compares three popular Python data visualization libraries: Matplotlib, Seaborn, and Altair (Vega-Altair). Each library has its own strengths, weaknesses, and ideal use cases. This comparison will help you choose the right tool for your specific visualization needs.</p>
<section id="quick-reference-comparison" class="level2">
<h2 class="anchored" data-anchor-id="quick-reference-comparison" id="quick-reference-comparison">Quick Reference Comparison</h2>
<table class="caption-top table">
<colgroup>
<col style="width: 23%">
<col style="width: 31%">
<col style="width: 23%">
<col style="width: 21%">
</colgroup>
<thead>
<tr class="header">
<th>Feature</th>
<th>Matplotlib</th>
<th>Seaborn</th>
<th>Altair</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Release Year</strong></td>
<td>2003</td>
<td>2013</td>
<td>2016</td>
</tr>
<tr class="even">
<td><strong>Foundation</strong></td>
<td>Standalone</td>
<td>Built on Matplotlib</td>
<td>Based on Vega-Lite</td>
</tr>
<tr class="odd">
<td><strong>Philosophy</strong></td>
<td>Imperative</td>
<td>Statistical</td>
<td>Declarative</td>
</tr>
<tr class="even">
<td><strong>Abstraction Level</strong></td>
<td>Low</td>
<td>Medium</td>
<td>High</td>
</tr>
<tr class="odd">
<td><strong>Learning Curve</strong></td>
<td>Steep</td>
<td>Moderate</td>
<td>Gentle</td>
</tr>
<tr class="even">
<td><strong>Code Verbosity</strong></td>
<td>High</td>
<td>Medium</td>
<td>Low</td>
</tr>
<tr class="odd">
<td><strong>Customization</strong></td>
<td>Extensive</td>
<td>Good</td>
<td>Limited</td>
</tr>
<tr class="even">
<td><strong>Statistical Integration</strong></td>
<td>Manual</td>
<td>Built-in</td>
<td>Good</td>
</tr>
<tr class="odd">
<td><strong>Interactive Features</strong></td>
<td>Limited</td>
<td>Limited</td>
<td>Excellent</td>
</tr>
<tr class="even">
<td><strong>Performance with Large Data</strong></td>
<td>Good</td>
<td>Moderate</td>
<td>Limited</td>
</tr>
<tr class="odd">
<td><strong>Community &amp; Resources</strong></td>
<td>Extensive</td>
<td>Good</td>
<td>Growing</td>
</tr>
</tbody>
</table>
</section>
<section id="matplotlib" class="level2">
<h2 class="anchored" data-anchor-id="matplotlib" id="matplotlib">Matplotlib</h2>
<p>Matplotlib is the foundational plotting library in Python’s data visualization ecosystem.</p>
<section id="strengths" class="level3">
<h3 class="anchored" data-anchor-id="strengths" id="strengths">Strengths:</h3>
<ul>
<li><strong>Fine-grained control</strong>: Almost every aspect of a visualization can be customized</li>
<li><strong>Versatility</strong>: Can create virtually any type of static plot</li>
<li><strong>Maturity</strong>: Extensive documentation and community support</li>
<li><strong>Ecosystem integration</strong>: Many libraries integrate with or build upon Matplotlib</li>
<li><strong>Performance</strong>: Handles large datasets well</li>
</ul>
</section>
<section id="weaknesses" class="level3">
<h3 class="anchored" data-anchor-id="weaknesses" id="weaknesses">Weaknesses:</h3>
<ul>
<li><strong>Verbose syntax</strong>: Requires many lines of code for complex visualizations</li>
<li><strong>Steep learning curve</strong>: Many functions and parameters to learn</li>
<li><strong>Default aesthetics</strong>: Basic default styling (though this has improved)</li>
<li><strong>Limited interactivity</strong>: Primarily designed for static plots</li>
</ul>
</section>
<section id="example-code" class="level3">
<h3 class="anchored" data-anchor-id="example-code" id="example-code">Example Code:</h3>
<div id="b201f4e2" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Sample data</span></span>
<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> np.sin(x)</span>
<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Create figure and axis</span></span>
<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>fig, ax <span class="op">=</span> plt.subplots(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">4</span>))</span>
<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Plot data</span></span>
<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>ax.plot(x, y, label<span class="op">=</span><span class="st">'Sine Wave'</span>)</span>
<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Add grid, legend, title and labels</span></span>
<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a>ax.grid(<span class="va">True</span>)</span>
<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>ax.set_xlabel(<span class="st">'X-axis'</span>)</span>
<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>ax.set_ylabel(<span class="st">'Y-axis'</span>)</span>
<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>ax.set_title(<span class="st">'Simple Sine Wave Plot'</span>)</span>
<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>ax.legend()</span>
<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a>plt.tight_layout()</span>
<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-2-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="when-to-use-matplotlib" class="level3">
<h3 class="anchored" data-anchor-id="when-to-use-matplotlib" id="when-to-use-matplotlib">When to use Matplotlib:</h3>
<ul>
<li>You need complete control over every aspect of your visualization</li>
<li>You’re creating complex, publication-quality figures</li>
<li>You’re working with specialized plot types not available in higher-level libraries</li>
<li>You need to integrate with many other Python libraries</li>
<li>You’re working with large datasets</li>
</ul>
</section>
</section>
<section id="seaborn" class="level2">
<h2 class="anchored" data-anchor-id="seaborn" id="seaborn">Seaborn</h2>
<p>Seaborn is a statistical visualization library built on top of Matplotlib.</p>
<section id="strengths-1" class="level3">
<h3 class="anchored" data-anchor-id="strengths-1" id="strengths-1">Strengths:</h3>
<ul>
<li><strong>Aesthetic defaults</strong>: Beautiful out-of-the-box styling</li>
<li><strong>Statistical integration</strong>: Built-in support for statistical visualizations</li>
<li><strong>Dataset awareness</strong>: Works well with pandas DataFrames</li>
<li><strong>Simplicity</strong>: Fewer lines of code than Matplotlib for common plots</li>
<li><strong>High-level functions</strong>: Specialized plots like <code>lmplot</code>, <code>catplot</code>, etc.</li>
</ul>
</section>
<section id="weaknesses-1" class="level3">
<h3 class="anchored" data-anchor-id="weaknesses-1" id="weaknesses-1">Weaknesses:</h3>
<ul>
<li><strong>Limited customization</strong>: Some advanced customizations require falling back to Matplotlib</li>
<li><strong>Performance</strong>: Can be slower with very large datasets</li>
<li><strong>Restricted scope</strong>: Focused on statistical visualization, not general-purpose plotting</li>
</ul>
</section>
<section id="example-code-1" class="level3">
<h3 class="anchored" data-anchor-id="example-code-1" id="example-code-1">Example Code:</h3>
<div id="57ca5f0d" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> np.sin(x) <span class="op">+</span> np.random.normal(<span class="dv">0</span>, <span class="fl">0.2</span>, size<span class="op">=</span><span class="bu">len</span>(x))</span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({<span class="st">'x'</span>: x, <span class="st">'y'</span>: y})</span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Set the aesthetic style</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>sns.set_theme(style<span class="op">=</span><span class="st">"whitegrid"</span>)</span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the plot</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">4</span>))</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>sns.lineplot(data<span class="op">=</span>data, x<span class="op">=</span><span class="st">'x'</span>, y<span class="op">=</span><span class="st">'y'</span>, label<span class="op">=</span><span class="st">'Noisy Sine Wave'</span>)</span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>sns.regplot(data<span class="op">=</span>data, x<span class="op">=</span><span class="st">'x'</span>, y<span class="op">=</span><span class="st">'y'</span>, scatter<span class="op">=</span><span class="va">False</span>, label<span class="op">=</span><span class="st">'Regression Line'</span>)</span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Add title and labels</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Seaborn Line Plot with Regression'</span>)</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'X-axis'</span>)</span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Y-axis'</span>)</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>plt.legend()</span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>plt.tight_layout()</span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-3-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="when-to-use-seaborn" class="level3">
<h3 class="anchored" data-anchor-id="when-to-use-seaborn" id="when-to-use-seaborn">When to use Seaborn:</h3>
<ul>
<li>You want attractive visualizations with minimal code</li>
<li>You’re performing statistical analysis</li>
<li>You’re working with pandas DataFrames</li>
<li>You’re creating common statistical plots (distributions, relationships, categorical plots)</li>
<li>You want the power of Matplotlib with a simpler interface</li>
</ul>
</section>
</section>
<section id="altair-vega-altair" class="level2">
<h2 class="anchored" data-anchor-id="altair-vega-altair" id="altair-vega-altair">Altair (Vega-Altair)</h2>
<p>Altair is a declarative statistical visualization library based on Vega-Lite.</p>
<section id="strengths-2" class="level3">
<h3 class="anchored" data-anchor-id="strengths-2" id="strengths-2">Strengths:</h3>
<ul>
<li><strong>Declarative approach</strong>: Focus on what to visualize, not how to draw it</li>
<li><strong>Concise syntax</strong>: Very readable, clear code</li>
<li><strong>Layered grammar of graphics</strong>: Intuitive composition of plots</li>
<li><strong>Interactive visualizations</strong>: Built-in support for interactive features</li>
<li><strong>JSON output</strong>: Visualizations can be saved as JSON specifications</li>
</ul>
</section>
<section id="weaknesses-2" class="level3">
<h3 class="anchored" data-anchor-id="weaknesses-2" id="weaknesses-2">Weaknesses:</h3>
<ul>
<li><strong>Performance limitations</strong>: Not ideal for very large datasets (&gt;5000 points)</li>
<li><strong>Limited customization</strong>: Less fine-grained control than Matplotlib</li>
<li><strong>Learning curve</strong>: Different paradigm from traditional plotting libraries</li>
<li><strong>Browser dependency</strong>: Uses JavaScript rendering for advanced features</li>
</ul>
</section>
<section id="example-code-2" class="level3">
<h3 class="anchored" data-anchor-id="example-code-2" id="example-code-2">Example Code:</h3>
<div id="9463cc5f" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data</span></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> np.sin(x) <span class="op">+</span> np.random.normal(<span class="dv">0</span>, <span class="fl">0.2</span>, size<span class="op">=</span><span class="bu">len</span>(x))</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({<span class="st">'x'</span>: x, <span class="st">'y'</span>: y})</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a simple scatter plot with interactive tooltips</span></span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>chart <span class="op">=</span> alt.Chart(data).mark_circle().encode(</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'x'</span>,</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'y'</span>,</span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    tooltip<span class="op">=</span>[<span class="st">'x'</span>, <span class="st">'y'</span>]</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">600</span>,</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">300</span>,</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Interactive Altair Scatter Plot'</span></span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>).interactive()</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a><span class="co"># Add a regression line</span></span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>regression <span class="op">=</span> alt.Chart(data).transform_regression(</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>, <span class="st">'y'</span></span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>).mark_line(color<span class="op">=</span><span class="st">'red'</span>).encode(</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'x'</span>,</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'y'</span></span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a><span class="co"># Combine the plots</span></span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>final_chart <span class="op">=</span> chart <span class="op">+</span> regression</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a><span class="co"># Display the chart</span></span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>final_chart</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="3">

<style>
  #altair-viz-991288cc63af46c985cf5f7ce0db3def.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-991288cc63af46c985cf5f7ce0db3def.vega-embed details,
  #altair-viz-991288cc63af46c985cf5f7ce0db3def.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-991288cc63af46c985cf5f7ce0db3def"></div>

</div>
</div>
</section>
<section id="when-to-use-altair" class="level3">
<h3 class="anchored" data-anchor-id="when-to-use-altair" id="when-to-use-altair">When to use Altair:</h3>
<ul>
<li>You want interactive visualizations</li>
<li>You prefer a declarative approach to visualization</li>
<li>You’re working with small to medium-sized datasets</li>
<li>You want to publish visualizations on the web</li>
<li>You appreciate a consistent grammar of graphics</li>
</ul>
</section>
</section>
<section id="common-visualization-types-comparison" class="level2">
<h2 class="anchored" data-anchor-id="common-visualization-types-comparison" id="common-visualization-types-comparison">Common Visualization Types Comparison</h2>
<section id="scatter-plot" class="level3">
<h3 class="anchored" data-anchor-id="scatter-plot" id="scatter-plot">Scatter Plot</h3>
<p><strong>Matplotlib:</strong></p>
<div id="d794ed88" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.random.randn(<span class="dv">100</span>)</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a>y <span class="op">=</span> np.random.randn(<span class="dv">100</span>)</span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">6</span>))</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>plt.scatter(x, y, alpha<span class="op">=</span><span class="fl">0.7</span>)</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Matplotlib Scatter Plot'</span>)</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'X-axis'</span>)</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Y-axis'</span>)</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>plt.grid(<span class="va">True</span>)</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-5-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Seaborn:</strong></p>
<div id="ccfdb3ff" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>: np.random.randn(<span class="dv">100</span>),</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">'y'</span>: np.random.randn(<span class="dv">100</span>)</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>sns.set_theme(style<span class="op">=</span><span class="st">"whitegrid"</span>)</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">6</span>))</span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>sns.scatterplot(data<span class="op">=</span>data, x<span class="op">=</span><span class="st">'x'</span>, y<span class="op">=</span><span class="st">'y'</span>, alpha<span class="op">=</span><span class="fl">0.7</span>)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Seaborn Scatter Plot'</span>)</span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-6-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Altair:</strong></p>
<div id="9a398785" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({</span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>: np.random.randn(<span class="dv">100</span>),</span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'y'</span>: np.random.randn(<span class="dv">100</span>)</span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>alt.Chart(data).mark_circle(opacity<span class="op">=</span><span class="fl">0.7</span>).encode(</span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'x'</span>,</span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'y'</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb6-14"><a href="#cb6-14" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">500</span>,</span>
<span id="cb6-15"><a href="#cb6-15" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">400</span>,</span>
<span id="cb6-16"><a href="#cb6-16" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Altair Scatter Plot'</span></span>
<span id="cb6-17"><a href="#cb6-17" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="6">

<style>
  #altair-viz-b42b2251fa444ff5a680a721c408dbca.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-b42b2251fa444ff5a680a721c408dbca.vega-embed details,
  #altair-viz-b42b2251fa444ff5a680a721c408dbca.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-b42b2251fa444ff5a680a721c408dbca"></div>

</div>
</div>
</section>
<section id="histogram" class="level3">
<h3 class="anchored" data-anchor-id="histogram" id="histogram">Histogram</h3>
<p><strong>Matplotlib:</strong></p>
<div id="17e660cb" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> np.random.randn(<span class="dv">1000</span>)</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">6</span>))</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>plt.hist(data, bins<span class="op">=</span><span class="dv">30</span>, alpha<span class="op">=</span><span class="fl">0.7</span>, edgecolor<span class="op">=</span><span class="st">'black'</span>)</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Matplotlib Histogram'</span>)</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'Value'</span>)</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Frequency'</span>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a>plt.grid(<span class="va">True</span>, alpha<span class="op">=</span><span class="fl">0.3</span>)</span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-8-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Seaborn:</strong></p>
<div id="6f652b74" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> np.random.randn(<span class="dv">1000</span>)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>sns.set_theme(style<span class="op">=</span><span class="st">"whitegrid"</span>)</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">8</span>, <span class="dv">6</span>))</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>sns.histplot(data<span class="op">=</span>data, bins<span class="op">=</span><span class="dv">30</span>, kde<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Seaborn Histogram with KDE'</span>)</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-9-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Altair:</strong></p>
<div id="3006688e" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({<span class="st">'value'</span>: np.random.randn(<span class="dv">1000</span>)})</span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>alt.Chart(data).mark_bar().encode(</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    alt.X(<span class="st">'value'</span>, <span class="bu">bin</span><span class="op">=</span>alt.Bin(maxbins<span class="op">=</span><span class="dv">30</span>)),</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'count()'</span></span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb9-11"><a href="#cb9-11" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">500</span>,</span>
<span id="cb9-12"><a href="#cb9-12" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">400</span>,</span>
<span id="cb9-13"><a href="#cb9-13" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Altair Histogram'</span></span>
<span id="cb9-14"><a href="#cb9-14" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="9">

<style>
  #altair-viz-656ea03a9c4449ecad82850f7516dda9.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-656ea03a9c4449ecad82850f7516dda9.vega-embed details,
  #altair-viz-656ea03a9c4449ecad82850f7516dda9.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-656ea03a9c4449ecad82850f7516dda9"></div>

</div>
</div>
</section>
<section id="line-plot" class="level3">
<h3 class="anchored" data-anchor-id="line-plot" id="line-plot">Line Plot</h3>
<p><strong>Matplotlib:</strong></p>
<div id="90f4214d" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>y1 <span class="op">=</span> np.sin(x)</span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>y2 <span class="op">=</span> np.cos(x)</span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>plt.plot(x, y1, label<span class="op">=</span><span class="st">'Sine'</span>)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>plt.plot(x, y2, label<span class="op">=</span><span class="st">'Cosine'</span>)</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Matplotlib Line Plot'</span>)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'X-axis'</span>)</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Y-axis'</span>)</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>plt.legend()</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>plt.grid(<span class="va">True</span>)</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-11-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Seaborn:</strong></p>
<div id="9aab94c0" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>: np.concatenate([x, x]),</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'y'</span>: np.concatenate([np.sin(x), np.cos(x)]),</span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'function'</span>: [<span class="st">'Sine'</span>]<span class="op">*</span><span class="dv">100</span> <span class="op">+</span> [<span class="st">'Cosine'</span>]<span class="op">*</span><span class="dv">100</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a>sns.set_theme(style<span class="op">=</span><span class="st">"darkgrid"</span>)</span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>sns.lineplot(data<span class="op">=</span>data, x<span class="op">=</span><span class="st">'x'</span>, y<span class="op">=</span><span class="st">'y'</span>, hue<span class="op">=</span><span class="st">'function'</span>)</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Seaborn Line Plot'</span>)</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-12-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Altair:</strong></p>
<div id="b2c4b3da" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>x <span class="op">=</span> np.linspace(<span class="dv">0</span>, <span class="dv">10</span>, <span class="dv">100</span>)</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>: np.concatenate([x, x]),</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">'y'</span>: np.concatenate([np.sin(x), np.cos(x)]),</span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'function'</span>: [<span class="st">'Sine'</span>]<span class="op">*</span><span class="dv">100</span> <span class="op">+</span> [<span class="st">'Cosine'</span>]<span class="op">*</span><span class="dv">100</span></span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>alt.Chart(data).mark_line().encode(</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'x'</span>,</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'y'</span>,</span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a>    color<span class="op">=</span><span class="st">'function'</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">600</span>,</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">400</span>,</span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Altair Line Plot'</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="12">

<style>
  #altair-viz-2cf1f8245f09403f837e5648cc5f0fd2.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-2cf1f8245f09403f837e5648cc5f0fd2.vega-embed details,
  #altair-viz-2cf1f8245f09403f837e5648cc5f0fd2.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-2cf1f8245f09403f837e5648cc5f0fd2"></div>

</div>
</div>
</section>
<section id="heatmap" class="level3">
<h3 class="anchored" data-anchor-id="heatmap" id="heatmap">Heatmap</h3>
<p><strong>Matplotlib:</strong></p>
<div id="ef0912a7" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> np.random.rand(<span class="dv">10</span>, <span class="dv">12</span>)</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">8</span>))</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>plt.imshow(data, cmap<span class="op">=</span><span class="st">'viridis'</span>)</span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>plt.colorbar(label<span class="op">=</span><span class="st">'Value'</span>)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Matplotlib Heatmap'</span>)</span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>plt.xlabel(<span class="st">'X-axis'</span>)</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>plt.ylabel(<span class="st">'Y-axis'</span>)</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-14-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Seaborn:</strong></p>
<div id="b8a21d8e" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> np.random.rand(<span class="dv">10</span>, <span class="dv">12</span>)</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a>plt.figure(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">8</span>))</span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a>sns.heatmap(data, annot<span class="op">=</span><span class="va">True</span>, cmap<span class="op">=</span><span class="st">'viridis'</span>, fmt<span class="op">=</span><span class="st">'.2f'</span>)</span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>plt.title(<span class="st">'Seaborn Heatmap'</span>)</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-15-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p><strong>Altair:</strong></p>
<div id="386dc776" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data</span></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> np.random.rand(<span class="dv">10</span>, <span class="dv">12</span>)</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a>df <span class="op">=</span> pd.DataFrame(data)</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Reshape for Altair</span></span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a>df_long <span class="op">=</span> df.reset_index().melt(id_vars<span class="op">=</span><span class="st">'index'</span>)</span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a>df_long.columns <span class="op">=</span> [<span class="st">'y'</span>, <span class="st">'x'</span>, <span class="st">'value'</span>]</span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>alt.Chart(df_long).mark_rect().encode(</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'x:O'</span>,</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'y:O'</span>,</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>    color<span class="op">=</span><span class="st">'value:Q'</span></span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">500</span>,</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">400</span>,</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Altair Heatmap'</span></span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="15">

<style>
  #altair-viz-092114f6982d4160aa0b64954c1aa22a.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-092114f6982d4160aa0b64954c1aa22a.vega-embed details,
  #altair-viz-092114f6982d4160aa0b64954c1aa22a.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-092114f6982d4160aa0b64954c1aa22a"></div>

</div>
</div>
</section>
</section>
<section id="decision-framework-for-choosing-a-library" class="level2">
<h2 class="anchored" data-anchor-id="decision-framework-for-choosing-a-library" id="decision-framework-for-choosing-a-library">Decision Framework for Choosing a Library</h2>
<section id="choose-matplotlib-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-matplotlib-when" id="choose-matplotlib-when">Choose Matplotlib when:</h3>
<ul>
<li>You need complete control over every detail of your visualization</li>
<li>You’re creating complex, custom plots</li>
<li>Your visualizations will be included in scientific publications</li>
<li>You’re working with very large datasets</li>
<li>You need to create animations or specialized chart types</li>
</ul>
</section>
<section id="choose-seaborn-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-seaborn-when" id="choose-seaborn-when">Choose Seaborn when:</h3>
<ul>
<li>You want attractive plots with minimal code</li>
<li>You’re performing statistical analysis</li>
<li>You want to create common statistical plots quickly</li>
<li>You need to visualize relationships between variables</li>
<li>You want good-looking defaults but still need some customization</li>
</ul>
</section>
<section id="choose-altair-when" class="level3">
<h3 class="anchored" data-anchor-id="choose-altair-when" id="choose-altair-when">Choose Altair when:</h3>
<ul>
<li>You want interactive visualizations</li>
<li>You prefer a declarative approach to visualization</li>
<li>You want concise, readable code</li>
<li>You’re creating dashboards or web-based visualizations</li>
<li>You’re working with small to medium-sized datasets</li>
</ul>
</section>
</section>
<section id="integration-examples" class="level2">
<h2 class="anchored" data-anchor-id="integration-examples" id="integration-examples">Integration Examples</h2>
<section id="combining-seaborn-with-matplotlib" class="level3">
<h3 class="anchored" data-anchor-id="combining-seaborn-with-matplotlib" id="combining-seaborn-with-matplotlib">Combining Seaborn with Matplotlib:</h3>
<div id="dc3e330c" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> matplotlib.pyplot <span class="im">as</span> plt</span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> seaborn <span class="im">as</span> sns</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data</span></span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a>np.random.seed(<span class="dv">42</span>)</span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> pd.DataFrame({</span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'x'</span>: np.random.normal(<span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">100</span>),</span>
<span id="cb16-10"><a href="#cb16-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'y'</span>: np.random.normal(<span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">100</span>),</span>
<span id="cb16-11"><a href="#cb16-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">'category'</span>: np.random.choice([<span class="st">'A'</span>, <span class="st">'B'</span>, <span class="st">'C'</span>], <span class="dv">100</span>)</span>
<span id="cb16-12"><a href="#cb16-12" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb16-13"><a href="#cb16-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-14"><a href="#cb16-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Create a figure with Matplotlib</span></span>
<span id="cb16-15"><a href="#cb16-15" aria-hidden="true" tabindex="-1"></a>fig, ax <span class="op">=</span> plt.subplots(figsize<span class="op">=</span>(<span class="dv">10</span>, <span class="dv">6</span>))</span>
<span id="cb16-16"><a href="#cb16-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-17"><a href="#cb16-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Use Seaborn for the main plot</span></span>
<span id="cb16-18"><a href="#cb16-18" aria-hidden="true" tabindex="-1"></a>sns.scatterplot(data<span class="op">=</span>data, x<span class="op">=</span><span class="st">'x'</span>, y<span class="op">=</span><span class="st">'y'</span>, hue<span class="op">=</span><span class="st">'category'</span>, ax<span class="op">=</span>ax)</span>
<span id="cb16-19"><a href="#cb16-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-20"><a href="#cb16-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Add Matplotlib customizations</span></span>
<span id="cb16-21"><a href="#cb16-21" aria-hidden="true" tabindex="-1"></a>ax.set_title(<span class="st">'Combining Matplotlib and Seaborn'</span>, fontsize<span class="op">=</span><span class="dv">16</span>)</span>
<span id="cb16-22"><a href="#cb16-22" aria-hidden="true" tabindex="-1"></a>ax.grid(<span class="va">True</span>, linestyle<span class="op">=</span><span class="st">'--'</span>, alpha<span class="op">=</span><span class="fl">0.7</span>)</span>
<span id="cb16-23"><a href="#cb16-23" aria-hidden="true" tabindex="-1"></a>ax.set_xlabel(<span class="st">'X Variable'</span>, fontsize<span class="op">=</span><span class="dv">12</span>)</span>
<span id="cb16-24"><a href="#cb16-24" aria-hidden="true" tabindex="-1"></a>ax.set_ylabel(<span class="st">'Y Variable'</span>, fontsize<span class="op">=</span><span class="dv">12</span>)</span>
<span id="cb16-25"><a href="#cb16-25" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-26"><a href="#cb16-26" aria-hidden="true" tabindex="-1"></a><span class="co"># Add annotations using Matplotlib</span></span>
<span id="cb16-27"><a href="#cb16-27" aria-hidden="true" tabindex="-1"></a>ax.annotate(<span class="st">'Interesting Point'</span>, xy<span class="op">=</span>(<span class="op">-</span><span class="dv">1</span>, <span class="dv">1</span>), xytext<span class="op">=</span>(<span class="op">-</span><span class="dv">2</span>, <span class="fl">1.5</span>),</span>
<span id="cb16-28"><a href="#cb16-28" aria-hidden="true" tabindex="-1"></a>            arrowprops<span class="op">=</span><span class="bu">dict</span>(facecolor<span class="op">=</span><span class="st">'black'</span>, shrink<span class="op">=</span><span class="fl">0.05</span>))</span>
<span id="cb16-29"><a href="#cb16-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-30"><a href="#cb16-30" aria-hidden="true" tabindex="-1"></a>plt.tight_layout()</span>
<span id="cb16-31"><a href="#cb16-31" aria-hidden="true" tabindex="-1"></a>plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://theja-vanka.github.io/blogs/posts/data-visualization-tutorial/cell-17-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="using-altair-with-pandas" class="level3">
<h3 class="anchored" data-anchor-id="using-altair-with-pandas" id="using-altair-with-pandas">Using Altair with Pandas:</h3>
<div id="bab89b6a" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> altair <span class="im">as</span> alt</span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Create sample data with pandas</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>np.random.seed(<span class="dv">42</span>)</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>df <span class="op">=</span> pd.DataFrame({</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>    <span class="st">'date'</span>: pd.date_range(<span class="st">'2023-01-01'</span>, periods<span class="op">=</span><span class="dv">100</span>),</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'value'</span>: np.cumsum(np.random.randn(<span class="dv">100</span>)),</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'category'</span>: np.random.choice([<span class="st">'Group A'</span>, <span class="st">'Group B'</span>], <span class="dv">100</span>)</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Use pandas to prepare the data</span></span>
<span id="cb17-14"><a href="#cb17-14" aria-hidden="true" tabindex="-1"></a>df[<span class="st">'month'</span>] <span class="op">=</span> df[<span class="st">'date'</span>].dt.month</span>
<span id="cb17-15"><a href="#cb17-15" aria-hidden="true" tabindex="-1"></a>monthly_avg <span class="op">=</span> df.groupby([<span class="st">'month'</span>, <span class="st">'category'</span>])[<span class="st">'value'</span>].mean().reset_index()</span>
<span id="cb17-16"><a href="#cb17-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-17"><a href="#cb17-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Create the Altair visualization</span></span>
<span id="cb17-18"><a href="#cb17-18" aria-hidden="true" tabindex="-1"></a>chart <span class="op">=</span> alt.Chart(monthly_avg).mark_line(point<span class="op">=</span><span class="va">True</span>).encode(</span>
<span id="cb17-19"><a href="#cb17-19" aria-hidden="true" tabindex="-1"></a>    x<span class="op">=</span><span class="st">'month:O'</span>,</span>
<span id="cb17-20"><a href="#cb17-20" aria-hidden="true" tabindex="-1"></a>    y<span class="op">=</span><span class="st">'value:Q'</span>,</span>
<span id="cb17-21"><a href="#cb17-21" aria-hidden="true" tabindex="-1"></a>    color<span class="op">=</span><span class="st">'category:N'</span>,</span>
<span id="cb17-22"><a href="#cb17-22" aria-hidden="true" tabindex="-1"></a>    tooltip<span class="op">=</span>[<span class="st">'month'</span>, <span class="st">'value'</span>, <span class="st">'category'</span>]</span>
<span id="cb17-23"><a href="#cb17-23" aria-hidden="true" tabindex="-1"></a>).properties(</span>
<span id="cb17-24"><a href="#cb17-24" aria-hidden="true" tabindex="-1"></a>    width<span class="op">=</span><span class="dv">600</span>,</span>
<span id="cb17-25"><a href="#cb17-25" aria-hidden="true" tabindex="-1"></a>    height<span class="op">=</span><span class="dv">400</span>,</span>
<span id="cb17-26"><a href="#cb17-26" aria-hidden="true" tabindex="-1"></a>    title<span class="op">=</span><span class="st">'Monthly Averages by Category'</span></span>
<span id="cb17-27"><a href="#cb17-27" aria-hidden="true" tabindex="-1"></a>).interactive()</span>
<span id="cb17-28"><a href="#cb17-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-29"><a href="#cb17-29" aria-hidden="true" tabindex="-1"></a>chart</span></code></pre></div></div>
<div class="cell-output cell-output-display" data-execution_count="17">

<style>
  #altair-viz-702aa36247d1433788278cb6741531f5.vega-embed {
    width: 100%;
    display: flex;
  }

  #altair-viz-702aa36247d1433788278cb6741531f5.vega-embed details,
  #altair-viz-702aa36247d1433788278cb6741531f5.vega-embed details summary {
    position: relative;
  }
</style>
<div id="altair-viz-702aa36247d1433788278cb6741531f5"></div>

</div>
</div>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<p>For libraries like Matplotlib, Seaborn, and Altair, performance can vary widely depending on the size of your dataset and the complexity of your visualization. Here’s a general overview:</p>
<section id="small-datasets-1000-points" class="level3">
<h3 class="anchored" data-anchor-id="small-datasets-1000-points" id="small-datasets-1000-points">Small Datasets (&lt; 1,000 points):</h3>
<ul>
<li>All three libraries perform well</li>
<li>Altair might have slightly more overhead due to its JSON specification generation</li>
</ul>
</section>
<section id="medium-datasets-1000---10000-points" class="level3">
<h3 class="anchored" data-anchor-id="medium-datasets-1000---10000-points" id="medium-datasets-1000---10000-points">Medium Datasets (1,000 - 10,000 points):</h3>
<ul>
<li>Matplotlib and Seaborn continue to perform well</li>
<li>Altair starts to slow down but remains usable</li>
</ul>
</section>
<section id="large-datasets-10000-points" class="level3">
<h3 class="anchored" data-anchor-id="large-datasets-10000-points" id="large-datasets-10000-points">Large Datasets (&gt; 10,000 points):</h3>
<ul>
<li>Matplotlib performs best for large static visualizations</li>
<li>Seaborn becomes slower as it adds statistical computations</li>
<li>Altair significantly slows down and may require data aggregation</li>
</ul>
</section>
<section id="recommended-approaches-for-large-data" class="level3">
<h3 class="anchored" data-anchor-id="recommended-approaches-for-large-data" id="recommended-approaches-for-large-data">Recommended Approaches for Large Data:</h3>
<ol type="1">
<li><strong>Matplotlib</strong>: Use <code>plot()</code> instead of <code>scatter()</code> for line plots, or try <code>hexbin()</code> for density plots</li>
<li><strong>Seaborn</strong>: Use <code>sample()</code> or aggregation methods before plotting</li>
<li><strong>Altair</strong>: Use <code>transform_sample()</code> or pre-aggregate your data</li>
</ol>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>The Python visualization ecosystem offers tools for every need, from low-level control to high-level abstraction:</p>
<ul>
<li><strong>Matplotlib</strong> provides ultimate flexibility and control but requires more code and knowledge</li>
<li><strong>Seaborn</strong> offers a perfect middle ground with statistical integration and clean defaults</li>
<li><strong>Altair</strong> delivers a concise, declarative approach with built-in interactivity</li>
</ul>
<p>Rather than picking just one library, consider becoming familiar with all three and selecting the right tool for each visualization task. Many data scientists use a combination of these libraries, leveraging the strengths of each one as needed.</p>
<p>For those just starting, Seaborn provides a gentle entry point with attractive results for common visualization needs. As your skills advance, you can incorporate Matplotlib for customization and Altair for interactive visualizations.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[From Pandas to Polars]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/pandas-to-polars/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/pandas-to-polars/</guid>
      <pubDate>Sat, 05 Apr 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="from-pandas-to-polars" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/pandas-to-polars/pvp.jpg" class="img-fluid"></p>
<p>As datasets grow in size and complexity, performance and efficiency become critical in data processing. While Pandas has long been the go-to library for data manipulation in Python, it can struggle with speed and memory usage, especially on large datasets. Polars, a newer DataFrame library written in Rust, offers a faster, more memory-efficient alternative with support for lazy evaluation and multi-threading.</p>
<p>This guide explores how to convert Pandas DataFrames to Polars, and highlights key differences in syntax, performance, and functionality. Whether you’re looking to speed up your data workflows or just exploring modern tools, understanding the transition from Pandas to Polars is a valuable step.</p>
<section id="installation-and-setup" class="level2">
<h2 class="anchored" data-anchor-id="installation-and-setup" id="installation-and-setup">Installation and Setup</h2>
<section id="pandas" class="level3">
<h3 class="anchored" data-anchor-id="pandas" id="pandas">Pandas</h3>
<div id="3de22b60" class="cell" data-execution_count="1">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Import pandas</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span></code></pre></div></div>
</div>
</section>
<section id="polars" class="level3">
<h3 class="anchored" data-anchor-id="polars" id="polars">Polars</h3>
<div id="d6c40551" class="cell" data-execution_count="2">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Import polars</span></span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> polars <span class="im">as</span> pl</span></code></pre></div></div>
</div>
</section>
</section>
<section id="creating-dataframes" class="level2">
<h2 class="anchored" data-anchor-id="creating-dataframes" id="creating-dataframes">Creating DataFrames</h2>
<section id="from-dictionaries" class="level3">
<h3 class="anchored" data-anchor-id="from-dictionaries" id="from-dictionaries">From dictionaries</h3>
<section id="pandas-1" class="level4">
<h4 class="anchored" data-anchor-id="pandas-1">Pandas</h4>
<div id="76669000" class="cell" data-execution_count="3">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Create DataFrame from dictionary</span></span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> {</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'name'</span>: [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>],</span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'age'</span>: [<span class="dv">25</span>, <span class="dv">30</span>, <span class="dv">35</span>, <span class="dv">40</span>],</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'city'</span>: [<span class="st">'New York'</span>, <span class="st">'Los Angeles'</span>, <span class="st">'Chicago'</span>, <span class="st">'Houston'</span>]</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>df_pd <span class="op">=</span> pd.DataFrame(data)</span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(df_pd)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>      name  age         city
0    Alice   25     New York
1      Bob   30  Los Angeles
2  Charlie   35      Chicago
3    David   40      Houston</code></pre>
</div>
</div>
</section>
<section id="polars-1" class="level4">
<h4 class="anchored" data-anchor-id="polars-1">Polars</h4>
<div id="d64ed3f4" class="cell" data-execution_count="4">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> polars <span class="im">as</span> pl</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Create DataFrame from dictionary</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> {</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'name'</span>: [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>],</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'age'</span>: [<span class="dv">25</span>, <span class="dv">30</span>, <span class="dv">35</span>, <span class="dv">40</span>],</span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'city'</span>: [<span class="st">'New York'</span>, <span class="st">'Los Angeles'</span>, <span class="st">'Chicago'</span>, <span class="st">'Houston'</span>]</span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> pl.DataFrame(data)</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(df_pl)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>shape: (4, 3)
┌─────────┬─────┬─────────────┐
│ name    ┆ age ┆ city        │
│ ---     ┆ --- ┆ ---         │
│ str     ┆ i64 ┆ str         │
╞═════════╪═════╪═════════════╡
│ Alice   ┆ 25  ┆ New York    │
│ Bob     ┆ 30  ┆ Los Angeles │
│ Charlie ┆ 35  ┆ Chicago     │
│ David   ┆ 40  ┆ Houston     │
└─────────┴─────┴─────────────┘</code></pre>
</div>
</div>
</section>
</section>
</section>
<section id="basic-operations" class="level2">
<h2 class="anchored" data-anchor-id="basic-operations" id="basic-operations">Basic Operations</h2>
<section id="selecting-columns" class="level3">
<h3 class="anchored" data-anchor-id="selecting-columns" id="selecting-columns">Selecting columns</h3>
<section id="pandas-2" class="level4">
<h4 class="anchored" data-anchor-id="pandas-2">Pandas</h4>
<div id="b6735410" class="cell" data-execution_count="5">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Select a single column (returns Series)</span></span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a>series <span class="op">=</span> df_pd[<span class="st">'name'</span>]</span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Select multiple columns</span></span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>df_subset <span class="op">=</span> df_pd[[<span class="st">'name'</span>, <span class="st">'age'</span>]]</span></code></pre></div></div>
</div>
</section>
<section id="polars-2" class="level4">
<h4 class="anchored" data-anchor-id="polars-2">Polars</h4>
<div id="525d9c36" class="cell" data-execution_count="6">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Select a single column (returns Series)</span></span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a>series <span class="op">=</span> df_pl[<span class="st">'name'</span>]</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative method</span></span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>series <span class="op">=</span> df_pl.select(pl.col(<span class="st">'name'</span>)).to_series()</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Select multiple columns</span></span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>df_subset <span class="op">=</span> df_pl.select([<span class="st">'name'</span>, <span class="st">'age'</span>])</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Alternative method</span></span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>df_subset <span class="op">=</span> df_pl.select(pl.col([<span class="st">'name'</span>, <span class="st">'age'</span>]))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="adding-a-new-column" class="level3">
<h3 class="anchored" data-anchor-id="adding-a-new-column" id="adding-a-new-column">Adding a new column</h3>
<section id="pandas-3" class="level4">
<h4 class="anchored" data-anchor-id="pandas-3">Pandas</h4>
<div id="42c1c1d2" class="cell" data-execution_count="7">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add a new column</span></span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'is_adult'</span>] <span class="op">=</span> df_pd[<span class="st">'age'</span>] <span class="op">&gt;=</span> <span class="dv">18</span></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Using assign (creates a new DataFrame)</span></span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a>df_pd <span class="op">=</span> df_pd.assign(age_squared<span class="op">=</span>df_pd[<span class="st">'age'</span>] <span class="op">**</span> <span class="dv">2</span>)</span></code></pre></div></div>
</div>
</section>
<section id="polars-3" class="level4">
<h4 class="anchored" data-anchor-id="polars-3">Polars</h4>
<div id="9dfee089" class="cell" data-execution_count="8">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add a new column</span></span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    pl.when(pl.col(<span class="st">'age'</span>) <span class="op">&gt;=</span> <span class="dv">18</span>).then(<span class="va">True</span>).otherwise(<span class="va">False</span>).alias(<span class="st">'is_adult'</span>)</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Creating derived columns </span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    (pl.col(<span class="st">'age'</span>) <span class="op">**</span> <span class="dv">2</span>).alias(<span class="st">'age_squared'</span>)</span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple columns at once</span></span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns([</span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).is_null().alias(<span class="st">'age_is_null'</span>),</span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    (pl.col(<span class="st">'age'</span>) <span class="op">*</span> <span class="dv">2</span>).alias(<span class="st">'age_doubled'</span>)</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</div>
</section>
</section>
<section id="basic-statistics" class="level3">
<h3 class="anchored" data-anchor-id="basic-statistics" id="basic-statistics">Basic statistics</h3>
<section id="pandas-4" class="level4">
<h4 class="anchored" data-anchor-id="pandas-4">Pandas</h4>
<div id="1f7b4b16" class="cell" data-execution_count="9">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Get summary statistics</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>summary <span class="op">=</span> df_pd.describe()</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Individual statistics</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>mean_age <span class="op">=</span> df_pd[<span class="st">'age'</span>].mean()</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a>median_age <span class="op">=</span> df_pd[<span class="st">'age'</span>].median()</span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>min_age <span class="op">=</span> df_pd[<span class="st">'age'</span>].<span class="bu">min</span>()</span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>max_age <span class="op">=</span> df_pd[<span class="st">'age'</span>].<span class="bu">max</span>()</span></code></pre></div></div>
</div>
</section>
<section id="polars-4" class="level4">
<h4 class="anchored" data-anchor-id="polars-4">Polars</h4>
<div id="f5721d75" class="cell" data-execution_count="10">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Get summary statistics</span></span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a>summary <span class="op">=</span> df_pl.describe()</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Individual statistics</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>mean_age <span class="op">=</span> df_pl.select(pl.col(<span class="st">'age'</span>).mean()).item()</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>median_age <span class="op">=</span> df_pl.select(pl.col(<span class="st">'age'</span>).median()).item()</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a>min_age <span class="op">=</span> df_pl.select(pl.col(<span class="st">'age'</span>).<span class="bu">min</span>()).item()</span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a>max_age <span class="op">=</span> df_pl.select(pl.col(<span class="st">'age'</span>).<span class="bu">max</span>()).item()</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="filtering-data" class="level2">
<h2 class="anchored" data-anchor-id="filtering-data" id="filtering-data">Filtering Data</h2>
<section id="simple-filtering" class="level3">
<h3 class="anchored" data-anchor-id="simple-filtering" id="simple-filtering">Simple filtering</h3>
<section id="pandas-5" class="level4">
<h4 class="anchored" data-anchor-id="pandas-5">Pandas</h4>
<div id="d73bd2e5" class="cell" data-execution_count="11">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter rows</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>adults <span class="op">=</span> df_pd[df_pd[<span class="st">'age'</span>] <span class="op">&gt;=</span> <span class="dv">18</span>]</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple conditions</span></span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> df_pd[(df_pd[<span class="st">'age'</span>] <span class="op">&gt;</span> <span class="dv">30</span>) <span class="op">&amp;</span> (df_pd[<span class="st">'city'</span>] <span class="op">==</span> <span class="st">'Chicago'</span>)]</span></code></pre></div></div>
</div>
</section>
<section id="polars-5" class="level4">
<h4 class="anchored" data-anchor-id="polars-5">Polars</h4>
<div id="e8f63ea9" class="cell" data-execution_count="12">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter rows</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>adults <span class="op">=</span> df_pl.<span class="bu">filter</span>(pl.col(<span class="st">'age'</span>) <span class="op">&gt;=</span> <span class="dv">18</span>)</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple conditions</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> df_pl.<span class="bu">filter</span>((pl.col(<span class="st">'age'</span>) <span class="op">&gt;</span> <span class="dv">30</span>) <span class="op">&amp;</span> (pl.col(<span class="st">'city'</span>) <span class="op">==</span> <span class="st">'Chicago'</span>))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="complex-filtering" class="level3">
<h3 class="anchored" data-anchor-id="complex-filtering" id="complex-filtering">Complex filtering</h3>
<section id="pandas-6" class="level4">
<h4 class="anchored" data-anchor-id="pandas-6">Pandas</h4>
<div id="7edfef2b" class="cell" data-execution_count="13">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter with OR conditions</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pd[(df_pd[<span class="st">'city'</span>] <span class="op">==</span> <span class="st">'New York'</span>) <span class="op">|</span> (df_pd[<span class="st">'city'</span>] <span class="op">==</span> <span class="st">'Chicago'</span>)]</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Using isin</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a>cities <span class="op">=</span> [<span class="st">'New York'</span>, <span class="st">'Chicago'</span>]</span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pd[df_pd[<span class="st">'city'</span>].isin(cities)]</span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a><span class="co"># String contains</span></span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pd[df_pd[<span class="st">'name'</span>].<span class="bu">str</span>.contains(<span class="st">'li'</span>)]</span></code></pre></div></div>
</div>
</section>
<section id="polars-6" class="level4">
<h4 class="anchored" data-anchor-id="polars-6">Polars</h4>
<div id="30dfa1f6" class="cell" data-execution_count="14">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><a href="#cb16-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Filter with OR conditions</span></span>
<span id="cb16-2"><a href="#cb16-2" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pl.<span class="bu">filter</span>((pl.col(<span class="st">'city'</span>) <span class="op">==</span> <span class="st">'New York'</span>) <span class="op">|</span> (pl.col(<span class="st">'city'</span>) <span class="op">==</span> <span class="st">'Chicago'</span>))</span>
<span id="cb16-3"><a href="#cb16-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-4"><a href="#cb16-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Using is_in</span></span>
<span id="cb16-5"><a href="#cb16-5" aria-hidden="true" tabindex="-1"></a>cities <span class="op">=</span> [<span class="st">'New York'</span>, <span class="st">'Chicago'</span>]</span>
<span id="cb16-6"><a href="#cb16-6" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pl.<span class="bu">filter</span>(pl.col(<span class="st">'city'</span>).is_in(cities))</span>
<span id="cb16-7"><a href="#cb16-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb16-8"><a href="#cb16-8" aria-hidden="true" tabindex="-1"></a><span class="co"># String contains</span></span>
<span id="cb16-9"><a href="#cb16-9" aria-hidden="true" tabindex="-1"></a>df_filtered <span class="op">=</span> df_pl.<span class="bu">filter</span>(pl.col(<span class="st">'name'</span>).<span class="bu">str</span>.contains(<span class="st">'li'</span>))</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="grouping-and-aggregation" class="level2">
<h2 class="anchored" data-anchor-id="grouping-and-aggregation" id="grouping-and-aggregation">Grouping and Aggregation</h2>
<section id="basic-groupby" class="level3">
<h3 class="anchored" data-anchor-id="basic-groupby" id="basic-groupby">Basic groupby</h3>
<section id="pandas-7" class="level4">
<h4 class="anchored" data-anchor-id="pandas-7">Pandas</h4>
<div id="3edc9d87" class="cell" data-execution_count="15">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Group by one column and aggregate</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a>city_stats <span class="op">=</span> df_pd.groupby(<span class="st">'city'</span>).agg({</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'age'</span>: [<span class="st">'mean'</span>, <span class="st">'min'</span>, <span class="st">'max'</span>, <span class="st">'count'</span>]</span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Reset index for flat DataFrame</span></span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>city_stats <span class="op">=</span> city_stats.reset_index()</span></code></pre></div></div>
</div>
</section>
<section id="polars-7" class="level4">
<h4 class="anchored" data-anchor-id="polars-7">Polars</h4>
<div id="06ad6d3a" class="cell" data-execution_count="16">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Group by one column and aggregate</span></span>
<span id="cb18-2"><a href="#cb18-2" aria-hidden="true" tabindex="-1"></a>city_stats <span class="op">=</span> df_pl.group_by(<span class="st">'city'</span>).agg([</span>
<span id="cb18-3"><a href="#cb18-3" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).mean().alias(<span class="st">'age_mean'</span>),</span>
<span id="cb18-4"><a href="#cb18-4" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).<span class="bu">min</span>().alias(<span class="st">'age_min'</span>),</span>
<span id="cb18-5"><a href="#cb18-5" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).<span class="bu">max</span>().alias(<span class="st">'age_max'</span>),</span>
<span id="cb18-6"><a href="#cb18-6" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).count().alias(<span class="st">'age_count'</span>)</span>
<span id="cb18-7"><a href="#cb18-7" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="joiningmerging-dataframes" class="level2">
<h2 class="anchored" data-anchor-id="joiningmerging-dataframes" id="joiningmerging-dataframes">Joining/Merging DataFrames</h2>
<section id="inner-join" class="level3">
<h3 class="anchored" data-anchor-id="inner-join" id="inner-join">Inner Join</h3>
<section id="pandas-8" class="level4">
<h4 class="anchored" data-anchor-id="pandas-8">Pandas</h4>
<div id="947ac82e" class="cell" data-execution_count="17">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><a href="#cb19-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create another DataFrame</span></span>
<span id="cb19-2"><a href="#cb19-2" aria-hidden="true" tabindex="-1"></a>employee_data <span class="op">=</span> {</span>
<span id="cb19-3"><a href="#cb19-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'emp_id'</span>: [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>],</span>
<span id="cb19-4"><a href="#cb19-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">'name'</span>: [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>],</span>
<span id="cb19-5"><a href="#cb19-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'dept'</span>: [<span class="st">'HR'</span>, <span class="st">'IT'</span>, <span class="st">'Finance'</span>, <span class="st">'IT'</span>]</span>
<span id="cb19-6"><a href="#cb19-6" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb19-7"><a href="#cb19-7" aria-hidden="true" tabindex="-1"></a>employee_df_pd <span class="op">=</span> pd.DataFrame(employee_data)</span>
<span id="cb19-8"><a href="#cb19-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-9"><a href="#cb19-9" aria-hidden="true" tabindex="-1"></a>salary_data <span class="op">=</span> {</span>
<span id="cb19-10"><a href="#cb19-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'emp_id'</span>: [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">5</span>],</span>
<span id="cb19-11"><a href="#cb19-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">'salary'</span>: [<span class="dv">50000</span>, <span class="dv">60000</span>, <span class="dv">70000</span>, <span class="dv">80000</span>]</span>
<span id="cb19-12"><a href="#cb19-12" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb19-13"><a href="#cb19-13" aria-hidden="true" tabindex="-1"></a>salary_df_pd <span class="op">=</span> pd.DataFrame(salary_data)</span>
<span id="cb19-14"><a href="#cb19-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb19-15"><a href="#cb19-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Inner join</span></span>
<span id="cb19-16"><a href="#cb19-16" aria-hidden="true" tabindex="-1"></a>merged_df <span class="op">=</span> employee_df_pd.merge(</span>
<span id="cb19-17"><a href="#cb19-17" aria-hidden="true" tabindex="-1"></a>    salary_df_pd,</span>
<span id="cb19-18"><a href="#cb19-18" aria-hidden="true" tabindex="-1"></a>    on<span class="op">=</span><span class="st">'emp_id'</span>,</span>
<span id="cb19-19"><a href="#cb19-19" aria-hidden="true" tabindex="-1"></a>    how<span class="op">=</span><span class="st">'inner'</span></span>
<span id="cb19-20"><a href="#cb19-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="polars-8" class="level4">
<h4 class="anchored" data-anchor-id="polars-8">Polars</h4>
<div id="fd4d6fa1" class="cell" data-execution_count="18">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create another DataFrame</span></span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a>employee_data <span class="op">=</span> {</span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'emp_id'</span>: [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">4</span>],</span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>    <span class="st">'name'</span>: [<span class="st">'Alice'</span>, <span class="st">'Bob'</span>, <span class="st">'Charlie'</span>, <span class="st">'David'</span>],</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    <span class="st">'dept'</span>: [<span class="st">'HR'</span>, <span class="st">'IT'</span>, <span class="st">'Finance'</span>, <span class="st">'IT'</span>]</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>employee_df_pl <span class="op">=</span> pl.DataFrame(employee_data)</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a>salary_data <span class="op">=</span> {</span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'emp_id'</span>: [<span class="dv">1</span>, <span class="dv">2</span>, <span class="dv">3</span>, <span class="dv">5</span>],</span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">'salary'</span>: [<span class="dv">50000</span>, <span class="dv">60000</span>, <span class="dv">70000</span>, <span class="dv">80000</span>]</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>salary_df_pl <span class="op">=</span> pl.DataFrame(salary_data)</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Inner join</span></span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>merged_df <span class="op">=</span> employee_df_pl.join(</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a>    salary_df_pl,</span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>    on<span class="op">=</span><span class="st">'emp_id'</span>,</span>
<span id="cb20-19"><a href="#cb20-19" aria-hidden="true" tabindex="-1"></a>    how<span class="op">=</span><span class="st">'inner'</span></span>
<span id="cb20-20"><a href="#cb20-20" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
</section>
<section id="different-join-types" class="level3">
<h3 class="anchored" data-anchor-id="different-join-types" id="different-join-types">Different join types</h3>
<section id="pandas-9" class="level4">
<h4 class="anchored" data-anchor-id="pandas-9">Pandas</h4>
<div id="91a88ca5" class="cell" data-execution_count="19">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Left join</span></span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>left_join <span class="op">=</span> employee_df_pd.merge(salary_df_pd, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'left'</span>)</span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-4"><a href="#cb21-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Right join</span></span>
<span id="cb21-5"><a href="#cb21-5" aria-hidden="true" tabindex="-1"></a>right_join <span class="op">=</span> employee_df_pd.merge(salary_df_pd, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'right'</span>)</span>
<span id="cb21-6"><a href="#cb21-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb21-7"><a href="#cb21-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Outer join</span></span>
<span id="cb21-8"><a href="#cb21-8" aria-hidden="true" tabindex="-1"></a>outer_join <span class="op">=</span> employee_df_pd.merge(salary_df_pd, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'outer'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="polars-9" class="level4">
<h4 class="anchored" data-anchor-id="polars-9">Polars</h4>
<div id="d49c9638" class="cell" data-execution_count="20">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Left join</span></span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>left_join <span class="op">=</span> employee_df_pl.join(salary_df_pl, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'left'</span>)</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Right join</span></span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>right_join <span class="op">=</span> employee_df_pl.join(salary_df_pl, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'right'</span>)</span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Outer join</span></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>outer_join <span class="op">=</span> employee_df_pl.join(salary_df_pl, on<span class="op">=</span><span class="st">'emp_id'</span>, how<span class="op">=</span><span class="st">'full'</span>)</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="handling-missing-values" class="level2">
<h2 class="anchored" data-anchor-id="handling-missing-values" id="handling-missing-values">Handling Missing Values</h2>
<section id="checking-for-missing-values" class="level3">
<h3 class="anchored" data-anchor-id="checking-for-missing-values" id="checking-for-missing-values">Checking for missing values</h3>
<section id="pandas-10" class="level4">
<h4 class="anchored" data-anchor-id="pandas-10">Pandas</h4>
<div id="3dd3c419" class="cell" data-execution_count="21">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Check for missing values</span></span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>missing_count <span class="op">=</span> df_pd.isnull().<span class="bu">sum</span>()</span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb23-4"><a href="#cb23-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Check if any column has missing values</span></span>
<span id="cb23-5"><a href="#cb23-5" aria-hidden="true" tabindex="-1"></a>has_missing <span class="op">=</span> df_pd.isnull().<span class="bu">any</span>().<span class="bu">any</span>()</span></code></pre></div></div>
</div>
</section>
<section id="polars-10" class="level4">
<h4 class="anchored" data-anchor-id="polars-10">Polars</h4>
<div id="1be5d51b" class="cell" data-execution_count="22">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Check for missing values</span></span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>missing_count <span class="op">=</span> df_pl.null_count()</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Check if specific column has missing values</span></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>has_missing <span class="op">=</span> df_pl.select(pl.col(<span class="st">'age'</span>).is_null().<span class="bu">any</span>()).item()</span></code></pre></div></div>
</div>
</section>
</section>
<section id="handling-missing-values-1" class="level3">
<h3 class="anchored" data-anchor-id="handling-missing-values-1" id="handling-missing-values-1">Handling missing values</h3>
<section id="pandas-11" class="level4">
<h4 class="anchored" data-anchor-id="pandas-11">Pandas</h4>
<div id="754a8ce7" class="cell" data-execution_count="23">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb25"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1"><a href="#cb25-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Drop rows with any missing values</span></span>
<span id="cb25-2"><a href="#cb25-2" aria-hidden="true" tabindex="-1"></a>df_pd_clean <span class="op">=</span> df_pd.dropna()</span>
<span id="cb25-3"><a href="#cb25-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-4"><a href="#cb25-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Fill missing values</span></span>
<span id="cb25-5"><a href="#cb25-5" aria-hidden="true" tabindex="-1"></a>df_pd_filled <span class="op">=</span> df_pd.fillna({</span>
<span id="cb25-6"><a href="#cb25-6" aria-hidden="true" tabindex="-1"></a>    <span class="st">'age'</span>: <span class="dv">0</span>,</span>
<span id="cb25-7"><a href="#cb25-7" aria-hidden="true" tabindex="-1"></a>    <span class="st">'city'</span>: <span class="st">'Unknown'</span></span>
<span id="cb25-8"><a href="#cb25-8" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb25-9"><a href="#cb25-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb25-10"><a href="#cb25-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Forward fill</span></span>
<span id="cb25-11"><a href="#cb25-11" aria-hidden="true" tabindex="-1"></a>df_pd_ffill <span class="op">=</span> df_pd.ffill()</span></code></pre></div></div>
</div>
</section>
<section id="polars-11" class="level4">
<h4 class="anchored" data-anchor-id="polars-11">Polars</h4>
<div id="379f9f12" class="cell" data-execution_count="24">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb26"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><a href="#cb26-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Drop rows with any missing values</span></span>
<span id="cb26-2"><a href="#cb26-2" aria-hidden="true" tabindex="-1"></a>df_pl_clean <span class="op">=</span> df_pl.drop_nulls()</span>
<span id="cb26-3"><a href="#cb26-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-4"><a href="#cb26-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Fill missing values</span></span>
<span id="cb26-5"><a href="#cb26-5" aria-hidden="true" tabindex="-1"></a>df_pl_filled <span class="op">=</span> df_pl.with_columns([</span>
<span id="cb26-6"><a href="#cb26-6" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).fill_null(<span class="dv">0</span>),</span>
<span id="cb26-7"><a href="#cb26-7" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'city'</span>).fill_null(<span class="st">'Unknown'</span>)</span>
<span id="cb26-8"><a href="#cb26-8" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb26-9"><a href="#cb26-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb26-10"><a href="#cb26-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Forward fill</span></span>
<span id="cb26-11"><a href="#cb26-11" aria-hidden="true" tabindex="-1"></a>df_pl_ffill <span class="op">=</span> df_pl.with_columns([</span>
<span id="cb26-12"><a href="#cb26-12" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'age'</span>).fill_null(strategy<span class="op">=</span><span class="st">'forward'</span>),</span>
<span id="cb26-13"><a href="#cb26-13" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'city'</span>).fill_null(strategy<span class="op">=</span><span class="st">'forward'</span>)</span>
<span id="cb26-14"><a href="#cb26-14" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="string-operations" class="level2">
<h2 class="anchored" data-anchor-id="string-operations" id="string-operations">String Operations</h2>
<section id="basic-string-operations" class="level3">
<h3 class="anchored" data-anchor-id="basic-string-operations" id="basic-string-operations">Basic string operations</h3>
<section id="pandas-12" class="level4">
<h4 class="anchored" data-anchor-id="pandas-12">Pandas</h4>
<div id="a2da31ca" class="cell" data-execution_count="25">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb27"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><a href="#cb27-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to uppercase</span></span>
<span id="cb27-2"><a href="#cb27-2" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'name_upper'</span>] <span class="op">=</span> df_pd[<span class="st">'name'</span>].<span class="bu">str</span>.upper()</span>
<span id="cb27-3"><a href="#cb27-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-4"><a href="#cb27-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Get string length</span></span>
<span id="cb27-5"><a href="#cb27-5" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'name_length'</span>] <span class="op">=</span> df_pd[<span class="st">'name'</span>].<span class="bu">str</span>.<span class="bu">len</span>()</span>
<span id="cb27-6"><a href="#cb27-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-7"><a href="#cb27-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract substring</span></span>
<span id="cb27-8"><a href="#cb27-8" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'name_first_char'</span>] <span class="op">=</span> df_pd[<span class="st">'name'</span>].<span class="bu">str</span>[<span class="dv">0</span>]</span>
<span id="cb27-9"><a href="#cb27-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb27-10"><a href="#cb27-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Replace substrings</span></span>
<span id="cb27-11"><a href="#cb27-11" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'city_replaced'</span>] <span class="op">=</span> df_pd[<span class="st">'city'</span>].<span class="bu">str</span>.replace(<span class="st">'New'</span>, <span class="st">'Old'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="polars-12" class="level4">
<h4 class="anchored" data-anchor-id="polars-12">Polars</h4>
<div id="78995c7c" class="cell" data-execution_count="26">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb28"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><a href="#cb28-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to uppercase</span></span>
<span id="cb28-2"><a href="#cb28-2" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(pl.col(<span class="st">'name'</span>).<span class="bu">str</span>.to_uppercase().alias(<span class="st">'name_upper'</span>))</span>
<span id="cb28-3"><a href="#cb28-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-4"><a href="#cb28-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Get string length</span></span>
<span id="cb28-5"><a href="#cb28-5" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(pl.col(<span class="st">'name'</span>).<span class="bu">str</span>.len_chars().alias(<span class="st">'name_length'</span>))</span>
<span id="cb28-6"><a href="#cb28-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-7"><a href="#cb28-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract substring </span></span>
<span id="cb28-8"><a href="#cb28-8" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(pl.col(<span class="st">'name'</span>).<span class="bu">str</span>.<span class="bu">slice</span>(<span class="dv">0</span>, <span class="dv">1</span>).alias(<span class="st">'name_first_char'</span>))</span>
<span id="cb28-9"><a href="#cb28-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb28-10"><a href="#cb28-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Replace substrings</span></span>
<span id="cb28-11"><a href="#cb28-11" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(pl.col(<span class="st">'city'</span>).<span class="bu">str</span>.replace(<span class="st">'New'</span>, <span class="st">'Old'</span>).alias(<span class="st">'city_replaced'</span>))</span></code></pre></div></div>
</div>
</section>
</section>
<section id="advanced-string-operations" class="level3">
<h3 class="anchored" data-anchor-id="advanced-string-operations" id="advanced-string-operations">Advanced string operations</h3>
<section id="pandas-13" class="level4">
<h4 class="anchored" data-anchor-id="pandas-13">Pandas</h4>
<div id="f8b2bf06" class="cell" data-execution_count="27">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb29"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1"><a href="#cb29-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Split string</span></span>
<span id="cb29-2"><a href="#cb29-2" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'first_word'</span>] <span class="op">=</span> df_pd[<span class="st">'city'</span>].<span class="bu">str</span>.split(<span class="st">' '</span>).<span class="bu">str</span>[<span class="dv">0</span>]</span>
<span id="cb29-3"><a href="#cb29-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-4"><a href="#cb29-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Pattern matching</span></span>
<span id="cb29-5"><a href="#cb29-5" aria-hidden="true" tabindex="-1"></a>has_new <span class="op">=</span> df_pd[<span class="st">'city'</span>].<span class="bu">str</span>.contains(<span class="st">'New'</span>)</span>
<span id="cb29-6"><a href="#cb29-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb29-7"><a href="#cb29-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract with regex</span></span>
<span id="cb29-8"><a href="#cb29-8" aria-hidden="true" tabindex="-1"></a>df_pd[<span class="st">'extracted'</span>] <span class="op">=</span> df_pd[<span class="st">'city'</span>].<span class="bu">str</span>.extract(<span class="vs">r'</span><span class="kw">(</span><span class="dv">\w</span><span class="op">+</span><span class="kw">)</span><span class="dv">\s</span><span class="vs">'</span>)</span></code></pre></div></div>
</div>
</section>
<section id="polars-13" class="level4">
<h4 class="anchored" data-anchor-id="polars-13">Polars</h4>
<div id="870d2705" class="cell" data-execution_count="28">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb30"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><a href="#cb30-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Split string</span></span>
<span id="cb30-2"><a href="#cb30-2" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(</span>
<span id="cb30-3"><a href="#cb30-3" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'city'</span>).<span class="bu">str</span>.split(<span class="st">' '</span>).<span class="bu">list</span>.get(<span class="dv">0</span>).alias(<span class="st">'first_word'</span>)</span>
<span id="cb30-4"><a href="#cb30-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb30-5"><a href="#cb30-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-6"><a href="#cb30-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Pattern matching</span></span>
<span id="cb30-7"><a href="#cb30-7" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(</span>
<span id="cb30-8"><a href="#cb30-8" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'city'</span>).<span class="bu">str</span>.contains(<span class="st">'New'</span>).alias(<span class="st">'has_new'</span>)</span>
<span id="cb30-9"><a href="#cb30-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb30-10"><a href="#cb30-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb30-11"><a href="#cb30-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract with regex</span></span>
<span id="cb30-12"><a href="#cb30-12" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> df_pl.with_columns(</span>
<span id="cb30-13"><a href="#cb30-13" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'city'</span>).<span class="bu">str</span>.extract(<span class="vs">r'</span><span class="kw">(</span><span class="dv">\w</span><span class="op">+</span><span class="kw">)</span><span class="dv">\s</span><span class="vs">'</span>).alias(<span class="st">'extracted'</span>)</span>
<span id="cb30-14"><a href="#cb30-14" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="time-series-operations" class="level2">
<h2 class="anchored" data-anchor-id="time-series-operations" id="time-series-operations">Time Series Operations</h2>
<section id="date-parsing-and-creation" class="level3">
<h3 class="anchored" data-anchor-id="date-parsing-and-creation" id="date-parsing-and-creation">Date parsing and creation</h3>
<section id="pandas-14" class="level4">
<h4 class="anchored" data-anchor-id="pandas-14">Pandas</h4>
<div id="a7ed5b5c" class="cell" data-execution_count="29">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb31"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb31-1"><a href="#cb31-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create DataFrame with dates</span></span>
<span id="cb31-2"><a href="#cb31-2" aria-hidden="true" tabindex="-1"></a>dates_pd <span class="op">=</span> pd.DataFrame({</span>
<span id="cb31-3"><a href="#cb31-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'date_str'</span>: [<span class="st">'2023-01-01'</span>, <span class="st">'2023-02-15'</span>, <span class="st">'2023-03-30'</span>]</span>
<span id="cb31-4"><a href="#cb31-4" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb31-5"><a href="#cb31-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-6"><a href="#cb31-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Parse dates</span></span>
<span id="cb31-7"><a href="#cb31-7" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'date'</span>] <span class="op">=</span> pd.to_datetime(dates_pd[<span class="st">'date_str'</span>])</span>
<span id="cb31-8"><a href="#cb31-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb31-9"><a href="#cb31-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract components</span></span>
<span id="cb31-10"><a href="#cb31-10" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'year'</span>] <span class="op">=</span> dates_pd[<span class="st">'date'</span>].dt.year</span>
<span id="cb31-11"><a href="#cb31-11" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'month'</span>] <span class="op">=</span> dates_pd[<span class="st">'date'</span>].dt.month</span>
<span id="cb31-12"><a href="#cb31-12" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'day'</span>] <span class="op">=</span> dates_pd[<span class="st">'date'</span>].dt.day</span>
<span id="cb31-13"><a href="#cb31-13" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'weekday'</span>] <span class="op">=</span> dates_pd[<span class="st">'date'</span>].dt.day_name()</span></code></pre></div></div>
</div>
</section>
<section id="polars-14" class="level4">
<h4 class="anchored" data-anchor-id="polars-14">Polars</h4>
<div id="19572f65" class="cell" data-execution_count="30">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb32"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><a href="#cb32-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Create DataFrame with dates</span></span>
<span id="cb32-2"><a href="#cb32-2" aria-hidden="true" tabindex="-1"></a>dates_pl <span class="op">=</span> pl.DataFrame({</span>
<span id="cb32-3"><a href="#cb32-3" aria-hidden="true" tabindex="-1"></a>    <span class="st">'date_str'</span>: [<span class="st">'2023-01-01'</span>, <span class="st">'2023-02-15'</span>, <span class="st">'2023-03-30'</span>]</span>
<span id="cb32-4"><a href="#cb32-4" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb32-5"><a href="#cb32-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-6"><a href="#cb32-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Parse dates</span></span>
<span id="cb32-7"><a href="#cb32-7" aria-hidden="true" tabindex="-1"></a>dates_pl <span class="op">=</span> dates_pl.with_columns(</span>
<span id="cb32-8"><a href="#cb32-8" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'date_str'</span>).<span class="bu">str</span>.strptime(pl.Datetime, <span class="st">'%Y-%m-</span><span class="sc">%d</span><span class="st">'</span>).alias(<span class="st">'date'</span>)</span>
<span id="cb32-9"><a href="#cb32-9" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb32-10"><a href="#cb32-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb32-11"><a href="#cb32-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract components</span></span>
<span id="cb32-12"><a href="#cb32-12" aria-hidden="true" tabindex="-1"></a>dates_pl <span class="op">=</span> dates_pl.with_columns([</span>
<span id="cb32-13"><a href="#cb32-13" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'date'</span>).dt.year().alias(<span class="st">'year'</span>),</span>
<span id="cb32-14"><a href="#cb32-14" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'date'</span>).dt.month().alias(<span class="st">'month'</span>),</span>
<span id="cb32-15"><a href="#cb32-15" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'date'</span>).dt.day().alias(<span class="st">'day'</span>),</span>
<span id="cb32-16"><a href="#cb32-16" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'date'</span>).dt.weekday().replace_strict({</span>
<span id="cb32-17"><a href="#cb32-17" aria-hidden="true" tabindex="-1"></a>        <span class="dv">0</span>: <span class="st">'Monday'</span>, <span class="dv">1</span>: <span class="st">'Tuesday'</span>, <span class="dv">2</span>: <span class="st">'Wednesday'</span>, </span>
<span id="cb32-18"><a href="#cb32-18" aria-hidden="true" tabindex="-1"></a>        <span class="dv">3</span>: <span class="st">'Thursday'</span>, <span class="dv">4</span>: <span class="st">'Friday'</span>, <span class="dv">5</span>: <span class="st">'Saturday'</span>, <span class="dv">6</span>: <span class="st">'Sunday'</span></span>
<span id="cb32-19"><a href="#cb32-19" aria-hidden="true" tabindex="-1"></a>    }, default<span class="op">=</span><span class="st">"unknown"</span>).alias(<span class="st">'weekday'</span>)</span>
<span id="cb32-20"><a href="#cb32-20" aria-hidden="true" tabindex="-1"></a>])</span></code></pre></div></div>
</div>
</section>
</section>
<section id="date-arithmetic" class="level3">
<h3 class="anchored" data-anchor-id="date-arithmetic" id="date-arithmetic">Date arithmetic</h3>
<section id="pandas-15" class="level4">
<h4 class="anchored" data-anchor-id="pandas-15">Pandas</h4>
<div id="0f6645e0" class="cell" data-execution_count="31">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb33"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1"><a href="#cb33-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add days</span></span>
<span id="cb33-2"><a href="#cb33-2" aria-hidden="true" tabindex="-1"></a>dates_pd[<span class="st">'next_week'</span>] <span class="op">=</span> dates_pd[<span class="st">'date'</span>] <span class="op">+</span> pd.Timedelta(days<span class="op">=</span><span class="dv">7</span>)</span>
<span id="cb33-3"><a href="#cb33-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb33-4"><a href="#cb33-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Date difference</span></span>
<span id="cb33-5"><a href="#cb33-5" aria-hidden="true" tabindex="-1"></a>date_range <span class="op">=</span> pd.date_range(start<span class="op">=</span><span class="st">'2023-01-01'</span>, end<span class="op">=</span><span class="st">'2023-01-10'</span>)</span>
<span id="cb33-6"><a href="#cb33-6" aria-hidden="true" tabindex="-1"></a>df_dates <span class="op">=</span> pd.DataFrame({<span class="st">'date'</span>: date_range})</span>
<span id="cb33-7"><a href="#cb33-7" aria-hidden="true" tabindex="-1"></a>df_dates[<span class="st">'days_since_start'</span>] <span class="op">=</span> (df_dates[<span class="st">'date'</span>] <span class="op">-</span> df_dates[<span class="st">'date'</span>].<span class="bu">min</span>()).dt.days</span></code></pre></div></div>
</div>
</section>
<section id="polars-15" class="level4">
<h4 class="anchored" data-anchor-id="polars-15">Polars</h4>
<div id="8e7433c6" class="cell" data-execution_count="32">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb34"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb34-1"><a href="#cb34-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Add days</span></span>
<span id="cb34-2"><a href="#cb34-2" aria-hidden="true" tabindex="-1"></a>dates_pl <span class="op">=</span> dates_pl.with_columns(</span>
<span id="cb34-3"><a href="#cb34-3" aria-hidden="true" tabindex="-1"></a>    (pl.col(<span class="st">'date'</span>) <span class="op">+</span> pl.duration(days<span class="op">=</span><span class="dv">7</span>)).alias(<span class="st">'next_week'</span>)</span>
<span id="cb34-4"><a href="#cb34-4" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb34-5"><a href="#cb34-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb34-6"><a href="#cb34-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Date difference</span></span>
<span id="cb34-7"><a href="#cb34-7" aria-hidden="true" tabindex="-1"></a>date_range <span class="op">=</span> pd.date_range(start<span class="op">=</span><span class="st">'2023-01-01'</span>, end<span class="op">=</span><span class="st">'2023-01-10'</span>)  <span class="co"># Using pandas to generate range</span></span>
<span id="cb34-8"><a href="#cb34-8" aria-hidden="true" tabindex="-1"></a>df_dates <span class="op">=</span> pl.DataFrame({<span class="st">'date'</span>: date_range})</span>
<span id="cb34-9"><a href="#cb34-9" aria-hidden="true" tabindex="-1"></a>df_dates <span class="op">=</span> df_dates.with_columns(</span>
<span id="cb34-10"><a href="#cb34-10" aria-hidden="true" tabindex="-1"></a>    (pl.col(<span class="st">'date'</span>) <span class="op">-</span> pl.col(<span class="st">'date'</span>).<span class="bu">min</span>()).dt.total_days().alias(<span class="st">'days_since_start'</span>)</span>
<span id="cb34-11"><a href="#cb34-11" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="performance-comparison" class="level2">
<h2 class="anchored" data-anchor-id="performance-comparison" id="performance-comparison">Performance Comparison</h2>
<p>This section demonstrates performance differences between pandas and polars for a large dataset operation.</p>
<div id="77853cab" class="cell" data-execution_count="33">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb35"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><a href="#cb35-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pandas <span class="im">as</span> pd</span>
<span id="cb35-2"><a href="#cb35-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> polars <span class="im">as</span> pl</span>
<span id="cb35-3"><a href="#cb35-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> time</span>
<span id="cb35-4"><a href="#cb35-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> numpy <span class="im">as</span> np</span>
<span id="cb35-5"><a href="#cb35-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-6"><a href="#cb35-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Generate a large dataset (10 million rows)</span></span>
<span id="cb35-7"><a href="#cb35-7" aria-hidden="true" tabindex="-1"></a>n <span class="op">=</span> <span class="dv">10_000_000</span></span>
<span id="cb35-8"><a href="#cb35-8" aria-hidden="true" tabindex="-1"></a>data <span class="op">=</span> {</span>
<span id="cb35-9"><a href="#cb35-9" aria-hidden="true" tabindex="-1"></a>    <span class="st">'id'</span>: np.arange(n),</span>
<span id="cb35-10"><a href="#cb35-10" aria-hidden="true" tabindex="-1"></a>    <span class="st">'value'</span>: np.random.randn(n),</span>
<span id="cb35-11"><a href="#cb35-11" aria-hidden="true" tabindex="-1"></a>    <span class="st">'group'</span>: np.random.choice([<span class="st">'A'</span>, <span class="st">'B'</span>, <span class="st">'C'</span>, <span class="st">'D'</span>], n)</span>
<span id="cb35-12"><a href="#cb35-12" aria-hidden="true" tabindex="-1"></a>}</span>
<span id="cb35-13"><a href="#cb35-13" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-14"><a href="#cb35-14" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to pandas DataFrame</span></span>
<span id="cb35-15"><a href="#cb35-15" aria-hidden="true" tabindex="-1"></a>df_pd <span class="op">=</span> pd.DataFrame(data)</span>
<span id="cb35-16"><a href="#cb35-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-17"><a href="#cb35-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Convert to polars DataFrame</span></span>
<span id="cb35-18"><a href="#cb35-18" aria-hidden="true" tabindex="-1"></a>df_pl <span class="op">=</span> pl.DataFrame(data)</span>
<span id="cb35-19"><a href="#cb35-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-20"><a href="#cb35-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Benchmark: Group by and calculate mean, min, max</span></span>
<span id="cb35-21"><a href="#cb35-21" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Running pandas groupby..."</span>)</span>
<span id="cb35-22"><a href="#cb35-22" aria-hidden="true" tabindex="-1"></a>start <span class="op">=</span> time.time()</span>
<span id="cb35-23"><a href="#cb35-23" aria-hidden="true" tabindex="-1"></a>result_pd <span class="op">=</span> df_pd.groupby(<span class="st">'group'</span>).agg({</span>
<span id="cb35-24"><a href="#cb35-24" aria-hidden="true" tabindex="-1"></a>    <span class="st">'value'</span>: [<span class="st">'mean'</span>, <span class="st">'min'</span>, <span class="st">'max'</span>, <span class="st">'count'</span>]</span>
<span id="cb35-25"><a href="#cb35-25" aria-hidden="true" tabindex="-1"></a>})</span>
<span id="cb35-26"><a href="#cb35-26" aria-hidden="true" tabindex="-1"></a>pd_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb35-27"><a href="#cb35-27" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Pandas time: </span><span class="sc">{</span>pd_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb35-28"><a href="#cb35-28" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-29"><a href="#cb35-29" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="st">"Running polars groupby..."</span>)</span>
<span id="cb35-30"><a href="#cb35-30" aria-hidden="true" tabindex="-1"></a>start <span class="op">=</span> time.time()</span>
<span id="cb35-31"><a href="#cb35-31" aria-hidden="true" tabindex="-1"></a>result_pl <span class="op">=</span> df_pl.group_by(<span class="st">'group'</span>).agg([</span>
<span id="cb35-32"><a href="#cb35-32" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'value'</span>).mean().alias(<span class="st">'value_mean'</span>),</span>
<span id="cb35-33"><a href="#cb35-33" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'value'</span>).<span class="bu">min</span>().alias(<span class="st">'value_min'</span>),</span>
<span id="cb35-34"><a href="#cb35-34" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'value'</span>).<span class="bu">max</span>().alias(<span class="st">'value_max'</span>),</span>
<span id="cb35-35"><a href="#cb35-35" aria-hidden="true" tabindex="-1"></a>    pl.col(<span class="st">'value'</span>).count().alias(<span class="st">'value_count'</span>)</span>
<span id="cb35-36"><a href="#cb35-36" aria-hidden="true" tabindex="-1"></a>])</span>
<span id="cb35-37"><a href="#cb35-37" aria-hidden="true" tabindex="-1"></a>pl_time <span class="op">=</span> time.time() <span class="op">-</span> start</span>
<span id="cb35-38"><a href="#cb35-38" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Polars time: </span><span class="sc">{</span>pl_time<span class="sc">:.4f}</span><span class="ss"> seconds"</span>)</span>
<span id="cb35-39"><a href="#cb35-39" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb35-40"><a href="#cb35-40" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f"Polars is </span><span class="sc">{</span>pd_time <span class="op">/</span> pl_time<span class="sc">:.2f}</span><span class="ss">x faster"</span>)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Running pandas groupby...
Pandas time: 0.2384 seconds
Running polars groupby...
Polars time: 0.0992 seconds
Polars is 2.40x faster</code></pre>
</div>
</div>
<p>Typically, for operations like this, Polars will be 3-10x faster than pandas, especially as data sizes increase. The performance gap widens further with more complex operations that can benefit from query optimization.</p>
</section>
<section id="api-philosophy-differences" class="level2">
<h2 class="anchored" data-anchor-id="api-philosophy-differences" id="api-philosophy-differences">API Philosophy Differences</h2>
<p>Pandas and Polars differ in several fundamental aspects:</p>
<section id="eager-vs.-lazy-execution" class="level3">
<h3 class="anchored" data-anchor-id="eager-vs.-lazy-execution" id="eager-vs.-lazy-execution">1. Eager vs.&nbsp;Lazy Execution</h3>
<p><strong>Pandas</strong> uses eager execution by default:</p>
<div id="808f50ad" class="cell" data-execution_count="34">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb37"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><a href="#cb37-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Each operation is executed immediately</span></span>
<span id="cb37-2"><a href="#cb37-2" aria-hidden="true" tabindex="-1"></a>pd_df <span class="op">=</span> pd.read_csv(<span class="st">'data.csv'</span>)</span>
<span id="cb37-3"><a href="#cb37-3" aria-hidden="true" tabindex="-1"></a>pd_df <span class="op">=</span> pd_df[pd_df[<span class="st">'age'</span>] <span class="op">&gt;</span> <span class="dv">0</span>]  <span class="co"># Filter is applied immediately</span></span>
<span id="cb37-4"><a href="#cb37-4" aria-hidden="true" tabindex="-1"></a>pd_df[<span class="st">'new_col'</span>] <span class="op">=</span> pd_df[<span class="st">'age'</span>] <span class="op">*</span> <span class="dv">2</span>  <span class="co"># Transformation is applied immediately</span></span>
<span id="cb37-5"><a href="#cb37-5" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> pd_df.groupby(<span class="st">'gender'</span>).<span class="bu">sum</span>()  <span class="co"># Groupby is executed immediately</span></span></code></pre></div></div>
</div>
<p><strong>Polars</strong> supports both eager and lazy execution:</p>
<div id="e988c1e8" class="cell" data-execution_count="35">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb38"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb38-1"><a href="#cb38-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Eager execution</span></span>
<span id="cb38-2"><a href="#cb38-2" aria-hidden="true" tabindex="-1"></a>pl_df <span class="op">=</span> pl.read_csv(<span class="st">'data.csv'</span>)</span>
<span id="cb38-3"><a href="#cb38-3" aria-hidden="true" tabindex="-1"></a>pl_df <span class="op">=</span> pl_df.<span class="bu">filter</span>(pl.col(<span class="st">'age'</span>) <span class="op">&gt;</span> <span class="dv">0</span>)</span>
<span id="cb38-4"><a href="#cb38-4" aria-hidden="true" tabindex="-1"></a>pl_df <span class="op">=</span> pl_df.with_columns((pl.col(<span class="st">'age'</span>) <span class="op">*</span> <span class="dv">2</span>).alias(<span class="st">'new_col'</span>))</span>
<span id="cb38-5"><a href="#cb38-5" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> pl_df.group_by(<span class="st">'gender'</span>).agg(pl.col(<span class="st">'age'</span>).<span class="bu">sum</span>())</span>
<span id="cb38-6"><a href="#cb38-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb38-7"><a href="#cb38-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Lazy execution</span></span>
<span id="cb38-8"><a href="#cb38-8" aria-hidden="true" tabindex="-1"></a>pl_lazy_df <span class="op">=</span> pl.scan_csv(<span class="st">'data.csv'</span>)  <span class="co"># Creates a lazy frame</span></span>
<span id="cb38-9"><a href="#cb38-9" aria-hidden="true" tabindex="-1"></a>result <span class="op">=</span> (pl_lazy_df</span>
<span id="cb38-10"><a href="#cb38-10" aria-hidden="true" tabindex="-1"></a>    .<span class="bu">filter</span>(pl.col(<span class="st">'age'</span>) <span class="op">&gt;</span> <span class="dv">0</span>)</span>
<span id="cb38-11"><a href="#cb38-11" aria-hidden="true" tabindex="-1"></a>    .with_columns((pl.col(<span class="st">'age'</span>) <span class="op">*</span> <span class="dv">2</span>).alias(<span class="st">'new_col'</span>))</span>
<span id="cb38-12"><a href="#cb38-12" aria-hidden="true" tabindex="-1"></a>    .group_by(<span class="st">'gender'</span>)</span>
<span id="cb38-13"><a href="#cb38-13" aria-hidden="true" tabindex="-1"></a>    .agg(pl.col(<span class="st">'age'</span>).<span class="bu">sum</span>())</span>
<span id="cb38-14"><a href="#cb38-14" aria-hidden="true" tabindex="-1"></a>    .collect()  <span class="co"># Only at this point is the query executed</span></span>
<span id="cb38-15"><a href="#cb38-15" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="method-chaining-vs.-assignment" class="level3">
<h3 class="anchored" data-anchor-id="method-chaining-vs.-assignment" id="method-chaining-vs.-assignment">2. Method Chaining vs.&nbsp;Assignment</h3>
<p><strong>Pandas</strong> often uses assignment operations:</p>
<div id="d171a1d7" class="cell" data-execution_count="36">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb39"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1"><a href="#cb39-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Many pandas operations use in-place assignment</span></span>
<span id="cb39-2"><a href="#cb39-2" aria-hidden="true" tabindex="-1"></a>pd_df[<span class="st">'new_col'</span>] <span class="op">=</span> pd_df[<span class="st">'new_col'</span>] <span class="op">*</span> <span class="dv">2</span></span>
<span id="cb39-3"><a href="#cb39-3" aria-hidden="true" tabindex="-1"></a>pd_df[<span class="st">'new_col'</span>] <span class="op">=</span> pd_df[<span class="st">'new_col'</span>].fillna(<span class="dv">0</span>)</span>
<span id="cb39-4"><a href="#cb39-4" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb39-5"><a href="#cb39-5" aria-hidden="true" tabindex="-1"></a><span class="co"># Some operations return new DataFrames</span></span>
<span id="cb39-6"><a href="#cb39-6" aria-hidden="true" tabindex="-1"></a>pd_df <span class="op">=</span> pd_df.sort_values(<span class="st">'new_col'</span>)</span></code></pre></div></div>
</div>
<p><strong>Polars</strong> consistently uses method chaining:</p>
<div id="ac599e39" class="cell" data-execution_count="37">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb40"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><a href="#cb40-1" aria-hidden="true" tabindex="-1"></a><span class="co"># All operations return new DataFrames and can be chained</span></span>
<span id="cb40-2"><a href="#cb40-2" aria-hidden="true" tabindex="-1"></a>pl_df <span class="op">=</span> (pl_df</span>
<span id="cb40-3"><a href="#cb40-3" aria-hidden="true" tabindex="-1"></a>    .with_columns((pl.col(<span class="st">'new_col'</span>) <span class="op">*</span> <span class="dv">2</span>).alias(<span class="st">'new_col'</span>))</span>
<span id="cb40-4"><a href="#cb40-4" aria-hidden="true" tabindex="-1"></a>    .with_columns(pl.col(<span class="st">'new_col'</span>).fill_null(<span class="dv">0</span>))</span>
<span id="cb40-5"><a href="#cb40-5" aria-hidden="true" tabindex="-1"></a>    .sort(<span class="st">'new_col'</span>)</span>
<span id="cb40-6"><a href="#cb40-6" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</div>
</section>
<section id="expression-api-vs.-direct-references" class="level3">
<h3 class="anchored" data-anchor-id="expression-api-vs.-direct-references" id="expression-api-vs.-direct-references">3. Expression API vs.&nbsp;Direct References</h3>
<p><strong>Pandas</strong> directly references columns:</p>
<div id="acdc3eba" class="cell" data-execution_count="38">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb41"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><a href="#cb41-1" aria-hidden="true" tabindex="-1"></a>pd_df[<span class="st">'result'</span>] <span class="op">=</span> pd_df[<span class="st">'age'</span>] <span class="op">+</span> pd_df[<span class="st">'new_col'</span>]</span>
<span id="cb41-2"><a href="#cb41-2" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> pd_df[pd_df[<span class="st">'age'</span>] <span class="op">&gt;</span> pd_df[<span class="st">'age'</span>].mean()]</span></code></pre></div></div>
</div>
<p><strong>Polars</strong> uses an expression API:</p>
<div id="c6e4964e" class="cell" data-execution_count="39">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb42"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb42-1"><a href="#cb42-1" aria-hidden="true" tabindex="-1"></a>pl_df <span class="op">=</span> pl_df.with_columns(</span>
<span id="cb42-2"><a href="#cb42-2" aria-hidden="true" tabindex="-1"></a>    (pl.col(<span class="st">'age'</span>) <span class="op">+</span> pl.col(<span class="st">'new_col'</span>)).alias(<span class="st">'result'</span>)</span>
<span id="cb42-3"><a href="#cb42-3" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb42-4"><a href="#cb42-4" aria-hidden="true" tabindex="-1"></a>filtered <span class="op">=</span> pl_df.<span class="bu">filter</span>(pl.col(<span class="st">'age'</span>) <span class="op">&gt;</span> pl.col(<span class="st">'age'</span>).mean())</span></code></pre></div></div>
</div>
</section>
</section>
<section id="migration-guide" class="level2">
<h2 class="anchored" data-anchor-id="migration-guide" id="migration-guide">Migration Guide</h2>
<p>If you’re transitioning from pandas to polars, here are key mappings between common operations:</p>
<table class="caption-top table">
<colgroup>
<col style="width: 40%">
<col style="width: 29%">
<col style="width: 29%">
</colgroup>
<thead>
<tr class="header">
<th>Operation</th>
<th>Pandas</th>
<th>Polars</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Read CSV</td>
<td><code>pd.read_csv('file.csv')</code></td>
<td><code>pl.read_csv('file.csv')</code></td>
</tr>
<tr class="even">
<td>Select columns</td>
<td><code>df[['col1', 'col2']]</code></td>
<td><code>df.select(['col1', 'col2'])</code></td>
</tr>
<tr class="odd">
<td>Add column</td>
<td><code>df['new'] = df['col1'] * 2</code></td>
<td><code>df.with_columns((pl.col('col1') * 2).alias('new'))</code></td>
</tr>
<tr class="even">
<td>Filter rows</td>
<td><code>df[df['col'] &gt; 5]</code></td>
<td><code>df.filter(pl.col('col') &gt; 5)</code></td>
</tr>
<tr class="odd">
<td>Sort</td>
<td><code>df.sort_values('col')</code></td>
<td><code>df.sort('col')</code></td>
</tr>
<tr class="even">
<td>Group by</td>
<td><code>df.groupby('col').agg({'val': 'sum'})</code></td>
<td><code>df.group_by('col').agg(pl.col('val').sum())</code></td>
</tr>
<tr class="odd">
<td>Join</td>
<td><code>df1.merge(df2, on='key')</code></td>
<td><code>df1.join(df2, on='key')</code></td>
</tr>
<tr class="even">
<td>Fill NA</td>
<td><code>df.fillna(0)</code></td>
<td><code>df.fill_null(0)</code></td>
</tr>
<tr class="odd">
<td>Drop NA</td>
<td><code>df.dropna()</code></td>
<td><code>df.drop_nulls()</code></td>
</tr>
<tr class="even">
<td>Rename</td>
<td><code>df.rename(columns={'a': 'b'})</code></td>
<td><code>df.rename({'a': 'b'})</code></td>
</tr>
<tr class="odd">
<td>Unique values</td>
<td><code>df['col'].unique()</code></td>
<td><code>df.select(pl.col('col').unique())</code></td>
</tr>
<tr class="even">
<td>Value counts</td>
<td><code>df['col'].value_counts()</code></td>
<td><code>df.group_by('col').count()</code></td>
</tr>
</tbody>
</table>
<section id="key-tips-for-migration" class="level3">
<h3 class="anchored" data-anchor-id="key-tips-for-migration" id="key-tips-for-migration">Key Tips for Migration</h3>
<ol type="1">
<li><strong>Think in expressions</strong>: Use <code>pl.col()</code> to reference columns in operations</li>
<li><strong>Embrace method chaining</strong>: String operations together instead of intermediate variables</li>
<li><strong>Try lazy execution</strong>: For complex operations, use <code>pl.scan_csv()</code> and lazy operations</li>
<li><strong>Use with_columns()</strong>: Instead of direct assignment, use with_columns for adding/modifying columns</li>
<li><strong>Learn the expression functions</strong>: Many operations like string manipulation use different syntax</li>
</ol>
</section>
<section id="when-to-keep-using-pandas" class="level3">
<h3 class="anchored" data-anchor-id="when-to-keep-using-pandas" id="when-to-keep-using-pandas">When to Keep Using Pandas</h3>
<p>Despite Polars’ advantages, pandas might still be preferred when:</p>
<ul>
<li>Working with existing codebases heavily dependent on pandas</li>
<li>Using specialized libraries that only support pandas (some visualization tools)</li>
<li>Dealing with very small datasets where performance isn’t critical</li>
<li>Using pandas-specific features without polars equivalents</li>
<li>Working with time series data that benefits from pandas’ specialized functionality</li>
</ul>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>Polars offers significant performance improvements and a more consistent API compared to pandas, particularly for large datasets and complex operations. While the syntax differences require some adjustment, the benefits in speed and memory efficiency make it a compelling choice for modern data analysis workflows.</p>
<p>Both libraries have their place in the Python data ecosystem. Pandas remains the more mature option with broader ecosystem compatibility, while Polars represents the future of high-performance data processing. For new projects dealing with large datasets, Polars is increasingly becoming the recommended choice.</p>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[PyTorch Lightning: A Comprehensive Guide]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/model-training/pytorch-lightning/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/model-training/pytorch-lightning/</guid>
      <pubDate>Sat, 29 Mar 2025 00:00:00 GMT</pubDate>
      
      <category>code</category>
      <category>tutorial</category>
      <category>advanced</category>
      <content:encoded><![CDATA[






<section id="pytorch-lightning-a-comprehensive-guide" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/model-training/pytorch-lightning/pytl.png" class="img-fluid"></p>
<p>PyTorch Lightning is a lightweight wrapper for PyTorch that helps organize code and reduce boilerplate while adding powerful features for research and production. This guide will walk you through the basics to advanced techniques.</p>
<section id="introduction-to-pytorch-lightning" class="level2">
<h2 class="anchored" data-anchor-id="introduction-to-pytorch-lightning" id="introduction-to-pytorch-lightning">Introduction to PyTorch Lightning</h2>
<p>PyTorch Lightning separates research code from engineering code, making models more:</p>
<ul>
<li><strong>Reproducible</strong>: The same code works across different hardware</li>
<li><strong>Readable</strong>: Standard project structure makes collaboration easier</li>
<li><strong>Scalable</strong>: Train on CPUs, GPUs, TPUs or clusters with no code changes</li>
</ul>
<p>Lightning helps you focus on the science by handling the engineering details.</p>
</section>
<section id="installation" class="level2">
<h2 class="anchored" data-anchor-id="installation" id="installation">Installation</h2>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install pytorch-lightning</span></code></pre></div></div>
<p>For the latest features, you can install from the source:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install git+https://github.com/Lightning-AI/lightning.git</span></code></pre></div></div>
</section>
<section id="basic-structure-the-lightningmodule" class="level2">
<h2 class="anchored" data-anchor-id="basic-structure-the-lightningmodule" id="basic-structure-the-lightningmodule">Basic Structure: The LightningModule</h2>
<p>The core component in Lightning is the <code>LightningModule</code>, which organizes your PyTorch code into a standardized structure:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb3"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> pytorch_lightning <span class="im">as</span> pl</span>
<span id="cb3-2"><a href="#cb3-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb3-3"><a href="#cb3-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb3-4"><a href="#cb3-4" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn.functional <span class="im">as</span> F</span>
<span id="cb3-5"><a href="#cb3-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MNISTModel(pl.LightningModule):</span>
<span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, lr<span class="op">=</span><span class="fl">0.001</span>):</span>
<span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.save_hyperparameters()  <span class="co"># Saves learning_rate to hparams</span></span>
<span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layer_1 <span class="op">=</span> nn.Linear(<span class="dv">28</span> <span class="op">*</span> <span class="dv">28</span>, <span class="dv">128</span>)</span>
<span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layer_2 <span class="op">=</span> nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.lr <span class="op">=</span> lr</span>
<span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a>        batch_size, channels, width, height <span class="op">=</span> x.size()</span>
<span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(batch_size, <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.layer_1(x)</span>
<span id="cb3-18"><a href="#cb3-18" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(x)</span>
<span id="cb3-19"><a href="#cb3-19" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.layer_2(x)</span>
<span id="cb3-20"><a href="#cb3-20" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb3-21"><a href="#cb3-21" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-22"><a href="#cb3-22" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb3-23"><a href="#cb3-23" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb3-24"><a href="#cb3-24" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb3-25"><a href="#cb3-25" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(logits, y)</span>
<span id="cb3-26"><a href="#cb3-26" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss)</span>
<span id="cb3-27"><a href="#cb3-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> loss</span>
<span id="cb3-28"><a href="#cb3-28" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-29"><a href="#cb3-29" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> validation_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb3-30"><a href="#cb3-30" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb3-31"><a href="#cb3-31" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb3-32"><a href="#cb3-32" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(logits, y)</span>
<span id="cb3-33"><a href="#cb3-33" aria-hidden="true" tabindex="-1"></a>        preds <span class="op">=</span> torch.argmax(logits, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-34"><a href="#cb3-34" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> (preds <span class="op">==</span> y).<span class="bu">float</span>().mean()</span>
<span id="cb3-35"><a href="#cb3-35" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_loss'</span>, loss)</span>
<span id="cb3-36"><a href="#cb3-36" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'val_acc'</span>, acc)</span>
<span id="cb3-37"><a href="#cb3-37" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-38"><a href="#cb3-38" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb3-39"><a href="#cb3-39" aria-hidden="true" tabindex="-1"></a>        x, y <span class="op">=</span> batch</span>
<span id="cb3-40"><a href="#cb3-40" aria-hidden="true" tabindex="-1"></a>        logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb3-41"><a href="#cb3-41" aria-hidden="true" tabindex="-1"></a>        loss <span class="op">=</span> F.cross_entropy(logits, y)</span>
<span id="cb3-42"><a href="#cb3-42" aria-hidden="true" tabindex="-1"></a>        preds <span class="op">=</span> torch.argmax(logits, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb3-43"><a href="#cb3-43" aria-hidden="true" tabindex="-1"></a>        acc <span class="op">=</span> (preds <span class="op">==</span> y).<span class="bu">float</span>().mean()</span>
<span id="cb3-44"><a href="#cb3-44" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'test_loss'</span>, loss)</span>
<span id="cb3-45"><a href="#cb3-45" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.log(<span class="st">'test_acc'</span>, acc)</span>
<span id="cb3-46"><a href="#cb3-46" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb3-47"><a href="#cb3-47" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb3-48"><a href="#cb3-48" aria-hidden="true" tabindex="-1"></a>        optimizer <span class="op">=</span> torch.optim.Adam(<span class="va">self</span>.parameters(), lr<span class="op">=</span><span class="va">self</span>.lr)</span>
<span id="cb3-49"><a href="#cb3-49" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> optimizer</span></code></pre></div></div>
<p>This basic structure includes:</p>
<ul>
<li>Model architecture (<code>__init__</code> and <code>forward</code>)</li>
<li>Training logic (<code>training_step</code>)</li>
<li>Validation logic (<code>validation_step</code>)</li>
<li>Test logic (<code>test_step</code>)</li>
<li>Optimization setup (<code>configure_optimizers</code>)</li>
</ul>
</section>
<section id="datamodules" class="level2">
<h2 class="anchored" data-anchor-id="datamodules" id="datamodules">DataModules</h2>
<p>Lightning’s <code>LightningDataModule</code> encapsulates all data-related logic:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb4"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning <span class="im">import</span> LightningDataModule</span>
<span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torch.utils.data <span class="im">import</span> DataLoader, random_split</span>
<span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision <span class="im">import</span> transforms</span>
<span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> torchvision.datasets <span class="im">import</span> MNIST</span>
<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> MNISTDataModule(LightningDataModule):</span>
<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, data_dir<span class="op">=</span><span class="st">'./data'</span>, batch_size<span class="op">=</span><span class="dv">32</span>):</span>
<span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.data_dir <span class="op">=</span> data_dir</span>
<span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.batch_size <span class="op">=</span> batch_size</span>
<span id="cb4-11"><a href="#cb4-11" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.transform <span class="op">=</span> transforms.Compose([</span>
<span id="cb4-12"><a href="#cb4-12" aria-hidden="true" tabindex="-1"></a>            transforms.ToTensor(),</span>
<span id="cb4-13"><a href="#cb4-13" aria-hidden="true" tabindex="-1"></a>            transforms.Normalize((<span class="fl">0.1307</span>,), (<span class="fl">0.3081</span>,))</span>
<span id="cb4-14"><a href="#cb4-14" aria-hidden="true" tabindex="-1"></a>        ])</span>
<span id="cb4-15"><a href="#cb4-15" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-16"><a href="#cb4-16" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> prepare_data(<span class="va">self</span>):</span>
<span id="cb4-17"><a href="#cb4-17" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Download data if needed</span></span>
<span id="cb4-18"><a href="#cb4-18" aria-hidden="true" tabindex="-1"></a>        MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">True</span>, download<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-19"><a href="#cb4-19" aria-hidden="true" tabindex="-1"></a>        MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">False</span>, download<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb4-20"><a href="#cb4-20" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-21"><a href="#cb4-21" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> setup(<span class="va">self</span>, stage<span class="op">=</span><span class="va">None</span>):</span>
<span id="cb4-22"><a href="#cb4-22" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Assign train/val/test datasets</span></span>
<span id="cb4-23"><a href="#cb4-23" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> stage <span class="op">==</span> <span class="st">'fit'</span> <span class="kw">or</span> stage <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb4-24"><a href="#cb4-24" aria-hidden="true" tabindex="-1"></a>            mnist_full <span class="op">=</span> MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">True</span>, transform<span class="op">=</span><span class="va">self</span>.transform)</span>
<span id="cb4-25"><a href="#cb4-25" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mnist_train, <span class="va">self</span>.mnist_val <span class="op">=</span> random_split(mnist_full, [<span class="dv">55000</span>, <span class="dv">5000</span>])</span>
<span id="cb4-26"><a href="#cb4-26" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-27"><a href="#cb4-27" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> stage <span class="op">==</span> <span class="st">'test'</span> <span class="kw">or</span> stage <span class="kw">is</span> <span class="va">None</span>:</span>
<span id="cb4-28"><a href="#cb4-28" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mnist_test <span class="op">=</span> MNIST(<span class="va">self</span>.data_dir, train<span class="op">=</span><span class="va">False</span>, transform<span class="op">=</span><span class="va">self</span>.transform)</span>
<span id="cb4-29"><a href="#cb4-29" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb4-30"><a href="#cb4-30" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> train_dataloader(<span class="va">self</span>):</span>
<span id="cb4-31"><a href="#cb4-31" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_train, batch_size<span class="op">=</span><span class="va">self</span>.batch_size)</span>
<span id="cb4-32"><a href="#cb4-32" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-33"><a href="#cb4-33" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> val_dataloader(<span class="va">self</span>):</span>
<span id="cb4-34"><a href="#cb4-34" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_val, batch_size<span class="op">=</span><span class="va">self</span>.batch_size)</span>
<span id="cb4-35"><a href="#cb4-35" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb4-36"><a href="#cb4-36" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> test_dataloader(<span class="va">self</span>):</span>
<span id="cb4-37"><a href="#cb4-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> DataLoader(<span class="va">self</span>.mnist_test, batch_size<span class="op">=</span><span class="va">self</span>.batch_size)</span></code></pre></div></div>
<p>Benefits of <code>LightningDataModule</code>:</p>
<ul>
<li>Encapsulates all data preparation logic</li>
<li>Makes data pipeline portable and reproducible</li>
<li>Simplifies sharing data pipelines between projects</li>
</ul>
</section>
<section id="training-with-trainer" class="level2">
<h2 class="anchored" data-anchor-id="training-with-trainer" id="training-with-trainer">Training with Trainer</h2>
<p>The Lightning <code>Trainer</code> handles the training loop and validation:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb5"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning <span class="im">import</span> Trainer</span>
<span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model and data module</span></span>
<span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel()</span>
<span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a>data_module <span class="op">=</span> MNISTDataModule()</span>
<span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize trainer</span></span>
<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"auto"</span>,  <span class="co"># Automatically use available GPU</span></span>
<span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="st">"auto"</span>,</span>
<span id="cb5-12"><a href="#cb5-12" aria-hidden="true" tabindex="-1"></a>    logger<span class="op">=</span><span class="va">True</span>,         <span class="co"># Use TensorBoard logger by default</span></span>
<span id="cb5-13"><a href="#cb5-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb5-14"><a href="#cb5-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-15"><a href="#cb5-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Train model</span></span>
<span id="cb5-16"><a href="#cb5-16" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, datamodule<span class="op">=</span>data_module)</span>
<span id="cb5-17"><a href="#cb5-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb5-18"><a href="#cb5-18" aria-hidden="true" tabindex="-1"></a><span class="co"># Test model</span></span>
<span id="cb5-19"><a href="#cb5-19" aria-hidden="true" tabindex="-1"></a>trainer.test(model, datamodule<span class="op">=</span>data_module)</span></code></pre></div></div>
<p>The <code>Trainer</code> automatically handles:</p>
<ul>
<li>Epoch and batch iteration</li>
<li>Optimizer steps</li>
<li>Logging metrics</li>
<li>Hardware acceleration (CPU, GPU, TPU)</li>
<li>Early stopping</li>
<li>Checkpointing</li>
<li>Multi-GPU training</li>
</ul>
<section id="common-trainer-parameters" class="level3">
<h3 class="anchored" data-anchor-id="common-trainer-parameters" id="common-trainer-parameters">Common Trainer Parameters</h3>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb6"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">10</span>,                    <span class="co"># Maximum number of epochs</span></span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a>    min_epochs<span class="op">=</span><span class="dv">1</span>,                     <span class="co"># Minimum number of epochs</span></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"cpu"</span>,                <span class="co"># Use CPU acceleration</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="dv">1</span>,                        <span class="co"># Use 1 CPUs</span></span>
<span id="cb6-6"><a href="#cb6-6" aria-hidden="true" tabindex="-1"></a>    precision<span class="op">=</span><span class="st">"16-mixed"</span>,             <span class="co"># Use mixed precision for faster training</span></span>
<span id="cb6-7"><a href="#cb6-7" aria-hidden="true" tabindex="-1"></a>    gradient_clip_val<span class="op">=</span><span class="fl">0.5</span>,            <span class="co"># Clip gradients</span></span>
<span id="cb6-8"><a href="#cb6-8" aria-hidden="true" tabindex="-1"></a>    accumulate_grad_batches<span class="op">=</span><span class="dv">4</span>,        <span class="co"># Accumulate gradients over 4 batches</span></span>
<span id="cb6-9"><a href="#cb6-9" aria-hidden="true" tabindex="-1"></a>    log_every_n_steps<span class="op">=</span><span class="dv">50</span>,             <span class="co"># Log metrics every 50 steps</span></span>
<span id="cb6-10"><a href="#cb6-10" aria-hidden="true" tabindex="-1"></a>    val_check_interval<span class="op">=</span><span class="fl">0.25</span>,          <span class="co"># Run validation 4 times per epoch</span></span>
<span id="cb6-11"><a href="#cb6-11" aria-hidden="true" tabindex="-1"></a>    fast_dev_run<span class="op">=</span><span class="va">False</span>,               <span class="co"># Debug mode (only run a few batches)</span></span>
<span id="cb6-12"><a href="#cb6-12" aria-hidden="true" tabindex="-1"></a>    deterministic<span class="op">=</span><span class="va">True</span>,               <span class="co"># Make training deterministic</span></span>
<span id="cb6-13"><a href="#cb6-13" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
</section>
<section id="callbacks" class="level2">
<h2 class="anchored" data-anchor-id="callbacks" id="callbacks">Callbacks</h2>
<p>Callbacks add functionality to the training loop without modifying the core code:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb7"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.callbacks <span class="im">import</span> ModelCheckpoint, EarlyStopping, LearningRateMonitor</span>
<span id="cb7-2"><a href="#cb7-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-3"><a href="#cb7-3" aria-hidden="true" tabindex="-1"></a><span class="co"># Save the best model based on validation accuracy</span></span>
<span id="cb7-4"><a href="#cb7-4" aria-hidden="true" tabindex="-1"></a>checkpoint_callback <span class="op">=</span> ModelCheckpoint(</span>
<span id="cb7-5"><a href="#cb7-5" aria-hidden="true" tabindex="-1"></a>    monitor<span class="op">=</span><span class="st">'val_acc'</span>,</span>
<span id="cb7-6"><a href="#cb7-6" aria-hidden="true" tabindex="-1"></a>    dirpath<span class="op">=</span><span class="st">'./checkpoints'</span>,</span>
<span id="cb7-7"><a href="#cb7-7" aria-hidden="true" tabindex="-1"></a>    filename<span class="op">=</span><span class="st">'mnist-</span><span class="sc">{epoch:02d}</span><span class="st">-</span><span class="sc">{val_acc:.2f}</span><span class="st">'</span>,</span>
<span id="cb7-8"><a href="#cb7-8" aria-hidden="true" tabindex="-1"></a>    save_top_k<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb7-9"><a href="#cb7-9" aria-hidden="true" tabindex="-1"></a>    mode<span class="op">=</span><span class="st">'max'</span>,</span>
<span id="cb7-10"><a href="#cb7-10" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb7-11"><a href="#cb7-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-12"><a href="#cb7-12" aria-hidden="true" tabindex="-1"></a><span class="co"># Stop training when validation loss stops improving</span></span>
<span id="cb7-13"><a href="#cb7-13" aria-hidden="true" tabindex="-1"></a>early_stop_callback <span class="op">=</span> EarlyStopping(</span>
<span id="cb7-14"><a href="#cb7-14" aria-hidden="true" tabindex="-1"></a>    monitor<span class="op">=</span><span class="st">'val_loss'</span>,</span>
<span id="cb7-15"><a href="#cb7-15" aria-hidden="true" tabindex="-1"></a>    patience<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb7-16"><a href="#cb7-16" aria-hidden="true" tabindex="-1"></a>    verbose<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb7-17"><a href="#cb7-17" aria-hidden="true" tabindex="-1"></a>    mode<span class="op">=</span><span class="st">'min'</span></span>
<span id="cb7-18"><a href="#cb7-18" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb7-19"><a href="#cb7-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-20"><a href="#cb7-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Monitor learning rate</span></span>
<span id="cb7-21"><a href="#cb7-21" aria-hidden="true" tabindex="-1"></a>lr_monitor <span class="op">=</span> LearningRateMonitor(logging_interval<span class="op">=</span><span class="st">'step'</span>)</span>
<span id="cb7-22"><a href="#cb7-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb7-23"><a href="#cb7-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize trainer with callbacks</span></span>
<span id="cb7-24"><a href="#cb7-24" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb7-25"><a href="#cb7-25" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb7-26"><a href="#cb7-26" aria-hidden="true" tabindex="-1"></a>    callbacks<span class="op">=</span>[checkpoint_callback, early_stop_callback, lr_monitor]</span>
<span id="cb7-27"><a href="#cb7-27" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<section id="custom-callback" class="level3">
<h3 class="anchored" data-anchor-id="custom-callback" id="custom-callback">Custom Callback</h3>
<p>You can create custom callbacks by extending the base <code>Callback</code> class:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb8"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.callbacks <span class="im">import</span> Callback</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> PrintingCallback(Callback):</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_train_start(<span class="va">self</span>, trainer, pl_module):</span>
<span id="cb8-5"><a href="#cb8-5" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Training has started!"</span>)</span>
<span id="cb8-6"><a href="#cb8-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-7"><a href="#cb8-7" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_train_end(<span class="va">self</span>, trainer, pl_module):</span>
<span id="cb8-8"><a href="#cb8-8" aria-hidden="true" tabindex="-1"></a>        <span class="bu">print</span>(<span class="st">"Training has finished!"</span>)</span>
<span id="cb8-9"><a href="#cb8-9" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb8-10"><a href="#cb8-10" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> on_train_batch_end(<span class="va">self</span>, trainer, pl_module, outputs, batch, batch_idx):</span>
<span id="cb8-11"><a href="#cb8-11" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb8-12"><a href="#cb8-12" aria-hidden="true" tabindex="-1"></a>            <span class="bu">print</span>(<span class="ss">f"Batch </span><span class="sc">{</span>batch_idx<span class="sc">}</span><span class="ss"> completed"</span>)</span></code></pre></div></div>
</section>
</section>
<section id="logging" class="level2">
<h2 class="anchored" data-anchor-id="logging" id="logging">Logging</h2>
<p>Lightning supports various loggers:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb9"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><a href="#cb9-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.loggers <span class="im">import</span> TensorBoardLogger, WandbLogger</span>
<span id="cb9-2"><a href="#cb9-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-3"><a href="#cb9-3" aria-hidden="true" tabindex="-1"></a><span class="co"># TensorBoard logger</span></span>
<span id="cb9-4"><a href="#cb9-4" aria-hidden="true" tabindex="-1"></a>tensorboard_logger <span class="op">=</span> TensorBoardLogger(save_dir<span class="op">=</span><span class="st">"logs/"</span>, name<span class="op">=</span><span class="st">"mnist"</span>)</span>
<span id="cb9-5"><a href="#cb9-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb9-6"><a href="#cb9-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Initialize trainer with loggers</span></span>
<span id="cb9-7"><a href="#cb9-7" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb9-8"><a href="#cb9-8" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb9-9"><a href="#cb9-9" aria-hidden="true" tabindex="-1"></a>    logger<span class="op">=</span>[tensorboard_logger]</span>
<span id="cb9-10"><a href="#cb9-10" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>Adding metrics in your model:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb10"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><a href="#cb10-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> training_step(<span class="va">self</span>, batch, batch_idx):</span>
<span id="cb10-2"><a href="#cb10-2" aria-hidden="true" tabindex="-1"></a>    x, y <span class="op">=</span> batch</span>
<span id="cb10-3"><a href="#cb10-3" aria-hidden="true" tabindex="-1"></a>    logits <span class="op">=</span> <span class="va">self</span>(x)</span>
<span id="cb10-4"><a href="#cb10-4" aria-hidden="true" tabindex="-1"></a>    loss <span class="op">=</span> F.cross_entropy(logits, y)</span>
<span id="cb10-5"><a href="#cb10-5" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-6"><a href="#cb10-6" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log scalar values</span></span>
<span id="cb10-7"><a href="#cb10-7" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'train_loss'</span>, loss)</span>
<span id="cb10-8"><a href="#cb10-8" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log learning rate</span></span>
<span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a>    lr <span class="op">=</span> <span class="va">self</span>.optimizers().param_groups[<span class="dv">0</span>][<span class="st">'lr'</span>]</span>
<span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>    <span class="va">self</span>.log(<span class="st">'learning_rate'</span>, lr)</span>
<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Log histograms (on GPU)</span></span>
<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a>    <span class="cf">if</span> batch_idx <span class="op">%</span> <span class="dv">100</span> <span class="op">==</span> <span class="dv">0</span>:</span>
<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.logger.experiment.add_histogram(<span class="st">'logits'</span>, logits, <span class="va">self</span>.global_step)</span>
<span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb10-17"><a href="#cb10-17" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> loss</span></code></pre></div></div>
</section>
<section id="distributed-training" class="level2">
<h2 class="anchored" data-anchor-id="distributed-training" id="distributed-training">Distributed Training</h2>
<p>Lightning makes distributed training simple:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb11"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><a href="#cb11-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Single GPU</span></span>
<span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(accelerator<span class="op">=</span><span class="st">"gpu"</span>, devices<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Multiple GPUs (DDP strategy automatically selected)</span></span>
<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(accelerator<span class="op">=</span><span class="st">"gpu"</span>, devices<span class="op">=</span><span class="dv">4</span>)</span>
<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a><span class="co"># Specific GPU indices</span></span>
<span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(accelerator<span class="op">=</span><span class="st">"gpu"</span>, devices<span class="op">=</span>[<span class="dv">0</span>, <span class="dv">1</span>, <span class="dv">3</span>])</span>
<span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-10"><a href="#cb11-10" aria-hidden="true" tabindex="-1"></a><span class="co"># TPU cores</span></span>
<span id="cb11-11"><a href="#cb11-11" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(accelerator<span class="op">=</span><span class="st">"tpu"</span>, devices<span class="op">=</span><span class="dv">8</span>)</span>
<span id="cb11-12"><a href="#cb11-12" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb11-13"><a href="#cb11-13" aria-hidden="true" tabindex="-1"></a><span class="co"># Advanced distributed training</span></span>
<span id="cb11-14"><a href="#cb11-14" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb11-15"><a href="#cb11-15" aria-hidden="true" tabindex="-1"></a>    accelerator<span class="op">=</span><span class="st">"gpu"</span>,</span>
<span id="cb11-16"><a href="#cb11-16" aria-hidden="true" tabindex="-1"></a>    devices<span class="op">=</span><span class="dv">4</span>,</span>
<span id="cb11-17"><a href="#cb11-17" aria-hidden="true" tabindex="-1"></a>    strategy<span class="op">=</span><span class="st">"ddp"</span>,  <span class="co"># Distributed Data Parallel</span></span>
<span id="cb11-18"><a href="#cb11-18" aria-hidden="true" tabindex="-1"></a>    num_nodes<span class="op">=</span><span class="dv">2</span>      <span class="co"># Use 2 machines</span></span>
<span id="cb11-19"><a href="#cb11-19" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
<p>Other distribution strategies:</p>
<ul>
<li><code>ddp_spawn</code>: Similar to DDP but uses spawn for multiprocessing</li>
<li><code>deepspeed</code>: For very large models using DeepSpeed</li>
<li><code>fsdp</code>: Fully Sharded Data Parallel for huge models</li>
</ul>
</section>
<section id="hyperparameter-tuning" class="level2">
<h2 class="anchored" data-anchor-id="hyperparameter-tuning" id="hyperparameter-tuning">Hyperparameter Tuning</h2>
<p>Lightning integrates well with optimization libraries:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb12"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><a href="#cb12-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning <span class="im">import</span> Trainer</span>
<span id="cb12-2"><a href="#cb12-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.tuner <span class="im">import</span> Tuner</span>
<span id="cb12-3"><a href="#cb12-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-4"><a href="#cb12-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Create model</span></span>
<span id="cb12-5"><a href="#cb12-5" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel()</span>
<span id="cb12-6"><a href="#cb12-6" aria-hidden="true" tabindex="-1"></a>data_module <span class="op">=</span> MNISTDataModule()</span>
<span id="cb12-7"><a href="#cb12-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-8"><a href="#cb12-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Create trainer</span></span>
<span id="cb12-9"><a href="#cb12-9" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(max_epochs<span class="op">=</span><span class="dv">10</span>)</span>
<span id="cb12-10"><a href="#cb12-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-11"><a href="#cb12-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Use Lightning's Tuner for auto-scaling batch size</span></span>
<span id="cb12-12"><a href="#cb12-12" aria-hidden="true" tabindex="-1"></a>tuner <span class="op">=</span> Tuner(trainer)</span>
<span id="cb12-13"><a href="#cb12-13" aria-hidden="true" tabindex="-1"></a>tuner.scale_batch_size(model, datamodule<span class="op">=</span>data_module)</span>
<span id="cb12-14"><a href="#cb12-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-15"><a href="#cb12-15" aria-hidden="true" tabindex="-1"></a><span class="co"># Find optimal learning rate</span></span>
<span id="cb12-16"><a href="#cb12-16" aria-hidden="true" tabindex="-1"></a>lr_finder <span class="op">=</span> tuner.lr_find(model, datamodule<span class="op">=</span>data_module)</span>
<span id="cb12-17"><a href="#cb12-17" aria-hidden="true" tabindex="-1"></a>model.learning_rate <span class="op">=</span> lr_finder.suggestion()</span>
<span id="cb12-18"><a href="#cb12-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb12-19"><a href="#cb12-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Train with optimized parameters</span></span>
<span id="cb12-20"><a href="#cb12-20" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, datamodule<span class="op">=</span>data_module)</span></code></pre></div></div>
</section>
<section id="model-checkpointing" class="level2">
<h2 class="anchored" data-anchor-id="model-checkpointing" id="model-checkpointing">Model Checkpointing</h2>
<p>Saving and loading model checkpoints:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb13"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><a href="#cb13-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Save checkpoints during training</span></span>
<span id="cb13-2"><a href="#cb13-2" aria-hidden="true" tabindex="-1"></a>checkpoint_callback <span class="op">=</span> ModelCheckpoint(</span>
<span id="cb13-3"><a href="#cb13-3" aria-hidden="true" tabindex="-1"></a>    dirpath<span class="op">=</span><span class="st">'./checkpoints'</span>,</span>
<span id="cb13-4"><a href="#cb13-4" aria-hidden="true" tabindex="-1"></a>    filename<span class="op">=</span><span class="st">'</span><span class="sc">{epoch}</span><span class="st">-</span><span class="sc">{val_loss:.2f}</span><span class="st">'</span>,</span>
<span id="cb13-5"><a href="#cb13-5" aria-hidden="true" tabindex="-1"></a>    monitor<span class="op">=</span><span class="st">'val_loss'</span>,</span>
<span id="cb13-6"><a href="#cb13-6" aria-hidden="true" tabindex="-1"></a>    save_top_k<span class="op">=</span><span class="dv">3</span>,</span>
<span id="cb13-7"><a href="#cb13-7" aria-hidden="true" tabindex="-1"></a>    mode<span class="op">=</span><span class="st">'min'</span></span>
<span id="cb13-8"><a href="#cb13-8" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb13-9"><a href="#cb13-9" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-10"><a href="#cb13-10" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb13-11"><a href="#cb13-11" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">10</span>,</span>
<span id="cb13-12"><a href="#cb13-12" aria-hidden="true" tabindex="-1"></a>    callbacks<span class="op">=</span>[checkpoint_callback]</span>
<span id="cb13-13"><a href="#cb13-13" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb13-14"><a href="#cb13-14" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-15"><a href="#cb13-15" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, datamodule<span class="op">=</span>data_module)</span>
<span id="cb13-16"><a href="#cb13-16" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-17"><a href="#cb13-17" aria-hidden="true" tabindex="-1"></a><span class="co"># Path to best checkpoint</span></span>
<span id="cb13-18"><a href="#cb13-18" aria-hidden="true" tabindex="-1"></a>best_model_path <span class="op">=</span> checkpoint_callback.best_model_path</span>
<span id="cb13-19"><a href="#cb13-19" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-20"><a href="#cb13-20" aria-hidden="true" tabindex="-1"></a><span class="co"># Load checkpoint into model</span></span>
<span id="cb13-21"><a href="#cb13-21" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel.load_from_checkpoint(best_model_path)</span>
<span id="cb13-22"><a href="#cb13-22" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb13-23"><a href="#cb13-23" aria-hidden="true" tabindex="-1"></a><span class="co"># Continue training from checkpoint</span></span>
<span id="cb13-24"><a href="#cb13-24" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, datamodule<span class="op">=</span>data_module)</span></code></pre></div></div>
<p>Checkpointing for production:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb14"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><a href="#cb14-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Save model in production-ready format</span></span>
<span id="cb14-2"><a href="#cb14-2" aria-hidden="true" tabindex="-1"></a>trainer.save_checkpoint(<span class="st">"model.ckpt"</span>)</span>
<span id="cb14-3"><a href="#cb14-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-4"><a href="#cb14-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Extract state dict for production</span></span>
<span id="cb14-5"><a href="#cb14-5" aria-hidden="true" tabindex="-1"></a>checkpoint <span class="op">=</span> torch.load(<span class="st">"model.ckpt"</span>)</span>
<span id="cb14-6"><a href="#cb14-6" aria-hidden="true" tabindex="-1"></a>model_state_dict <span class="op">=</span> checkpoint[<span class="st">"state_dict"</span>]</span>
<span id="cb14-7"><a href="#cb14-7" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb14-8"><a href="#cb14-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Save just the PyTorch model for production</span></span>
<span id="cb14-9"><a href="#cb14-9" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel()</span>
<span id="cb14-10"><a href="#cb14-10" aria-hidden="true" tabindex="-1"></a>model.load_state_dict(model_state_dict)</span>
<span id="cb14-11"><a href="#cb14-11" aria-hidden="true" tabindex="-1"></a>torch.save(model.state_dict(), <span class="st">"production_model.pt"</span>)</span></code></pre></div></div>
</section>
<section id="production-deployment" class="level2">
<h2 class="anchored" data-anchor-id="production-deployment" id="production-deployment">Production Deployment</h2>
<p>Converting Lightning models to production:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb15"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1"><a href="#cb15-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Option 1: Use the Lightning model directly</span></span>
<span id="cb15-2"><a href="#cb15-2" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel.load_from_checkpoint(<span class="st">"model.ckpt"</span>)</span>
<span id="cb15-3"><a href="#cb15-3" aria-hidden="true" tabindex="-1"></a>model.<span class="bu">eval</span>()</span>
<span id="cb15-4"><a href="#cb15-4" aria-hidden="true" tabindex="-1"></a>model.freeze()  <span class="co"># Freeze the model parameters</span></span>
<span id="cb15-5"><a href="#cb15-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-6"><a href="#cb15-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Make predictions</span></span>
<span id="cb15-7"><a href="#cb15-7" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> torch.no_grad():</span>
<span id="cb15-8"><a href="#cb15-8" aria-hidden="true" tabindex="-1"></a>    x <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">28</span>, <span class="dv">28</span>)</span>
<span id="cb15-9"><a href="#cb15-9" aria-hidden="true" tabindex="-1"></a>    y_hat <span class="op">=</span> model(x)</span>
<span id="cb15-10"><a href="#cb15-10" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-11"><a href="#cb15-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Option 2: Extract the PyTorch model</span></span>
<span id="cb15-12"><a href="#cb15-12" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> ProductionModel(nn.Module):</span>
<span id="cb15-13"><a href="#cb15-13" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb15-14"><a href="#cb15-14" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb15-15"><a href="#cb15-15" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layer_1 <span class="op">=</span> nn.Linear(<span class="dv">28</span> <span class="op">*</span> <span class="dv">28</span>, <span class="dv">128</span>)</span>
<span id="cb15-16"><a href="#cb15-16" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.layer_2 <span class="op">=</span> nn.Linear(<span class="dv">128</span>, <span class="dv">10</span>)</span>
<span id="cb15-17"><a href="#cb15-17" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb15-18"><a href="#cb15-18" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, x):</span>
<span id="cb15-19"><a href="#cb15-19" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> x.view(x.size(<span class="dv">0</span>), <span class="op">-</span><span class="dv">1</span>)</span>
<span id="cb15-20"><a href="#cb15-20" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> F.relu(<span class="va">self</span>.layer_1(x))</span>
<span id="cb15-21"><a href="#cb15-21" aria-hidden="true" tabindex="-1"></a>        x <span class="op">=</span> <span class="va">self</span>.layer_2(x)</span>
<span id="cb15-22"><a href="#cb15-22" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> x</span>
<span id="cb15-23"><a href="#cb15-23" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-24"><a href="#cb15-24" aria-hidden="true" tabindex="-1"></a><span class="co"># Load weights from Lightning model</span></span>
<span id="cb15-25"><a href="#cb15-25" aria-hidden="true" tabindex="-1"></a>lightning_model <span class="op">=</span> MNISTModel.load_from_checkpoint(<span class="st">"model.ckpt"</span>)</span>
<span id="cb15-26"><a href="#cb15-26" aria-hidden="true" tabindex="-1"></a>production_model <span class="op">=</span> ProductionModel()</span>
<span id="cb15-27"><a href="#cb15-27" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-28"><a href="#cb15-28" aria-hidden="true" tabindex="-1"></a><span class="co"># Copy weights</span></span>
<span id="cb15-29"><a href="#cb15-29" aria-hidden="true" tabindex="-1"></a>production_model.layer_1.weight.data <span class="op">=</span> lightning_model.layer_1.weight.data</span>
<span id="cb15-30"><a href="#cb15-30" aria-hidden="true" tabindex="-1"></a>production_model.layer_1.bias.data <span class="op">=</span> lightning_model.layer_1.bias.data</span>
<span id="cb15-31"><a href="#cb15-31" aria-hidden="true" tabindex="-1"></a>production_model.layer_2.weight.data <span class="op">=</span> lightning_model.layer_2.weight.data</span>
<span id="cb15-32"><a href="#cb15-32" aria-hidden="true" tabindex="-1"></a>production_model.layer_2.bias.data <span class="op">=</span> lightning_model.layer_2.bias.data</span>
<span id="cb15-33"><a href="#cb15-33" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-34"><a href="#cb15-34" aria-hidden="true" tabindex="-1"></a><span class="co"># Save production model</span></span>
<span id="cb15-35"><a href="#cb15-35" aria-hidden="true" tabindex="-1"></a>torch.save(production_model.state_dict(), <span class="st">"production_model.pt"</span>)</span>
<span id="cb15-36"><a href="#cb15-36" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb15-37"><a href="#cb15-37" aria-hidden="true" tabindex="-1"></a><span class="co"># Export to ONNX</span></span>
<span id="cb15-38"><a href="#cb15-38" aria-hidden="true" tabindex="-1"></a>dummy_input <span class="op">=</span> torch.randn(<span class="dv">1</span>, <span class="dv">1</span>, <span class="dv">28</span>, <span class="dv">28</span>)</span>
<span id="cb15-39"><a href="#cb15-39" aria-hidden="true" tabindex="-1"></a>torch.onnx.export(</span>
<span id="cb15-40"><a href="#cb15-40" aria-hidden="true" tabindex="-1"></a>    production_model,</span>
<span id="cb15-41"><a href="#cb15-41" aria-hidden="true" tabindex="-1"></a>    dummy_input,</span>
<span id="cb15-42"><a href="#cb15-42" aria-hidden="true" tabindex="-1"></a>    <span class="st">"model.onnx"</span>,</span>
<span id="cb15-43"><a href="#cb15-43" aria-hidden="true" tabindex="-1"></a>    export_params<span class="op">=</span><span class="va">True</span>,</span>
<span id="cb15-44"><a href="#cb15-44" aria-hidden="true" tabindex="-1"></a>    opset_version<span class="op">=</span><span class="dv">11</span>,</span>
<span id="cb15-45"><a href="#cb15-45" aria-hidden="true" tabindex="-1"></a>    input_names<span class="op">=</span>[<span class="st">'input'</span>],</span>
<span id="cb15-46"><a href="#cb15-46" aria-hidden="true" tabindex="-1"></a>    output_names<span class="op">=</span>[<span class="st">'output'</span>]</span>
<span id="cb15-47"><a href="#cb15-47" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="best-practices" class="level2">
<h2 class="anchored" data-anchor-id="best-practices" id="best-practices">Best Practices</h2>
<section id="code-organization" class="level3">
<h3 class="anchored" data-anchor-id="code-organization" id="code-organization">Code Organization</h3>
<p>A well-organized Lightning project structure:</p>
<pre><code>project/
├── configs/              # Configuration files
├── data/                 # Data files
├── lightning_logs/       # Generated logs
├── models/               # Model definitions
│   ├── __init__.py
│   └── mnist_model.py    # LightningModule
├── data_modules/         # Data modules
│   ├── __init__.py
│   └── mnist_data.py     # LightningDataModule
├── callbacks/            # Custom callbacks
├── utils/                # Utility functions
├── main.py               # Training script
└── README.md</code></pre>
</section>
<section id="lightning-cli" class="level3">
<h3 class="anchored" data-anchor-id="lightning-cli" id="lightning-cli">Lightning CLI</h3>
<p>Lightning provides a CLI for running experiments from config files:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb17"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><a href="#cb17-1" aria-hidden="true" tabindex="-1"></a><span class="co"># main.py</span></span>
<span id="cb17-2"><a href="#cb17-2" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.cli <span class="im">import</span> LightningCLI</span>
<span id="cb17-3"><a href="#cb17-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-4"><a href="#cb17-4" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> cli_main():</span>
<span id="cb17-5"><a href="#cb17-5" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create CLI with LightningModule and LightningDataModule</span></span>
<span id="cb17-6"><a href="#cb17-6" aria-hidden="true" tabindex="-1"></a>    cli <span class="op">=</span> LightningCLI(</span>
<span id="cb17-7"><a href="#cb17-7" aria-hidden="true" tabindex="-1"></a>        MNISTModel,</span>
<span id="cb17-8"><a href="#cb17-8" aria-hidden="true" tabindex="-1"></a>        MNISTDataModule,</span>
<span id="cb17-9"><a href="#cb17-9" aria-hidden="true" tabindex="-1"></a>        save_config_callback<span class="op">=</span><span class="va">None</span>,</span>
<span id="cb17-10"><a href="#cb17-10" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb17-11"><a href="#cb17-11" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb17-12"><a href="#cb17-12" aria-hidden="true" tabindex="-1"></a><span class="cf">if</span> <span class="va">__name__</span> <span class="op">==</span> <span class="st">"__main__"</span>:</span>
<span id="cb17-13"><a href="#cb17-13" aria-hidden="true" tabindex="-1"></a>    cli_main()</span></code></pre></div></div>
<p>Run with:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb18"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb18-1"><a href="#cb18-1" aria-hidden="true" tabindex="-1"></a><span class="ex">python</span> main.py fit <span class="at">--config</span> configs/mnist.yaml</span></code></pre></div></div>
<p>Example config file (mnist.yaml):</p>
<pre class="{yaml}"><code>model:
  learning_rate: 0.001
  hidden_size: 128
data:
  data_dir: ./data
  batch_size: 64
trainer:
  max_epochs: 10
  accelerator: gpu
  devices: 1</code></pre>
</section>
<section id="profiling" class="level3">
<h3 class="anchored" data-anchor-id="profiling" id="profiling">Profiling</h3>
<p>Lightning includes tools to profile your code:</p>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb20"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1"><a href="#cb20-1" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> pytorch_lightning.profilers <span class="im">import</span> PyTorchProfiler, SimpleProfiler</span>
<span id="cb20-2"><a href="#cb20-2" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-3"><a href="#cb20-3" aria-hidden="true" tabindex="-1"></a><span class="co"># PyTorch Profiler</span></span>
<span id="cb20-4"><a href="#cb20-4" aria-hidden="true" tabindex="-1"></a>profiler <span class="op">=</span> PyTorchProfiler(</span>
<span id="cb20-5"><a href="#cb20-5" aria-hidden="true" tabindex="-1"></a>    on_trace_ready<span class="op">=</span>torch.profiler.tensorboard_trace_handler(<span class="st">"logs/profiler"</span>),</span>
<span id="cb20-6"><a href="#cb20-6" aria-hidden="true" tabindex="-1"></a>    schedule<span class="op">=</span>torch.profiler.schedule(wait<span class="op">=</span><span class="dv">1</span>, warmup<span class="op">=</span><span class="dv">1</span>, active<span class="op">=</span><span class="dv">3</span>, repeat<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb20-7"><a href="#cb20-7" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb20-8"><a href="#cb20-8" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-9"><a href="#cb20-9" aria-hidden="true" tabindex="-1"></a><span class="co"># Simple Profiler</span></span>
<span id="cb20-10"><a href="#cb20-10" aria-hidden="true" tabindex="-1"></a><span class="co"># profiler = SimpleProfiler()</span></span>
<span id="cb20-11"><a href="#cb20-11" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> MNISTModel()</span>
<span id="cb20-12"><a href="#cb20-12" aria-hidden="true" tabindex="-1"></a>data_module <span class="op">=</span> MNISTDataModule()</span>
<span id="cb20-13"><a href="#cb20-13" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb20-14"><a href="#cb20-14" aria-hidden="true" tabindex="-1"></a>    max_epochs<span class="op">=</span><span class="dv">5</span>,</span>
<span id="cb20-15"><a href="#cb20-15" aria-hidden="true" tabindex="-1"></a>    profiler<span class="op">=</span>profiler</span>
<span id="cb20-16"><a href="#cb20-16" aria-hidden="true" tabindex="-1"></a>)</span>
<span id="cb20-17"><a href="#cb20-17" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb20-18"><a href="#cb20-18" aria-hidden="true" tabindex="-1"></a>trainer.fit(model, datamodule<span class="op">=</span>data_module)</span></code></pre></div></div>
</section>
<section id="advanced-features" class="level3">
<h3 class="anchored" data-anchor-id="advanced-features" id="advanced-features">Advanced Features</h3>
<section id="gradient-accumulation" class="level4">
<h4 class="anchored" data-anchor-id="gradient-accumulation">Gradient Accumulation</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb21"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><a href="#cb21-1" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb21-2"><a href="#cb21-2" aria-hidden="true" tabindex="-1"></a>    accumulate_grad_batches<span class="op">=</span><span class="dv">4</span>  <span class="co"># Accumulate over 4 batches</span></span>
<span id="cb21-3"><a href="#cb21-3" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="learning-rate-scheduling" class="level4">
<h4 class="anchored" data-anchor-id="learning-rate-scheduling">Learning Rate Scheduling</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb22"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><a href="#cb22-1" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> configure_optimizers(<span class="va">self</span>):</span>
<span id="cb22-2"><a href="#cb22-2" aria-hidden="true" tabindex="-1"></a>    optimizer <span class="op">=</span> torch.optim.Adam(<span class="va">self</span>.parameters(), lr<span class="op">=</span><span class="fl">1e-3</span>)</span>
<span id="cb22-3"><a href="#cb22-3" aria-hidden="true" tabindex="-1"></a>    scheduler <span class="op">=</span> torch.optim.lr_scheduler.ReduceLROnPlateau(</span>
<span id="cb22-4"><a href="#cb22-4" aria-hidden="true" tabindex="-1"></a>        optimizer, </span>
<span id="cb22-5"><a href="#cb22-5" aria-hidden="true" tabindex="-1"></a>        mode<span class="op">=</span><span class="st">'min'</span>, </span>
<span id="cb22-6"><a href="#cb22-6" aria-hidden="true" tabindex="-1"></a>        factor<span class="op">=</span><span class="fl">0.1</span>, </span>
<span id="cb22-7"><a href="#cb22-7" aria-hidden="true" tabindex="-1"></a>        patience<span class="op">=</span><span class="dv">5</span></span>
<span id="cb22-8"><a href="#cb22-8" aria-hidden="true" tabindex="-1"></a>    )</span>
<span id="cb22-9"><a href="#cb22-9" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> {</span>
<span id="cb22-10"><a href="#cb22-10" aria-hidden="true" tabindex="-1"></a>        <span class="st">"optimizer"</span>: optimizer,</span>
<span id="cb22-11"><a href="#cb22-11" aria-hidden="true" tabindex="-1"></a>        <span class="st">"lr_scheduler"</span>: {</span>
<span id="cb22-12"><a href="#cb22-12" aria-hidden="true" tabindex="-1"></a>            <span class="st">"scheduler"</span>: scheduler,</span>
<span id="cb22-13"><a href="#cb22-13" aria-hidden="true" tabindex="-1"></a>            <span class="st">"monitor"</span>: <span class="st">"val_loss"</span>,</span>
<span id="cb22-14"><a href="#cb22-14" aria-hidden="true" tabindex="-1"></a>            <span class="st">"interval"</span>: <span class="st">"epoch"</span>,</span>
<span id="cb22-15"><a href="#cb22-15" aria-hidden="true" tabindex="-1"></a>            <span class="st">"frequency"</span>: <span class="dv">1</span></span>
<span id="cb22-16"><a href="#cb22-16" aria-hidden="true" tabindex="-1"></a>        }</span>
<span id="cb22-17"><a href="#cb22-17" aria-hidden="true" tabindex="-1"></a>    }</span></code></pre></div></div>
</section>
<section id="mixed-precision-training" class="level4">
<h4 class="anchored" data-anchor-id="mixed-precision-training">Mixed Precision Training</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb23"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><a href="#cb23-1" aria-hidden="true" tabindex="-1"></a>trainer <span class="op">=</span> Trainer(</span>
<span id="cb23-2"><a href="#cb23-2" aria-hidden="true" tabindex="-1"></a>    precision<span class="op">=</span><span class="st">"16-mixed"</span>  <span class="co"># Use 16-bit mixed precision</span></span>
<span id="cb23-3"><a href="#cb23-3" aria-hidden="true" tabindex="-1"></a>)</span></code></pre></div></div>
</section>
<section id="transfer-learning" class="level4">
<h4 class="anchored" data-anchor-id="transfer-learning">Transfer Learning</h4>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb24"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><a href="#cb24-1" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> TransferLearningModel(pl.LightningModule):</span>
<span id="cb24-2"><a href="#cb24-2" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>):</span>
<span id="cb24-3"><a href="#cb24-3" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb24-4"><a href="#cb24-4" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Load pretrained model</span></span>
<span id="cb24-5"><a href="#cb24-5" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.feature_extractor <span class="op">=</span> torchvision.models.resnet18(pretrained<span class="op">=</span><span class="va">True</span>)</span>
<span id="cb24-6"><a href="#cb24-6" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-7"><a href="#cb24-7" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Freeze feature extractor</span></span>
<span id="cb24-8"><a href="#cb24-8" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.feature_extractor.parameters():</span>
<span id="cb24-9"><a href="#cb24-9" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">False</span></span>
<span id="cb24-10"><a href="#cb24-10" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb24-11"><a href="#cb24-11" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Replace final layer</span></span>
<span id="cb24-12"><a href="#cb24-12" aria-hidden="true" tabindex="-1"></a>        num_features <span class="op">=</span> <span class="va">self</span>.feature_extractor.fc.in_features</span>
<span id="cb24-13"><a href="#cb24-13" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.feature_extractor.fc <span class="op">=</span> nn.Linear(num_features, <span class="dv">10</span>)</span>
<span id="cb24-14"><a href="#cb24-14" aria-hidden="true" tabindex="-1"></a>        </span>
<span id="cb24-15"><a href="#cb24-15" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> unfreeze_features(<span class="va">self</span>):</span>
<span id="cb24-16"><a href="#cb24-16" aria-hidden="true" tabindex="-1"></a>        <span class="co"># Unfreeze model after some training</span></span>
<span id="cb24-17"><a href="#cb24-17" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> param <span class="kw">in</span> <span class="va">self</span>.feature_extractor.parameters():</span>
<span id="cb24-18"><a href="#cb24-18" aria-hidden="true" tabindex="-1"></a>            param.requires_grad <span class="op">=</span> <span class="va">True</span></span></code></pre></div></div>
</section>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion" id="conclusion">Conclusion</h2>
<p>PyTorch Lightning provides an organized framework that separates research code from engineering boilerplate, making deep learning projects easier to develop, share, and scale. It’s especially valuable for research projects that need to be reproducible and scalable across different hardware configurations.</p>
<p>For further information:</p>
<ul>
<li><a href="https://lightning.ai/docs/pytorch/stable/">PyTorch Lightning Documentation</a></li>
<li><a href="https://github.com/Lightning-AI/lightning">Lightning GitHub Repository</a></li>
<li><a href="https://lightning.ai/docs/pytorch/stable/tutorials.html">Lightning Tutorials</a></li>
</ul>


</section>
</section>

]]></content:encoded>
    </item>
    <item>
      <title><![CDATA[Welcome To My Blog]]></title>
      <link>https://theja-vanka.github.io/blogs/posts/welcome/</link>
      <guid isPermaLink="true">https://theja-vanka.github.io/blogs/posts/welcome/</guid>
      <pubDate>Sat, 22 Mar 2025 00:00:00 GMT</pubDate>
      
      <category>news</category>
      <content:encoded><![CDATA[






<section id="welcome-to-my-computer-vision-blog" class="level1">

<p><img src="https://theja-vanka.github.io/blogs/posts/welcome/thumbnail.jpg" class="img-fluid"></p>
<p>This is the first post in this blog.</p>
<p>Hello and welcome!</p>
<p>I’m thrilled to kick off this blog dedicated to exploring the fascinating world of <strong>computer vision</strong> — a field where machines learn to see, interpret, and understand the visual world around us. Whether you’re a seasoned AI researcher, an aspiring developer, or simply curious about how technology can “see,” you’ll find something valuable here.</p>
<p>From image processing techniques to deep learning breakthroughs, from real-world applications to hands-on tutorials — this blog will cover it all. My goal is to make computer vision approachable, insightful, and exciting for everyone.</p>
<p>So, whether you’re here to learn, build, or stay on top of the latest innovations, I’m glad to have you along for the journey. Let’s dive into the visual future together!</p>
<p>Stay tuned, and let’s make machines see the world like never before. 🚀</p>
<p>Cheers, <br> Krishna</p>


</section>

]]></content:encoded>
    </item>
  </channel>
</rss>