.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "prototype/nestedtensor.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_prototype_nestedtensor.py: NestedTensors =============================================================== NestedTensors are similar to regular tensors, except for their shape: * for a regular tensor, each dimension has a size * for a nestedtensor, not all dimensions have regular sizes; some of them are jagged Nestedtensors are a natural solution for representing sequential data within various domains: * in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor * in CV, images can have variable shapes, so a batch of images forms a nestedtensor In this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example. NestedTensor are currently a prototype feature and are subject to change. .. GENERATED FROM PYTHON SOURCE LINES 23-29 .. code-block:: default import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') .. GENERATED FROM PYTHON SOURCE LINES 30-33 NestedTensor Initialization ---------------- .. GENERATED FROM PYTHON SOURCE LINES 35-37 From the Python frontend, a nestedtensor can be created from a list of tensors. We denote nt[i] as the ith tensor component of a nestedtensor. .. GENERATED FROM PYTHON SOURCE LINES 37-41 .. code-block:: default nt = torch.nested.nested_tensor([torch.arange(12).reshape( 2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device) print(f"{nt=}") .. GENERATED FROM PYTHON SOURCE LINES 42-44 By padding every underlying tensor to the same shape, a nestedtensor can be converted to a regular tensor. .. GENERATED FROM PYTHON SOURCE LINES 44-47 .. code-block:: default padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0) print(f"{padded_out_tensor=}") .. GENERATED FROM PYTHON SOURCE LINES 48-49 All tensors posses an attribute for determining if they are nested; .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: default print(f"nt is nested: {nt.is_nested}") print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}") .. GENERATED FROM PYTHON SOURCE LINES 53-56 It is common to construct nestedtensors from batches of irregularly shaped tensors. i.e. dimension 0 is assumed to be the batch dimension. Indexing dimension 0 gives back the first underlying tensor component. .. GENERATED FROM PYTHON SOURCE LINES 56-62 .. code-block:: default print("First underlying tensor component:", nt[0], sep='\n') print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n') # When indexing a nestedtensor's 0th dimension, the result is a regular tensor. print(f"First underlying tensor component is nested: {nt[0].is_nested}") .. GENERATED FROM PYTHON SOURCE LINES 63-66 An important note is that slicing in dimension 0 has not been supported yet. Which means it not currently possible to construct a view that combines the underlying tensor components. .. GENERATED FROM PYTHON SOURCE LINES 68-71 Nested Tensor Operations ---------------- .. GENERATED FROM PYTHON SOURCE LINES 73-92 As each operation must be explicitly implemented for nestedtensors, operation coverage for nestedtensors is currently narrower than that of regular tensors. For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. However, coverage is being expanded. If you need certain operations, please file an `issue `__ to help us prioritize coverage. **reshape** The reshape op is for changing the shape of a tensor. Its full semantics for regular tensors can be found `here `__. For regular tensors, when specifying the new shape, a single dimension may be -1, in which case it is inferred from the remaining dimensions and the number of elements. The semantics for nestedtensors are similar, except that -1 no longer infers. Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``). -1 is the only legal size to specify for a jagged dimension. .. GENERATED FROM PYTHON SOURCE LINES 92-95 .. code-block:: default nt_reshaped = nt.reshape(2, -1, 2, 3) print(f"{nt_reshaped=}") .. GENERATED FROM PYTHON SOURCE LINES 96-104 **transpose** The transpose op is for swapping two dimensions of a tensor. Its full semantics can be found `here `__. Note that for nestedtensors dimension 0 is special; it is assumed to be the batch dimension, so transposes involving nestedtensor dimension 0 are not supported. .. GENERATED FROM PYTHON SOURCE LINES 104-107 .. code-block:: default nt_transposed = nt_reshaped.transpose(1, 2) print(f"{nt_transposed=}") .. GENERATED FROM PYTHON SOURCE LINES 108-114 **others** Other operations have the same semantics as for regular tensors. Applying the operation on a nestedtensor is equivalent to applying the operation to the underlying tensor components, with the result being a nestedtensor as well. .. GENERATED FROM PYTHON SOURCE LINES 114-124 .. code-block:: default nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device) nt3 = torch.matmul(nt_transposed, nt_mm) print(f"Result of Matmul:\n {nt3}") nt4 = F.dropout(nt3, 0.1) print(f"Result of Dropout:\n {nt4}") nt5 = F.softmax(nt4, -1) print(f"Result of Softmax:\n {nt5}") .. GENERATED FROM PYTHON SOURCE LINES 125-128 Why Nested Tensor ---------------- .. GENERATED FROM PYTHON SOURCE LINES 130-136 When data is sequential, it is often the case that each sample has a different length. For example, in a batch of sentences, each sentence has a different number of words. A common technique for handling varying sequences is to manually pad each data tensor to the same shape in order to form a batch. For example, we have 2 sentences with different lengths and a vocabulary In order to represent his as single tensor we pad with 0 to the max length in the batch. .. GENERATED FROM PYTHON SOURCE LINES 136-147 .. code-block:: default sentences = [["goodbye", "padding"], ["embrace", "nested", "tensor"]] vocabulary = {"goodbye": 1.0, "padding": 2.0, "embrace": 3.0, "nested": 4.0, "tensor": 5.0} padded_sentences = torch.tensor([[1.0, 2.0, 0.0], [3.0, 4.0, 5.0]]) nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]) print(f"{padded_sentences=}") print(f"{nested_sentences=}") .. GENERATED FROM PYTHON SOURCE LINES 148-154 This techinque of padding a batch of data to its max length is not optimal. The padded data is not needed for computation and wastes memory by allocating larger tensors than necessary. Further, not all operations have the same semnatics when applied to padded data. For matrix multiplications in order to ignore the padded entries, one needs to pad with 0 while for softmax one has to pad with -inf to ignore specific entries. .. GENERATED FROM PYTHON SOURCE LINES 154-159 .. code-block:: default padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]) print(F.softmax(padded_sentences_for_softmax, -1)) print(F.softmax(nested_sentences, -1)) .. GENERATED FROM PYTHON SOURCE LINES 160-163 Let us take a look at a practical example: the multi-head attention component utilized in `Transformers `__. The nestedtensor version is straightforward. .. GENERATED FROM PYTHON SOURCE LINES 163-248 .. code-block:: default import math def mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, dropout_p: float = 0.0) -> torch.Tensor: """Compute multi-head attention with nested tensors. Args: query (torch.Tensor): query of shape (N, L_t, E_q) key (torch.Tensor): key of shape (N, L_s, E_k) value (torch.Tensor): value of shape (N, L_s, E_v) nheads (int): number of heads in multi-head attention W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None. b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None. b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None. b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None. dropout_p (float, optional): Dropout probability. Defaults to 0.0. Where: N is the batch size L_t is the target sequence length (jagged) L_s is the source sequence length (jagged) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: torch.Tensor: Output of shape (N, L_t, E_out) """ N = query.size(0) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) query = query.reshape(N, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) key = key.reshape(N, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) value = value.reshape(N, -1, nheads, E_head).transpose(1, 2) # query matmul key^T # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s) keyT = key.transpose(-1, -2) attn_weights = torch.matmul(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # softmax attn_weights = F.softmax(attn_weights, dim=-1) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights matmul value # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head) attn_output = torch.matmul(attn_weights, value) # merge heads # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) return attn_output .. GENERATED FROM PYTHON SOURCE LINES 249-251 The 0-padded tensor version additionally requires masks for more complicated treatments at padded entries. .. GENERATED FROM PYTHON SOURCE LINES 251-351 .. code-block:: default def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int, attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor, W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor, b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None, dropout_p: float = 0.0) -> torch.Tensor: """Compute multi-head attention for padded out dense tensors. Args: query (torch.Tensor): query of shape (N, L_t, E_q) key (torch.Tensor): key of shape (N, L_s, E_k) value (torch.Tensor): value of shape (N, L_s, E_v) nheads (int): number of heads in multi-head attention attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q) W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k) W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v) W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total) b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None. b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None. b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None. b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None. dropout_p (float, optional): Dropout probability. Defaults to 0.0. Where: N is the batch size L_t is the target sequence length (padded) L_s is the source sequence length (padded) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: torch.Tensor: Output of shape (N, L_t, E_out) """ N = query.size(0) L_t = query.size(1) L_s = key.size(1) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" assert L_t == L_s, "This implementation assumes equal query and key sequence lengths" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head) query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # query bmm key^T # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s) keyT = key.transpose(-1, -2) attn_weights = torch.bmm(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # Have to manipulate masks in order to apply them to the attention weights key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device) attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32) attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf")) # Zero out the attention weights where the mask is True by adding -inf prior to softmax attn_weights.add_(attn_mask) # softmax attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights bmm value # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head) attn_output = attn_weights.bmm(value) # merge heads # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) # padding-specific step: remove output projection bias from padded entries attn_output[attn_mask_q, :] = 0.0 return attn_output .. GENERATED FROM PYTHON SOURCE LINES 352-353 set hyperparameters following `the Transformer paper `__ .. GENERATED FROM PYTHON SOURCE LINES 353-357 .. code-block:: default N = 512 E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512 nheads = 8 .. GENERATED FROM PYTHON SOURCE LINES 358-359 except for dropout probability: set to 0 for correctness check .. GENERATED FROM PYTHON SOURCE LINES 359-361 .. code-block:: default dropout_p = 0.0 .. GENERATED FROM PYTHON SOURCE LINES 362-363 Let us generate some realistic fake data from Zipf's law. .. GENERATED FROM PYTHON SOURCE LINES 363-383 .. code-block:: default import numpy as np def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 sentence_lengths = np.empty(batch_size, dtype=int) for ibatch in range(batch_size): sentence_lengths[ibatch] = 1 word = np.random.zipf(alpha) while word != 3 and word != 386 and word != 858: sentence_lengths[ibatch] += 1 word = np.random.zipf(alpha) return sentence_lengths alpha = 1.2 sentence_lengths = zipf_sentence_lengths(alpha, N) L_t = np.max(sentence_lengths) L_s = L_t .. GENERATED FROM PYTHON SOURCE LINES 384-385 create inputs .. GENERATED FROM PYTHON SOURCE LINES 385-420 .. code-block:: default # create parameters W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device) W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device) W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device) W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device) # create nested input queries = [] keys = [] values = [] for i in range(N): l = sentence_lengths[i] s = l queries.append(torch.randn((l, E_q), device=device)) keys .append(torch.randn((s, E_k), device=device)) values .append(torch.randn((s, E_v), device=device)) query = torch.nested.nested_tensor(queries) key = torch.nested.nested_tensor(keys) value = torch.nested.nested_tensor(values) # pad input padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q)) padded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k)) padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v)) # create attention masks attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool) attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool) # We need to mask out the padding entries in the attention weights. for i, entry_length in enumerate(sentence_lengths): attn_mask_q[i, entry_length:] = True attn_mask_kv[i, entry_length:] = True .. GENERATED FROM PYTHON SOURCE LINES 421-422 check correctness and performance .. GENERATED FROM PYTHON SOURCE LINES 422-445 .. code-block:: default import timeit t0 = timeit.default_timer() out_nested = mha_nested( query, key, value, nheads, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t1 = timeit.default_timer() out_padded = mha_padded( padded_query, padded_key, padded_value, nheads, attn_mask_q, attn_mask_kv, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) print("nestedtensor multi-head attention takes", t1 - t0, "seconds") print("padded tensor multi-head attention takes", t2 - t1, "seconds") .. GENERATED FROM PYTHON SOURCE LINES 446-453 Although the nestedtensor version avoids wasted computation on padding, it is not faster then the equivalent padded tensor version. This is because the nestedtensor version has implemented a few of the kernels, like softmax, in a non optimal way. There are plans to implement performance critical operations using the new Pytorch 2.0 stack For now, some performant kernels are provided for specific use cases, e.g. self-attention evaluation by multi-head attention formula. .. GENERATED FROM PYTHON SOURCE LINES 453-459 .. code-block:: default # embeddings are assumed to be the same E = E_total mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device) mha_lib.eval() .. GENERATED FROM PYTHON SOURCE LINES 460-461 extract parameters for correctness check .. GENERATED FROM PYTHON SOURCE LINES 461-470 .. code-block:: default mha_lib.in_proj_weight.requires_grad_(False) mha_lib.in_proj_bias.requires_grad_(False) mha_lib.out_proj.weight.requires_grad_(False) mha_lib.out_proj.bias.requires_grad_(False) W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E] W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E] W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :] W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias .. GENERATED FROM PYTHON SOURCE LINES 471-476 If we set need_weights to False this will enable the fast path in the library. Under the hood this will call _scaled_dot_product_attention. If your tensors are on CUDA, than a fused, efficient attention kernel will be used. For more detailed performance characteristics look at the benchmark in pytorch/benchmarks/transformer/sdp.py .. GENERATED FROM PYTHON SOURCE LINES 476-495 .. code-block:: default with torch.inference_mode(): t0 = timeit.default_timer() out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False) t1 = timeit.default_timer() padded_out = mha_padded( padded_query, padded_query, padded_query, nheads, attn_mask_q, attn_mask_q, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() nested_time = t1 - t0 padded_time = t2 - t1 print("Nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item()) print("Nested library multi-head attention takes", nested_time, "seconds") print("Padded tensor multi-head attention takes", padded_time, "seconds") print(f"Nested Speedup: {padded_time / nested_time:.3f}") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_prototype_nestedtensor.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: nestedtensor.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nestedtensor.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_