Skip to content

sehkmg/Transformer-Navigation

Repository files navigation

Transformer-Navigation

Navigate inside of the transformer architecture.

This code is strongly based on the code in the blog post "The Annotated Transformer". For more information, read Attention Is All You Need paper or The Annotated Transformer post.

Transformer Architecture

Tracking tensors


0. Vocabulary
source vocabulary: ['etudiant', 'je', 'quel', 'suis', 'mois', '<blank>']
target vocabulary: ['a', 'am', 'month', 'i', 'student', 'what', '<blank>', '<s>', '</s>']

1. Preparing batch and mask

1-1. Batch (text)
source batch (text): [['je', 'suis', 'etudiant'], ['quel', 'mois', '<blank>']]
target batch (text): [['<s>', 'i', 'am', 'a', 'student'], ['<s>', 'what', 'month', '</s>', '<blank>']]
target batch true output: [['i', 'am', 'a', 'student', '</s>'], ['what', 'month', '</s>', '<blank>', '<blank>']]

1-2. Batch (number)
source batch (number):
tensor([[ 1,  3,  0],
        [ 2,  4,  5]])

target batch (number):
tensor([[ 7,  3,  1,  0,  4],
        [ 7,  5,  2,  8,  6]])

1-3. Mask
source mask:
tensor([[[ 1,  1,  1]],

        [[ 1,  1,  0]]], dtype=torch.uint8)

target mask:
tensor([[[ 1,  0,  0,  0,  0],
         [ 1,  1,  0,  0,  0],
         [ 1,  1,  1,  0,  0],
         [ 1,  1,  1,  1,  0],
         [ 1,  1,  1,  1,  1]],

        [[ 1,  0,  0,  0,  0],
         [ 1,  1,  0,  0,  0],
         [ 1,  1,  1,  0,  0],
         [ 1,  1,  1,  1,  0],
         [ 1,  1,  1,  1,  0]]], dtype=torch.uint8)

2. Encoding part

2-1. Source batch embedding + positional encoding (src_embed)
tensor([[[-0.0991, -0.3589, -0.5712, -0.2789,  1.8014,  0.9144],
         [ 0.4986, -0.9652, -0.4070,  1.2863, -1.8285, -0.4280],
         [ 1.9674, -1.2571, -1.0147,  1.0888,  1.1725,  2.6795]],

        [[ 1.3305,  0.0000, -1.0701, -0.5923, -1.0296,  1.6793],
         [ 0.8715,  1.3400, -0.0303,  0.0000,  1.6201,  1.5863],
         [ 1.4312, -1.9720,  1.0480,  0.1499,  0.5277,  0.4981]]])

2-2. MultiHeadedAttention(src_embed (query), src_embed (key), src_embed (value), src_mask)
Q:
tensor([[[[ 0.2873,  0.2007],
          [-0.9330,  0.9019],
          [-0.0082,  1.6451]],

         [[-0.5665, -0.7560],
          [ 1.1108,  0.1960],
          [ 0.9804, -0.2669]],

         [[ 0.3628, -0.4734],
          [-0.1485,  0.5912],
          [ 0.4117,  0.2190]]],


        [[[ 0.4587,  1.4866],
          [ 0.5982,  0.2759],
          [ 0.4676,  0.0464]],

         [[ 1.8160,  0.3228],
          [ 0.3496,  0.1092],
          [ 0.0640, -1.3059]],

         [[-0.0644,  0.6141],
          [ 0.2253, -0.3131],
          [-0.0457, -0.2041]]]])

K^T:
tensor([[[[-0.6992,  1.4689,  0.7740],
          [-0.8642,  1.2513,  0.2392]],

         [[-0.6000, -0.5576, -1.6619],
          [ 0.1727,  0.3718,  0.8012]],

         [[-0.1372, -0.2086,  0.4862],
          [ 0.2631, -1.5935, -1.8292]]],


        [[[ 0.7847, -0.6059,  0.3646],
          [-0.0430, -0.8606, -0.0435]],

         [[-1.3169, -0.4885, -1.3489],
          [ 0.5333,  0.1433, -1.1985]],

         [[ 1.4939,  0.8204, -0.8060],
          [-2.0194,  0.1332, -1.2614]]]])

Q * K^T:
tensor([[[[-0.2647,  0.4760,  0.1912],
          [-0.0899, -0.1711, -0.3580],
          [-1.0013,  1.4471,  0.2738]],

         [[ 0.1480,  0.0246,  0.2375],
          [-0.4473, -0.3864, -1.1942],
          [-0.4486, -0.4567, -1.3033]],

         [[-0.1233,  0.4799,  0.7370],
          [ 0.1244, -0.6442, -0.8158],
          [ 0.0008, -0.3075, -0.1417]]],


        [[[ 0.2093, -1.1012,  0.0725],
          [ 0.3235, -0.4242,  0.1457],
          [ 0.2580, -0.2286,  0.1191]],

         [[-1.5694, -0.5946, -2.0058],
          [-0.2844, -0.1097, -0.4260],
          [-0.5520, -0.1544,  1.0457]],

         [[-0.9450,  0.0205, -0.5111],
          [ 0.6851,  0.1012,  0.1509],
          [ 0.2431, -0.0457,  0.2081]]]])

mask:
tensor([[[[ 1,  1,  1]]],


        [[[ 1,  1,  0]]]], dtype=torch.uint8)

softmax(Q * K^T) with masking:
tensor([[[[ 0.2139,  0.4486,  0.3375],
          [ 0.3722,  0.3432,  0.2846],
          [ 0.0619,  0.7164,  0.2216]],

         [[ 0.3359,  0.2969,  0.3673],
          [ 0.3942,  0.4190,  0.1868],
          [ 0.4137,  0.4103,  0.1760]],

         [[ 0.1926,  0.3521,  0.4553],
          [ 0.5393,  0.2500,  0.2106],
          [ 0.3843,  0.2824,  0.3333]]],


        [[[ 0.7876,  0.2124,  0.0000],
          [ 0.6787,  0.3213,  0.0000],
          [ 0.6193,  0.3807,  0.0000]],

         [[ 0.2739,  0.7261,  0.0000],
          [ 0.4564,  0.5436,  0.0000],
          [ 0.4019,  0.5981,  0.0000]],

         [[ 0.2758,  0.7242,  0.0000],
          [ 0.6420,  0.3580,  0.0000],
          [ 0.5717,  0.4283,  0.0000]]]])

V:
tensor([[[[ 0.0429,  1.2085],
          [-0.6691, -0.5429],
          [-0.3989,  1.0405]],

         [[ 0.1204,  0.7988],
          [-0.4093, -1.5120],
          [ 0.2817,  0.3272]],

         [[-1.6174,  0.0450],
          [ 0.1738,  0.8663],
          [-2.2803,  0.3509]]],


        [[[-0.2660, -0.3585],
          [ 0.1631,  0.3127],
          [ 0.2958,  1.5232]],

         [[ 0.4697,  0.7489],
          [-0.1156,  1.5091],
          [ 1.4921, -0.5300]],

         [[-2.0205, -0.7742],
          [-1.4917, -0.9358],
          [-2.0415,  0.3973]]]])

softmax(Q * K^T) * V:
tensor([[[[-0.4729,  0.4067],
          [-0.3635,  0.6218],
          [-0.6309, -0.1760]],

         [[ 0.0248, -0.0672],
          [-0.0794, -0.2861],
          [-0.0762, -0.2582]],

         [[-1.4318,  0.5260],
          [-0.9209,  0.2676],
          [-1.4806,  0.4209]]],


        [[[-0.1943, -0.2399],
          [-0.1424, -0.1587],
          [-0.1141, -0.1144]],

         [[ 0.0497,  1.4454],
          [ 0.1684,  1.2913],
          [ 0.1329,  1.3373]],

         [[-1.8195, -0.9902],
          [-2.0346, -0.9245],
          [-1.9933, -0.9371]]]])

Merge, take a linear map and add to the src_embed
tensor([[[ 0.6415, -0.8070, -0.0989, -0.3344,  2.1601,  0.1320],
         [ 0.4986, -1.6185, -0.0037,  0.9838, -2.2176, -0.7925],
         [ 2.7234, -1.5721, -0.1605,  1.3591,  1.6872,  1.6879]],

        [[ 2.7209,  0.0000, -1.0673, -0.0841,  0.0637,  1.5050],
         [ 2.2868,  1.0852,  0.0689,  0.4620,  1.6201,  1.4560],
         [ 2.8498, -2.2118,  1.0480,  0.6093,  1.6342,  0.3860]]])

2-3. X = X + FF(X)
tensor([[[ 1.8808, -0.3297, -0.3010, -0.6214,  0.2684, -0.8972],
         [ 0.7203, -0.1711,  0.7326,  0.8862, -1.7084, -0.4596],
         [ 1.5323, -0.8540, -1.1683,  0.2520, -0.3636,  0.6015]],

        [[ 1.2566,  0.3957, -1.6387, -0.2336, -0.3900,  0.6101],
         [ 0.9580,  1.0129, -1.4811, -0.2915,  0.5165, -0.7148],
         [ 1.8221, -1.1817,  0.1599, -0.3333, -0.0959, -0.3710]]])

3. Decoding part

3-1. Target batch embedding + positional encoding (tgt_embed)
tensor([[[-0.3361, -0.5411,  0.5297,  2.3342,  0.8973,  0.1402],
         [-0.4919, -0.6049, -1.6053,  0.2676,  0.8980,  0.2953],
         [ 1.1844, -0.0000, -0.1628,  2.7691,  1.4741,  0.4418],
         [-1.2462, -1.9267,  0.0446,  2.3728,  0.4344,  1.2240],
         [-2.0308,  0.6776, -0.4236,  1.1829,  0.9161, -0.4064]],

        [[-0.3361, -0.5411,  0.5297,  2.3342,  0.8973,  0.1402],
         [ 1.0250,  1.5371,  0.4503, -0.3613, -0.7242,  1.5025],
         [ 1.0694,  0.3562, -1.0864,  0.0000, -0.1417,  2.7160],
         [ 0.1968, -0.1253, -0.0000,  0.5154, -0.2522,  2.0241],
         [-0.1829,  0.0000,  0.0000,  2.0638, -1.6618,  0.0000]]])

3-2. Output from encoder (memory)
tensor([[[ 1.8808, -0.3297, -0.3010, -0.6214,  0.2684, -0.8972],
         [ 0.7203, -0.1711,  0.7326,  0.8862, -1.7084, -0.4596],
         [ 1.5323, -0.8540, -1.1683,  0.2520, -0.3636,  0.6015]],

        [[ 1.2566,  0.3957, -1.6387, -0.2336, -0.3900,  0.6101],
         [ 0.9580,  1.0129, -1.4811, -0.2915,  0.5165, -0.7148],
         [ 1.8221, -1.1817,  0.1599, -0.3333, -0.0959, -0.3710]]])

3-3. MultiHeadedAttention(tgt_embed (query), tgt_embed (key), tgt_embed (value), tgt_mask)
Q:
tensor([[[[-0.5421,  1.5738],
          [-0.7382,  0.7951],
          [-1.6815,  1.0081],
          [ 0.2136,  1.1754],
          [-0.0839,  1.0675]],

         [[ 1.2160, -0.9961],
          [ 0.8194, -1.0782],
          [ 1.8926, -1.9814],
          [ 0.5622, -0.5471],
          [ 0.0154, -0.1365]],

         [[-1.8108, -1.0346],
          [-0.7414, -0.6354],
          [-0.7128, -0.2254],
          [-1.9131, -1.1539],
          [-1.6068, -1.5226]]],


        [[[-0.5421,  1.5738],
          [ 0.6245, -2.5758],
          [ 0.2095, -1.5605],
          [ 0.9281, -1.0101],
          [-0.6299, -0.6413]],

         [[ 1.2160, -0.9961],
          [-1.6406,  0.3268],
          [-0.7254, -0.4050],
          [-0.9473,  0.0529],
          [ 0.2568, -1.0327]],

         [[-1.8108, -1.0346],
          [ 1.1267,  0.6528],
          [ 0.4754,  0.2389],
          [-0.5415, -0.4035],
          [-0.5191, -0.2838]]]])

K^T:
tensor([[[[-1.6267,  0.1434, -0.5614, -1.9235, -0.9945],
          [-1.5008, -0.8060, -1.2070, -2.0382, -0.2845]],

         [[ 1.0226, -0.1413,  1.0822,  0.0854,  1.3647],
          [-0.6622,  0.7058,  0.1667, -0.4866, -1.1672]],

         [[-0.6153, -0.2148, -0.0583, -0.9386, -0.7262],
          [-0.6187, -0.2110, -0.5467, -0.2972,  0.1690]]],


        [[[-1.6267,  0.1631, -0.0309, -1.3011, -1.9562],
          [-1.5008,  0.1211, -0.9794, -1.7500, -1.4243]],

         [[ 1.0226, -0.4784, -1.0563, -1.0146,  1.6665],
          [-0.6622, -0.8010,  0.2722, -0.3651, -1.6178]],

         [[-0.6153, -0.3968, -0.4770, -1.0090, -0.7809],
          [-0.6187,  1.4376,  0.9992,  0.9124,  0.8331]]]])

Q * K^T:
tensor([[[[-1.0467, -0.9519, -1.1280, -1.5309,  0.0646],
          [ 0.0053, -0.5280, -0.3855, -0.1418,  0.3592],
          [ 0.8643, -0.7450, -0.1928,  0.8342,  0.9797],
          [-1.4931, -0.6482, -1.0880, -1.9846, -0.3867],
          [-1.0364, -0.6169, -0.8777, -1.4244, -0.1558]],

         [[ 1.3457, -0.6186,  0.8131,  0.4162,  1.9956],
          [ 1.0974, -0.6200,  0.4999,  0.4205,  1.6806],
          [ 2.2964, -1.1779,  1.2147,  0.7960,  3.4617],
          [ 0.6627, -0.3292,  0.3657,  0.2222,  0.9941],
          [ 0.0750, -0.0696, -0.0043,  0.0479,  0.1275]],

         [[ 1.2405,  0.4294,  0.4746,  1.4192,  0.8062],
          [ 0.6005,  0.2074,  0.2762,  0.6255,  0.3048],
          [ 0.4088,  0.1419,  0.1165,  0.5204,  0.3391],
          [ 1.3372,  0.4627,  0.5249,  1.5122,  0.8445],
          [ 1.3653,  0.4712,  0.6548,  1.3864,  0.6432]]],


        [[[-1.0467,  0.0722, -1.0781, -1.4488, -0.8352],
          [ 2.0153, -0.1485,  1.7703,  2.6130,  1.7303],
          [ 1.4151, -0.1094,  1.0762,  1.7383,  1.2818],
          [ 0.0045,  0.0206,  0.6793,  0.3962, -0.2664],
          [ 1.4051, -0.1275,  0.4579,  1.3731,  1.5171]],

         [[ 1.3457,  0.1528, -1.1000, -0.6152,  2.5725],
          [-1.3393,  0.3699,  1.2883,  1.0926, -2.3071],
          [-0.3349,  0.4748,  0.4638,  0.6250, -0.3915],
          [-0.7097,  0.2905,  0.7177,  0.6660, -1.1767],
          [ 0.6693,  0.4980, -0.3906,  0.0823,  1.4841]],

         [[ 1.2405, -0.5436, -0.1203,  0.6244,  0.3904],
          [-0.7758,  0.3474,  0.0813, -0.3827, -0.2376],
          [-0.3114,  0.1094,  0.0084, -0.1851, -0.1218],
          [ 0.4122, -0.2582, -0.1025,  0.1260,  0.0613],
          [ 0.3500, -0.1428, -0.0254,  0.1873,  0.1195]]]])

mask:
tensor([[[[ 1,  0,  0,  0,  0],
          [ 1,  1,  0,  0,  0],
          [ 1,  1,  1,  0,  0],
          [ 1,  1,  1,  1,  0],
          [ 1,  1,  1,  1,  1]]],


        [[[ 1,  0,  0,  0,  0],
          [ 1,  1,  0,  0,  0],
          [ 1,  1,  1,  0,  0],
          [ 1,  1,  1,  1,  0],
          [ 1,  1,  1,  1,  0]]]], dtype=torch.uint8)

softmax(Q * K^T) with masking:
tensor([[[[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.6303,  0.3697,  0.0000,  0.0000,  0.0000],
          [ 0.6462,  0.1293,  0.2245,  0.0000,  0.0000],
          [ 0.1839,  0.4280,  0.2757,  0.1125,  0.0000],
          [ 0.1474,  0.2242,  0.1727,  0.1000,  0.3556]],

         [[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.8478,  0.1522,  0.0000,  0.0000,  0.0000],
          [ 0.7299,  0.0226,  0.2475,  0.0000,  0.0000],
          [ 0.3626,  0.1345,  0.2695,  0.2334,  0.0000],
          [ 0.2076,  0.1797,  0.1918,  0.2021,  0.2188]],

         [[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.5970,  0.4030,  0.0000,  0.0000,  0.0000],
          [ 0.3980,  0.3048,  0.2972,  0.0000,  0.0000],
          [ 0.3276,  0.1366,  0.1454,  0.3903,  0.0000],
          [ 0.2935,  0.1200,  0.1442,  0.2997,  0.1426]]],


        [[[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.8970,  0.1030,  0.0000,  0.0000,  0.0000],
          [ 0.5181,  0.1128,  0.3691,  0.0000,  0.0000],
          [ 0.1832,  0.1861,  0.3597,  0.2710,  0.0000],
          [ 0.3888,  0.0840,  0.1508,  0.3765,  0.0000]],

         [[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1533,  0.8467,  0.0000,  0.0000,  0.0000],
          [ 0.1828,  0.4108,  0.4064,  0.0000,  0.0000],
          [ 0.0844,  0.2296,  0.3519,  0.3341,  0.0000],
          [ 0.3643,  0.3069,  0.1262,  0.2025,  0.0000]],

         [[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.2454,  0.7546,  0.0000,  0.0000,  0.0000],
          [ 0.2564,  0.3905,  0.3530,  0.0000,  0.0000],
          [ 0.3496,  0.1788,  0.2090,  0.2626,  0.0000],
          [ 0.3177,  0.1941,  0.2182,  0.2700,  0.0000]]]])

V:
tensor([[[[-0.1569, -0.1303],
          [ 0.4915,  1.1576],
          [ 0.2951, -0.1244],
          [ 0.2081,  0.4328],
          [ 0.2947,  0.7468]],

         [[-0.6981, -0.6486],
          [-0.3091, -0.6712],
          [-0.3662, -0.4734],
          [-0.3013, -0.5132],
          [-1.3433, -1.7526]],

         [[-0.7434,  0.9040],
          [ 0.3286, -0.5518],
          [-1.4179,  0.4413],
          [-0.1633,  0.4952],
          [ 0.2411, -0.3765]]],


        [[[-0.1569, -0.1303],
          [ 1.1811,  0.0495],
          [ 1.4490,  0.9874],
          [ 1.2789,  0.9373],
          [ 1.2890, -0.4136]],

         [[-0.6981, -0.6486],
          [ 0.6316, -0.6768],
          [ 0.7180, -0.4615],
          [ 0.5463, -0.5267],
          [-0.3281, -1.2527]],

         [[-0.7434,  0.9040],
          [-0.3534, -1.5435],
          [-0.0166, -1.5146],
          [ 0.1169, -1.0003],
          [-2.0980, -0.2976]]]])

softmax(Q * K^T) * V:
tensor([[[[-0.1744, -0.1448],
          [ 0.0920,  0.3843],
          [ 0.0315,  0.0417],
          [ 0.3181,  0.5398],
          [ 0.1705,  0.2979]],

         [[-0.7757, -0.7207],
          [-0.7099, -0.7245],
          [-0.6747, -0.6731],
          [-0.5153, -0.6365],
          [-0.6951, -0.9259]],

         [[-0.8260,  1.0045],
          [ 0.1471, -0.2471],
          [-0.3569, -0.0412],
          [-0.2916,  0.4601],
          [-0.4420,  0.3972]]],


        [[[-0.1744, -0.1448],
          [-0.0212, -0.1242],
          [ 0.6520,  0.3362],
          [ 1.1765,  0.6606],
          [ 0.7100,  0.5013]],

         [[-0.7757, -0.7207],
          [ 0.4754, -0.7472],
          [ 0.4707, -0.6490],
          [ 0.5791, -0.6094],
          [ 0.1565, -0.6766]],

         [[-0.8260,  1.0045],
          [-0.4990, -1.0476],
          [-0.3652, -0.4122],
          [-0.3287, -0.5990],
          [-0.2314, -0.3482]]]])

Merge, take a linear map and add to the tgt_embed (tgt_embed_attn)
tensor([[[ 0.2576, -1.3072,  1.8100,  0.7006,  0.6683,  0.4353],
         [-0.9247, -1.0474, -0.8220, -0.5448,  0.5661, -0.0052],
         [ 1.0528, -0.7254,  0.9485,  1.6799,  1.5480,  0.1961],
         [-1.2074, -2.0542,  0.9351,  1.5558,  0.0306,  0.8459],
         [-2.0669,  0.1645,  0.8701, -0.0027,  0.7121, -0.7640]],

        [[ 0.2576, -1.3072,  0.5297,  0.7006,  0.8973,  0.4353],
         [ 0.3881,  0.7784,  1.2159, -0.5447, -0.0654,  0.0339],
         [ 0.3439,  0.2162, -1.0864, -0.0267,  0.2738,  1.2346],
         [-0.8429,  0.0335,  0.0000,  0.8225,  0.4094,  2.0241],
         [-0.9050, -0.0864,  0.9056,  1.9073, -1.6618, -1.3210]]])

3-4. MultiHeadedAttention(tgt_embed_attn (query), memory (key), memory (value), src_mask)
Q:
tensor([[[[-0.0758,  0.9863],
          [-0.1229, -0.9282],
          [ 1.4340,  0.6957],
          [-0.2881, -0.4703],
          [-0.6102, -0.8839]],

         [[-1.6173, -0.7890],
          [-1.4365, -2.1836],
          [-1.5107, -1.5196],
          [-2.2072, -0.4909],
          [-1.0444, -1.1330]],

         [[ 0.0027, -0.5079],
          [ 0.8844, -0.9470],
          [ 0.3176, -1.0832],
          [ 0.2971, -0.1874],
          [ 0.1125,  0.5323]]],


        [[[ 0.5328,  0.4803],
          [-1.1035,  0.9301],
          [-0.4100, -1.4574],
          [-1.3444, -1.9944],
          [ 0.3436, -0.3239]],

         [[-1.8617, -1.4616],
          [ 1.2383,  0.4401],
          [ 0.2157, -0.4388],
          [-1.4478, -0.2748],
          [-0.6261,  0.7962]],

         [[ 0.4649, -1.0808],
          [-1.0231,  0.7122],
          [ 0.2729, -0.4496],
          [ 0.4535,  0.1840],
          [-0.6205,  0.8188]]]])

K^T:
tensor([[[[ 0.2184,  0.2875,  0.1470],
          [-0.7032,  0.7268, -0.8534]],

         [[ 0.9941,  1.4162,  0.3437],
          [-0.1453, -0.6175, -0.4588]],

         [[-1.8275,  0.8175, -0.4833],
          [ 0.9539, -0.6637,  0.6972]]],


        [[[ 1.2250,  1.3999, -0.5532],
          [-1.0456, -1.6467, -0.1479]],

         [[ 0.8125,  0.6212,  0.7205],
          [-0.1804,  0.5947, -0.6608]],

         [[-0.4169, -1.6140, -1.1815],
          [ 0.3413,  0.3242,  0.9922]]]])

Q * K^T:
tensor([[[[-0.5021,  0.4915, -0.6030],
          [ 0.4426, -0.5020,  0.5473],
          [-0.1245,  0.6490, -0.2707],
          [ 0.1894, -0.3003,  0.2538],
          [ 0.3453, -0.5783,  0.4699]],

         [[-1.0558, -1.2751, -0.1370],
          [-0.7855, -0.4852,  0.3593],
          [-0.9058, -0.8494,  0.1259],
          [-1.5010, -1.9960, -0.3771],
          [-0.6178, -0.5512,  0.1138]],

         [[-0.3460,  0.2399, -0.2513],
          [-1.7816,  0.9557, -0.7691],
          [-1.1411,  0.6920, -0.6426],
          [-0.5104,  0.2597, -0.1939],
          [ 0.2137, -0.1848,  0.2240]]],


        [[[ 0.1064, -0.0319, -0.2586],
          [-1.6434, -2.1752,  0.3344],
          [ 0.7224,  1.2912,  0.3127],
          [ 0.3100,  0.9915,  0.7344],
          [ 0.5371,  0.7172, -0.1005]],

         [[-0.8832, -1.4323, -0.2656],
          [ 0.6553,  0.7290,  0.4253],
          [ 0.1799, -0.0898,  0.3149],
          [-0.7968, -0.7515, -0.6093],
          [-0.4613,  0.0598, -0.6911]],

         [[-0.3978, -0.7784, -1.1467],
          [ 0.4734,  1.3309,  1.3543],
          [-0.1890, -0.4146, -0.5435],
          [-0.0893, -0.4754, -0.2498],
          [ 0.3805,  0.8959,  1.0928]]]])

mask:
tensor([[[[ 1,  1,  1]]],


        [[[ 1,  1,  0]]]], dtype=torch.uint8)

softmax(Q * K^T) with masking:
tensor([[[[ 0.2172,  0.5865,  0.1963],
          [ 0.4001,  0.1556,  0.4443],
          [ 0.2481,  0.5376,  0.2143],
          [ 0.3732,  0.2287,  0.3981],
          [ 0.3953,  0.1570,  0.4478]],

         [[ 0.2321,  0.1864,  0.5816],
          [ 0.1821,  0.2459,  0.5721],
          [ 0.2056,  0.2175,  0.5769],
          [ 0.2134,  0.1301,  0.6565],
          [ 0.2411,  0.2577,  0.5011]],

         [[ 0.2567,  0.4612,  0.2822],
          [ 0.0521,  0.8045,  0.1434],
          [ 0.1124,  0.7026,  0.1850],
          [ 0.2206,  0.4766,  0.3028],
          [ 0.3729,  0.2503,  0.3768]]],


        [[[ 0.5345,  0.4655,  0.0000],
          [ 0.6299,  0.3701,  0.0000],
          [ 0.3615,  0.6385,  0.0000],
          [ 0.3359,  0.6641,  0.0000],
          [ 0.4551,  0.5449,  0.0000]],

         [[ 0.6339,  0.3661,  0.0000],
          [ 0.4816,  0.5184,  0.0000],
          [ 0.5670,  0.4330,  0.0000],
          [ 0.4887,  0.5113,  0.0000],
          [ 0.3726,  0.6274,  0.0000]],

         [[ 0.5940,  0.4060,  0.0000],
          [ 0.2979,  0.7021,  0.0000],
          [ 0.5562,  0.4438,  0.0000],
          [ 0.5954,  0.4046,  0.0000],
          [ 0.3739,  0.6261,  0.0000]]]])

V:
tensor([[[[ 0.7363,  0.0968],
          [ 0.7019, -0.1864],
          [ 0.3013, -0.5080]],

         [[ 0.7667,  0.1734],
          [-0.2915,  1.3010],
          [ 0.4510,  1.5238]],

         [[-0.5390, -0.9138],
          [ 0.7364, -1.1734],
          [-0.6351, -1.3181]]],


        [[[ 0.9191,  0.4546],
          [ 0.7295,  1.4285],
          [ 0.3693, -0.8896]],

         [[ 0.9879,  0.4118],
          [ 1.1443, -0.6735],
          [ 0.2565,  0.9689]],

         [[ 0.0065, -1.2190],
          [-0.5786, -0.5889],
          [-0.5328, -0.9883]]]])

softmax(Q * K^T) * V:
tensor([[[[ 0.2434, -0.0875],
          [ 0.5974, -0.2400],
          [ 0.2747, -0.0943],
          [ 0.6170, -0.2319],
          [ 0.5957, -0.2427]],

         [[ 0.4288,  1.2988],
          [ 0.3622,  1.3591],
          [ 0.3938,  1.3308],
          [ 0.4687,  1.3408],
          [ 0.3731,  1.2676]],

         [[ 0.0244, -1.2751],
          [ 0.5259, -1.3118],
          [ 0.4443, -1.1870],
          [ 0.1762, -1.0648],
          [-0.2844, -1.2568]]],


        [[[ 0.9231,  1.0088],
          [ 0.9432,  0.9056],
          [ 0.8867,  1.1960],
          [ 0.8813,  1.2237],
          [ 0.9064,  1.0948]],

         [[ 1.1613,  0.0161],
          [ 0.5286,  0.2203],
          [ 1.1729, -0.0646],
          [ 1.1865, -0.1591],
          [ 0.7977, -0.4695]],

         [[-0.2567, -1.0702],
          [-0.4492, -0.8629],
          [ 0.0040, -0.7533],
          [-0.2558, -1.0712],
          [-0.3998, -0.9161]]]])

Merge, take a linear map and add to the tgt_embed_attn
tensor([[[-0.2161, -2.4811,  1.1464,  1.3303,  0.2541,  0.4871],
         [-1.3193, -2.5867, -1.2033,  0.2351, -0.2230,  0.0062],
         [ 0.7021, -2.1894,  0.5505,  2.3140,  0.8843,  0.1961],
         [-1.8023, -3.4669,  0.6040,  2.2281, -0.7768,  0.9494],
         [-2.7583, -0.8034,  0.1077,  0.7447,  0.3146, -0.4146]],

        [[ 0.2576, -1.8936, -0.4380,  0.7006, -0.3638,  0.8867],
         [-0.2066,  0.5637, -0.0672, -1.1310, -0.9635,  0.5576],
         [ 0.0467, -0.5195, -1.8140, -1.3613,  0.2738,  1.4879],
         [-1.1212, -0.4227, -1.0964, -0.5154, -0.8633,  2.4635],
         [-1.1198, -0.0412, -0.2934,  0.6077, -2.5765, -0.7313]]])

3-5. X = X + FF(X)
tensor([[[-0.1272, -1.8795,  0.3698,  1.0207,  0.5115,  0.1045],
         [-0.2001, -1.4223, -0.7834,  1.2942,  0.4322,  0.6795],
         [ 0.3626, -1.7664, -0.3882,  1.1899,  0.3173,  0.2848],
         [-0.7579, -1.5202,  0.3895,  1.3258,  0.1286,  0.4343],
         [-1.6360,  0.0659,  0.6310,  0.9940,  0.6550, -0.7098]],

        [[ 0.3640, -1.6767, -0.5457,  0.5169,  0.1269,  1.2146],
         [-0.3392,  1.3893,  0.6082, -0.7329, -1.3625,  0.4371],
         [ 0.2417, -0.2835, -1.2188, -0.9513,  1.2164,  0.9954],
         [-0.9828, -0.2044, -0.4296, -0.0254, -0.2876,  1.9299],
         [-0.7952, -0.1242, -0.0943,  1.5632, -1.2098,  0.6603]]])

out:
tensor([[[-0.1272, -1.8795,  0.3698,  1.0207,  0.5115,  0.1045],
         [-0.2001, -1.4223, -0.7834,  1.2942,  0.4322,  0.6795],
         [ 0.3626, -1.7664, -0.3882,  1.1899,  0.3173,  0.2848],
         [-0.7579, -1.5202,  0.3895,  1.3258,  0.1286,  0.4343],
         [-1.6360,  0.0659,  0.6310,  0.9940,  0.6550, -0.7098]],

        [[ 0.3640, -1.6767, -0.5457,  0.5169,  0.1269,  1.2146],
         [-0.3392,  1.3893,  0.6082, -0.7329, -1.3625,  0.4371],
         [ 0.2417, -0.2835, -1.2188, -0.9513,  1.2164,  0.9954],
         [-0.9828, -0.2044, -0.4296, -0.0254, -0.2876,  1.9299],
         [-0.7952, -0.1242, -0.0943,  1.5632, -1.2098,  0.6603]]])

4. Prediction

4-1. Calculate probability (linear map + softmax)

target vocabulary: ['a', 'am', 'month', 'i', 'student', 'what', '<blank>', '<s>', '</s>']
tensor([[[ 0.1593,  0.2304,  0.2158,  0.0366,  0.0287,  0.1716,  0.0475,
           0.0771,  0.0329],
         [ 0.1775,  0.1055,  0.2827,  0.0478,  0.0193,  0.2719,  0.0189,
           0.0265,  0.0499],
         [ 0.2190,  0.1487,  0.2939,  0.0531,  0.0175,  0.1848,  0.0242,
           0.0293,  0.0296],
         [ 0.1144,  0.1960,  0.1891,  0.0353,  0.0387,  0.2022,  0.0599,
           0.0902,  0.0741],
         [ 0.0652,  0.2027,  0.0562,  0.0103,  0.0422,  0.0570,  0.1214,
           0.3786,  0.0664]],

        [[ 0.1261,  0.0600,  0.2794,  0.1511,  0.0339,  0.2632,  0.0161,
           0.0157,  0.0545],
         [ 0.0193,  0.0192,  0.0304,  0.3145,  0.1829,  0.0173,  0.1027,
           0.0722,  0.2414],
         [ 0.1132,  0.0315,  0.1908,  0.2659,  0.0712,  0.2250,  0.0138,
           0.0266,  0.0620],
         [ 0.0306,  0.0193,  0.1023,  0.2058,  0.1102,  0.2093,  0.0229,
           0.0248,  0.2748],
         [ 0.0900,  0.0885,  0.1555,  0.0921,  0.0535,  0.1170,  0.0773,
           0.0569,  0.2692]]])

4-2. Predict based on the probability
argmax:
tensor([[ 1,  2,  2,  5,  7],
        [ 2,  3,  3,  8,  8]])

predicted sentences:
[['am', 'month', 'month', 'what', '<s>'], ['month', 'i', 'i', '</s>', '</s>']]

4-3. Compare to the true output
target batch true output: [['i', 'am', 'a', 'student', '</s>'], ['what', 'month', '</s>', '<blank>', '<blank>']]

About

Navigate inside of the transformer architecture

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages