.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/ensembling.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_ensembling.py: Model ensembling ================ This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``. What is model ensembling? ------------------------- Model ensembling combines the predictions from multiple models together. Traditionally this is done by running each model on some inputs separately and then combining the predictions. However, if you're running models with the same architecture, then it may be possible to combine them together using ``torch.vmap``. ``vmap`` is a function transform that maps functions across dimensions of the input tensors. One of its use cases is eliminating for-loops and speeding them up through vectorization. Let's demonstrate how to do this using an ensemble of simple MLPs. .. GENERATED FROM PYTHON SOURCE LINES 20-43 .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F torch.manual_seed(0) # Here's a simple MLP class SimpleMLP(nn.Module): def __init__(self): super(SimpleMLP, self).__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = x.flatten(1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.relu(x) x = self.fc3(x) return x .. GENERATED FROM PYTHON SOURCE LINES 44-48 Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a minibatch of size 64. Furthermore, lets say we want to combine the predictions from 10 different models. .. GENERATED FROM PYTHON SOURCE LINES 48-57 .. code-block:: default device = 'cuda' num_models = 10 data = torch.randn(100, 64, 1, 28, 28, device=device) targets = torch.randint(10, (6400,), device=device) models = [SimpleMLP().to(device) for _ in range(num_models)] .. GENERATED FROM PYTHON SOURCE LINES 58-62 We have a couple of options for generating predictions. Maybe we want to give each model a different randomized minibatch of data. Alternatively, maybe we want to run the same minibatch of data through each model (e.g. if we were testing the effect of different model initializations). .. GENERATED FROM PYTHON SOURCE LINES 64-65 Option 1: different minibatch for each model .. GENERATED FROM PYTHON SOURCE LINES 65-69 .. code-block:: default minibatches = data[:num_models] predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)] .. GENERATED FROM PYTHON SOURCE LINES 70-71 Option 2: Same minibatch .. GENERATED FROM PYTHON SOURCE LINES 71-75 .. code-block:: default minibatch = data[0] predictions2 = [model(minibatch) for model in models] .. GENERATED FROM PYTHON SOURCE LINES 76-89 Using vmap to vectorize the ensemble ------------------------------------ Let's use vmap to speed up the for-loop. We must first prepare the models for use with vmap. First, let’s combine the states of the model together by stacking each parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are going to stack the .fc1.weight of each of the 10 models to produce a big weight of shape ``[10, 784, 128]``. PyTorch offers the ``torch.func.stack_module_state`` convenience function to do this. .. GENERATED FROM PYTHON SOURCE LINES 89-93 .. code-block:: default from torch.func import stack_module_state params, buffers = stack_module_state(models) .. GENERATED FROM PYTHON SOURCE LINES 94-98 Next, we need to define a function to vmap over. The function should, given parameters and buffers and inputs, run the model using those parameters, buffers, and inputs. We'll use ``torch.func.functional_call`` to help out: .. GENERATED FROM PYTHON SOURCE LINES 98-110 .. code-block:: default from torch.func import functional_call import copy # Construct a "stateless" version of one of the models. It is "stateless" in # the sense that the parameters are meta Tensors and do not have storage. base_model = copy.deepcopy(models[0]) base_model = base_model.to('meta') def fmodel(params, buffers, x): return functional_call(base_model, (params, buffers), (x,)) .. GENERATED FROM PYTHON SOURCE LINES 111-117 Option 1: get predictions using a different minibatch for each model. By default, vmap maps a function across the first dimension of all inputs to the passed-in function. After using ``stack_module_state``, each of the params and buffers have an additional dimension of size 'num_models' at the front, and minibatches has a dimension of size 'num_models'. .. GENERATED FROM PYTHON SOURCE LINES 117-129 .. code-block:: default print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models' from torch import vmap predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) # verify the vmap predictions match the assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5) .. GENERATED FROM PYTHON SOURCE LINES 130-135 Option 2: get predictions using the same minibatch of data. vmap has an in_dims arg that specifies which dimensions to map over. By using ``None``, we tell vmap we want the same minibatch to apply for all of the 10 models. .. GENERATED FROM PYTHON SOURCE LINES 135-140 .. code-block:: default predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5) .. GENERATED FROM PYTHON SOURCE LINES 141-147 A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations. .. GENERATED FROM PYTHON SOURCE LINES 149-152 Performance ----------- Curious about performance numbers? Here's how the numbers look. .. GENERATED FROM PYTHON SOURCE LINES 152-163 .. code-block:: default from torch.utils.benchmark import Timer without_vmap = Timer( stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]", globals=globals()) with_vmap = Timer( stmt="vmap(fmodel)(params, buffers, minibatches)", globals=globals()) print(f'Predictions without vmap {without_vmap.timeit(100)}') print(f'Predictions with vmap {with_vmap.timeit(100)}') .. GENERATED FROM PYTHON SOURCE LINES 164-172 There's a large speedup using vmap! In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue on GitHub. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_intermediate_ensembling.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: ensembling.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: ensembling.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_