Navigate inside of the transformer architecture.
This code is strongly based on the code in the blog post "The Annotated Transformer". For more information, read Attention Is All You Need paper or The Annotated Transformer post.
0. Vocabulary
source vocabulary: ['etudiant', 'je', 'quel', 'suis', 'mois', '<blank>']
target vocabulary: ['a', 'am', 'month', 'i', 'student', 'what', '<blank>', '<s>', '</s>']
1. Preparing batch and mask
1-1. Batch (text)
source batch (text): [['je', 'suis', 'etudiant'], ['quel', 'mois', '<blank>']]
target batch (text): [['<s>', 'i', 'am', 'a', 'student'], ['<s>', 'what', 'month', '</s>', '<blank>']]
target batch true output: [['i', 'am', 'a', 'student', '</s>'], ['what', 'month', '</s>', '<blank>', '<blank>']]
1-2. Batch (number)
source batch (number):
tensor([[ 1, 3, 0],
[ 2, 4, 5]])
target batch (number):
tensor([[ 7, 3, 1, 0, 4],
[ 7, 5, 2, 8, 6]])
1-3. Mask
source mask:
tensor([[[ 1, 1, 1]],
[[ 1, 1, 0]]], dtype=torch.uint8)
target mask:
tensor([[[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0],
[ 1, 1, 1, 1, 0],
[ 1, 1, 1, 1, 1]],
[[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0],
[ 1, 1, 1, 1, 0],
[ 1, 1, 1, 1, 0]]], dtype=torch.uint8)
2. Encoding part
2-1. Source batch embedding + positional encoding (src_embed)
tensor([[[-0.0991, -0.3589, -0.5712, -0.2789, 1.8014, 0.9144],
[ 0.4986, -0.9652, -0.4070, 1.2863, -1.8285, -0.4280],
[ 1.9674, -1.2571, -1.0147, 1.0888, 1.1725, 2.6795]],
[[ 1.3305, 0.0000, -1.0701, -0.5923, -1.0296, 1.6793],
[ 0.8715, 1.3400, -0.0303, 0.0000, 1.6201, 1.5863],
[ 1.4312, -1.9720, 1.0480, 0.1499, 0.5277, 0.4981]]])
2-2. MultiHeadedAttention(src_embed (query), src_embed (key), src_embed (value), src_mask)
Q:
tensor([[[[ 0.2873, 0.2007],
[-0.9330, 0.9019],
[-0.0082, 1.6451]],
[[-0.5665, -0.7560],
[ 1.1108, 0.1960],
[ 0.9804, -0.2669]],
[[ 0.3628, -0.4734],
[-0.1485, 0.5912],
[ 0.4117, 0.2190]]],
[[[ 0.4587, 1.4866],
[ 0.5982, 0.2759],
[ 0.4676, 0.0464]],
[[ 1.8160, 0.3228],
[ 0.3496, 0.1092],
[ 0.0640, -1.3059]],
[[-0.0644, 0.6141],
[ 0.2253, -0.3131],
[-0.0457, -0.2041]]]])
K^T:
tensor([[[[-0.6992, 1.4689, 0.7740],
[-0.8642, 1.2513, 0.2392]],
[[-0.6000, -0.5576, -1.6619],
[ 0.1727, 0.3718, 0.8012]],
[[-0.1372, -0.2086, 0.4862],
[ 0.2631, -1.5935, -1.8292]]],
[[[ 0.7847, -0.6059, 0.3646],
[-0.0430, -0.8606, -0.0435]],
[[-1.3169, -0.4885, -1.3489],
[ 0.5333, 0.1433, -1.1985]],
[[ 1.4939, 0.8204, -0.8060],
[-2.0194, 0.1332, -1.2614]]]])
Q * K^T:
tensor([[[[-0.2647, 0.4760, 0.1912],
[-0.0899, -0.1711, -0.3580],
[-1.0013, 1.4471, 0.2738]],
[[ 0.1480, 0.0246, 0.2375],
[-0.4473, -0.3864, -1.1942],
[-0.4486, -0.4567, -1.3033]],
[[-0.1233, 0.4799, 0.7370],
[ 0.1244, -0.6442, -0.8158],
[ 0.0008, -0.3075, -0.1417]]],
[[[ 0.2093, -1.1012, 0.0725],
[ 0.3235, -0.4242, 0.1457],
[ 0.2580, -0.2286, 0.1191]],
[[-1.5694, -0.5946, -2.0058],
[-0.2844, -0.1097, -0.4260],
[-0.5520, -0.1544, 1.0457]],
[[-0.9450, 0.0205, -0.5111],
[ 0.6851, 0.1012, 0.1509],
[ 0.2431, -0.0457, 0.2081]]]])
mask:
tensor([[[[ 1, 1, 1]]],
[[[ 1, 1, 0]]]], dtype=torch.uint8)
softmax(Q * K^T) with masking:
tensor([[[[ 0.2139, 0.4486, 0.3375],
[ 0.3722, 0.3432, 0.2846],
[ 0.0619, 0.7164, 0.2216]],
[[ 0.3359, 0.2969, 0.3673],
[ 0.3942, 0.4190, 0.1868],
[ 0.4137, 0.4103, 0.1760]],
[[ 0.1926, 0.3521, 0.4553],
[ 0.5393, 0.2500, 0.2106],
[ 0.3843, 0.2824, 0.3333]]],
[[[ 0.7876, 0.2124, 0.0000],
[ 0.6787, 0.3213, 0.0000],
[ 0.6193, 0.3807, 0.0000]],
[[ 0.2739, 0.7261, 0.0000],
[ 0.4564, 0.5436, 0.0000],
[ 0.4019, 0.5981, 0.0000]],
[[ 0.2758, 0.7242, 0.0000],
[ 0.6420, 0.3580, 0.0000],
[ 0.5717, 0.4283, 0.0000]]]])
V:
tensor([[[[ 0.0429, 1.2085],
[-0.6691, -0.5429],
[-0.3989, 1.0405]],
[[ 0.1204, 0.7988],
[-0.4093, -1.5120],
[ 0.2817, 0.3272]],
[[-1.6174, 0.0450],
[ 0.1738, 0.8663],
[-2.2803, 0.3509]]],
[[[-0.2660, -0.3585],
[ 0.1631, 0.3127],
[ 0.2958, 1.5232]],
[[ 0.4697, 0.7489],
[-0.1156, 1.5091],
[ 1.4921, -0.5300]],
[[-2.0205, -0.7742],
[-1.4917, -0.9358],
[-2.0415, 0.3973]]]])
softmax(Q * K^T) * V:
tensor([[[[-0.4729, 0.4067],
[-0.3635, 0.6218],
[-0.6309, -0.1760]],
[[ 0.0248, -0.0672],
[-0.0794, -0.2861],
[-0.0762, -0.2582]],
[[-1.4318, 0.5260],
[-0.9209, 0.2676],
[-1.4806, 0.4209]]],
[[[-0.1943, -0.2399],
[-0.1424, -0.1587],
[-0.1141, -0.1144]],
[[ 0.0497, 1.4454],
[ 0.1684, 1.2913],
[ 0.1329, 1.3373]],
[[-1.8195, -0.9902],
[-2.0346, -0.9245],
[-1.9933, -0.9371]]]])
Merge, take a linear map and add to the src_embed
tensor([[[ 0.6415, -0.8070, -0.0989, -0.3344, 2.1601, 0.1320],
[ 0.4986, -1.6185, -0.0037, 0.9838, -2.2176, -0.7925],
[ 2.7234, -1.5721, -0.1605, 1.3591, 1.6872, 1.6879]],
[[ 2.7209, 0.0000, -1.0673, -0.0841, 0.0637, 1.5050],
[ 2.2868, 1.0852, 0.0689, 0.4620, 1.6201, 1.4560],
[ 2.8498, -2.2118, 1.0480, 0.6093, 1.6342, 0.3860]]])
2-3. X = X + FF(X)
tensor([[[ 1.8808, -0.3297, -0.3010, -0.6214, 0.2684, -0.8972],
[ 0.7203, -0.1711, 0.7326, 0.8862, -1.7084, -0.4596],
[ 1.5323, -0.8540, -1.1683, 0.2520, -0.3636, 0.6015]],
[[ 1.2566, 0.3957, -1.6387, -0.2336, -0.3900, 0.6101],
[ 0.9580, 1.0129, -1.4811, -0.2915, 0.5165, -0.7148],
[ 1.8221, -1.1817, 0.1599, -0.3333, -0.0959, -0.3710]]])
3. Decoding part
3-1. Target batch embedding + positional encoding (tgt_embed)
tensor([[[-0.3361, -0.5411, 0.5297, 2.3342, 0.8973, 0.1402],
[-0.4919, -0.6049, -1.6053, 0.2676, 0.8980, 0.2953],
[ 1.1844, -0.0000, -0.1628, 2.7691, 1.4741, 0.4418],
[-1.2462, -1.9267, 0.0446, 2.3728, 0.4344, 1.2240],
[-2.0308, 0.6776, -0.4236, 1.1829, 0.9161, -0.4064]],
[[-0.3361, -0.5411, 0.5297, 2.3342, 0.8973, 0.1402],
[ 1.0250, 1.5371, 0.4503, -0.3613, -0.7242, 1.5025],
[ 1.0694, 0.3562, -1.0864, 0.0000, -0.1417, 2.7160],
[ 0.1968, -0.1253, -0.0000, 0.5154, -0.2522, 2.0241],
[-0.1829, 0.0000, 0.0000, 2.0638, -1.6618, 0.0000]]])
3-2. Output from encoder (memory)
tensor([[[ 1.8808, -0.3297, -0.3010, -0.6214, 0.2684, -0.8972],
[ 0.7203, -0.1711, 0.7326, 0.8862, -1.7084, -0.4596],
[ 1.5323, -0.8540, -1.1683, 0.2520, -0.3636, 0.6015]],
[[ 1.2566, 0.3957, -1.6387, -0.2336, -0.3900, 0.6101],
[ 0.9580, 1.0129, -1.4811, -0.2915, 0.5165, -0.7148],
[ 1.8221, -1.1817, 0.1599, -0.3333, -0.0959, -0.3710]]])
3-3. MultiHeadedAttention(tgt_embed (query), tgt_embed (key), tgt_embed (value), tgt_mask)
Q:
tensor([[[[-0.5421, 1.5738],
[-0.7382, 0.7951],
[-1.6815, 1.0081],
[ 0.2136, 1.1754],
[-0.0839, 1.0675]],
[[ 1.2160, -0.9961],
[ 0.8194, -1.0782],
[ 1.8926, -1.9814],
[ 0.5622, -0.5471],
[ 0.0154, -0.1365]],
[[-1.8108, -1.0346],
[-0.7414, -0.6354],
[-0.7128, -0.2254],
[-1.9131, -1.1539],
[-1.6068, -1.5226]]],
[[[-0.5421, 1.5738],
[ 0.6245, -2.5758],
[ 0.2095, -1.5605],
[ 0.9281, -1.0101],
[-0.6299, -0.6413]],
[[ 1.2160, -0.9961],
[-1.6406, 0.3268],
[-0.7254, -0.4050],
[-0.9473, 0.0529],
[ 0.2568, -1.0327]],
[[-1.8108, -1.0346],
[ 1.1267, 0.6528],
[ 0.4754, 0.2389],
[-0.5415, -0.4035],
[-0.5191, -0.2838]]]])
K^T:
tensor([[[[-1.6267, 0.1434, -0.5614, -1.9235, -0.9945],
[-1.5008, -0.8060, -1.2070, -2.0382, -0.2845]],
[[ 1.0226, -0.1413, 1.0822, 0.0854, 1.3647],
[-0.6622, 0.7058, 0.1667, -0.4866, -1.1672]],
[[-0.6153, -0.2148, -0.0583, -0.9386, -0.7262],
[-0.6187, -0.2110, -0.5467, -0.2972, 0.1690]]],
[[[-1.6267, 0.1631, -0.0309, -1.3011, -1.9562],
[-1.5008, 0.1211, -0.9794, -1.7500, -1.4243]],
[[ 1.0226, -0.4784, -1.0563, -1.0146, 1.6665],
[-0.6622, -0.8010, 0.2722, -0.3651, -1.6178]],
[[-0.6153, -0.3968, -0.4770, -1.0090, -0.7809],
[-0.6187, 1.4376, 0.9992, 0.9124, 0.8331]]]])
Q * K^T:
tensor([[[[-1.0467, -0.9519, -1.1280, -1.5309, 0.0646],
[ 0.0053, -0.5280, -0.3855, -0.1418, 0.3592],
[ 0.8643, -0.7450, -0.1928, 0.8342, 0.9797],
[-1.4931, -0.6482, -1.0880, -1.9846, -0.3867],
[-1.0364, -0.6169, -0.8777, -1.4244, -0.1558]],
[[ 1.3457, -0.6186, 0.8131, 0.4162, 1.9956],
[ 1.0974, -0.6200, 0.4999, 0.4205, 1.6806],
[ 2.2964, -1.1779, 1.2147, 0.7960, 3.4617],
[ 0.6627, -0.3292, 0.3657, 0.2222, 0.9941],
[ 0.0750, -0.0696, -0.0043, 0.0479, 0.1275]],
[[ 1.2405, 0.4294, 0.4746, 1.4192, 0.8062],
[ 0.6005, 0.2074, 0.2762, 0.6255, 0.3048],
[ 0.4088, 0.1419, 0.1165, 0.5204, 0.3391],
[ 1.3372, 0.4627, 0.5249, 1.5122, 0.8445],
[ 1.3653, 0.4712, 0.6548, 1.3864, 0.6432]]],
[[[-1.0467, 0.0722, -1.0781, -1.4488, -0.8352],
[ 2.0153, -0.1485, 1.7703, 2.6130, 1.7303],
[ 1.4151, -0.1094, 1.0762, 1.7383, 1.2818],
[ 0.0045, 0.0206, 0.6793, 0.3962, -0.2664],
[ 1.4051, -0.1275, 0.4579, 1.3731, 1.5171]],
[[ 1.3457, 0.1528, -1.1000, -0.6152, 2.5725],
[-1.3393, 0.3699, 1.2883, 1.0926, -2.3071],
[-0.3349, 0.4748, 0.4638, 0.6250, -0.3915],
[-0.7097, 0.2905, 0.7177, 0.6660, -1.1767],
[ 0.6693, 0.4980, -0.3906, 0.0823, 1.4841]],
[[ 1.2405, -0.5436, -0.1203, 0.6244, 0.3904],
[-0.7758, 0.3474, 0.0813, -0.3827, -0.2376],
[-0.3114, 0.1094, 0.0084, -0.1851, -0.1218],
[ 0.4122, -0.2582, -0.1025, 0.1260, 0.0613],
[ 0.3500, -0.1428, -0.0254, 0.1873, 0.1195]]]])
mask:
tensor([[[[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0],
[ 1, 1, 1, 1, 0],
[ 1, 1, 1, 1, 1]]],
[[[ 1, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0],
[ 1, 1, 1, 0, 0],
[ 1, 1, 1, 1, 0],
[ 1, 1, 1, 1, 0]]]], dtype=torch.uint8)
softmax(Q * K^T) with masking:
tensor([[[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.6303, 0.3697, 0.0000, 0.0000, 0.0000],
[ 0.6462, 0.1293, 0.2245, 0.0000, 0.0000],
[ 0.1839, 0.4280, 0.2757, 0.1125, 0.0000],
[ 0.1474, 0.2242, 0.1727, 0.1000, 0.3556]],
[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.8478, 0.1522, 0.0000, 0.0000, 0.0000],
[ 0.7299, 0.0226, 0.2475, 0.0000, 0.0000],
[ 0.3626, 0.1345, 0.2695, 0.2334, 0.0000],
[ 0.2076, 0.1797, 0.1918, 0.2021, 0.2188]],
[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.5970, 0.4030, 0.0000, 0.0000, 0.0000],
[ 0.3980, 0.3048, 0.2972, 0.0000, 0.0000],
[ 0.3276, 0.1366, 0.1454, 0.3903, 0.0000],
[ 0.2935, 0.1200, 0.1442, 0.2997, 0.1426]]],
[[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.8970, 0.1030, 0.0000, 0.0000, 0.0000],
[ 0.5181, 0.1128, 0.3691, 0.0000, 0.0000],
[ 0.1832, 0.1861, 0.3597, 0.2710, 0.0000],
[ 0.3888, 0.0840, 0.1508, 0.3765, 0.0000]],
[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.1533, 0.8467, 0.0000, 0.0000, 0.0000],
[ 0.1828, 0.4108, 0.4064, 0.0000, 0.0000],
[ 0.0844, 0.2296, 0.3519, 0.3341, 0.0000],
[ 0.3643, 0.3069, 0.1262, 0.2025, 0.0000]],
[[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.2454, 0.7546, 0.0000, 0.0000, 0.0000],
[ 0.2564, 0.3905, 0.3530, 0.0000, 0.0000],
[ 0.3496, 0.1788, 0.2090, 0.2626, 0.0000],
[ 0.3177, 0.1941, 0.2182, 0.2700, 0.0000]]]])
V:
tensor([[[[-0.1569, -0.1303],
[ 0.4915, 1.1576],
[ 0.2951, -0.1244],
[ 0.2081, 0.4328],
[ 0.2947, 0.7468]],
[[-0.6981, -0.6486],
[-0.3091, -0.6712],
[-0.3662, -0.4734],
[-0.3013, -0.5132],
[-1.3433, -1.7526]],
[[-0.7434, 0.9040],
[ 0.3286, -0.5518],
[-1.4179, 0.4413],
[-0.1633, 0.4952],
[ 0.2411, -0.3765]]],
[[[-0.1569, -0.1303],
[ 1.1811, 0.0495],
[ 1.4490, 0.9874],
[ 1.2789, 0.9373],
[ 1.2890, -0.4136]],
[[-0.6981, -0.6486],
[ 0.6316, -0.6768],
[ 0.7180, -0.4615],
[ 0.5463, -0.5267],
[-0.3281, -1.2527]],
[[-0.7434, 0.9040],
[-0.3534, -1.5435],
[-0.0166, -1.5146],
[ 0.1169, -1.0003],
[-2.0980, -0.2976]]]])
softmax(Q * K^T) * V:
tensor([[[[-0.1744, -0.1448],
[ 0.0920, 0.3843],
[ 0.0315, 0.0417],
[ 0.3181, 0.5398],
[ 0.1705, 0.2979]],
[[-0.7757, -0.7207],
[-0.7099, -0.7245],
[-0.6747, -0.6731],
[-0.5153, -0.6365],
[-0.6951, -0.9259]],
[[-0.8260, 1.0045],
[ 0.1471, -0.2471],
[-0.3569, -0.0412],
[-0.2916, 0.4601],
[-0.4420, 0.3972]]],
[[[-0.1744, -0.1448],
[-0.0212, -0.1242],
[ 0.6520, 0.3362],
[ 1.1765, 0.6606],
[ 0.7100, 0.5013]],
[[-0.7757, -0.7207],
[ 0.4754, -0.7472],
[ 0.4707, -0.6490],
[ 0.5791, -0.6094],
[ 0.1565, -0.6766]],
[[-0.8260, 1.0045],
[-0.4990, -1.0476],
[-0.3652, -0.4122],
[-0.3287, -0.5990],
[-0.2314, -0.3482]]]])
Merge, take a linear map and add to the tgt_embed (tgt_embed_attn)
tensor([[[ 0.2576, -1.3072, 1.8100, 0.7006, 0.6683, 0.4353],
[-0.9247, -1.0474, -0.8220, -0.5448, 0.5661, -0.0052],
[ 1.0528, -0.7254, 0.9485, 1.6799, 1.5480, 0.1961],
[-1.2074, -2.0542, 0.9351, 1.5558, 0.0306, 0.8459],
[-2.0669, 0.1645, 0.8701, -0.0027, 0.7121, -0.7640]],
[[ 0.2576, -1.3072, 0.5297, 0.7006, 0.8973, 0.4353],
[ 0.3881, 0.7784, 1.2159, -0.5447, -0.0654, 0.0339],
[ 0.3439, 0.2162, -1.0864, -0.0267, 0.2738, 1.2346],
[-0.8429, 0.0335, 0.0000, 0.8225, 0.4094, 2.0241],
[-0.9050, -0.0864, 0.9056, 1.9073, -1.6618, -1.3210]]])
3-4. MultiHeadedAttention(tgt_embed_attn (query), memory (key), memory (value), src_mask)
Q:
tensor([[[[-0.0758, 0.9863],
[-0.1229, -0.9282],
[ 1.4340, 0.6957],
[-0.2881, -0.4703],
[-0.6102, -0.8839]],
[[-1.6173, -0.7890],
[-1.4365, -2.1836],
[-1.5107, -1.5196],
[-2.2072, -0.4909],
[-1.0444, -1.1330]],
[[ 0.0027, -0.5079],
[ 0.8844, -0.9470],
[ 0.3176, -1.0832],
[ 0.2971, -0.1874],
[ 0.1125, 0.5323]]],
[[[ 0.5328, 0.4803],
[-1.1035, 0.9301],
[-0.4100, -1.4574],
[-1.3444, -1.9944],
[ 0.3436, -0.3239]],
[[-1.8617, -1.4616],
[ 1.2383, 0.4401],
[ 0.2157, -0.4388],
[-1.4478, -0.2748],
[-0.6261, 0.7962]],
[[ 0.4649, -1.0808],
[-1.0231, 0.7122],
[ 0.2729, -0.4496],
[ 0.4535, 0.1840],
[-0.6205, 0.8188]]]])
K^T:
tensor([[[[ 0.2184, 0.2875, 0.1470],
[-0.7032, 0.7268, -0.8534]],
[[ 0.9941, 1.4162, 0.3437],
[-0.1453, -0.6175, -0.4588]],
[[-1.8275, 0.8175, -0.4833],
[ 0.9539, -0.6637, 0.6972]]],
[[[ 1.2250, 1.3999, -0.5532],
[-1.0456, -1.6467, -0.1479]],
[[ 0.8125, 0.6212, 0.7205],
[-0.1804, 0.5947, -0.6608]],
[[-0.4169, -1.6140, -1.1815],
[ 0.3413, 0.3242, 0.9922]]]])
Q * K^T:
tensor([[[[-0.5021, 0.4915, -0.6030],
[ 0.4426, -0.5020, 0.5473],
[-0.1245, 0.6490, -0.2707],
[ 0.1894, -0.3003, 0.2538],
[ 0.3453, -0.5783, 0.4699]],
[[-1.0558, -1.2751, -0.1370],
[-0.7855, -0.4852, 0.3593],
[-0.9058, -0.8494, 0.1259],
[-1.5010, -1.9960, -0.3771],
[-0.6178, -0.5512, 0.1138]],
[[-0.3460, 0.2399, -0.2513],
[-1.7816, 0.9557, -0.7691],
[-1.1411, 0.6920, -0.6426],
[-0.5104, 0.2597, -0.1939],
[ 0.2137, -0.1848, 0.2240]]],
[[[ 0.1064, -0.0319, -0.2586],
[-1.6434, -2.1752, 0.3344],
[ 0.7224, 1.2912, 0.3127],
[ 0.3100, 0.9915, 0.7344],
[ 0.5371, 0.7172, -0.1005]],
[[-0.8832, -1.4323, -0.2656],
[ 0.6553, 0.7290, 0.4253],
[ 0.1799, -0.0898, 0.3149],
[-0.7968, -0.7515, -0.6093],
[-0.4613, 0.0598, -0.6911]],
[[-0.3978, -0.7784, -1.1467],
[ 0.4734, 1.3309, 1.3543],
[-0.1890, -0.4146, -0.5435],
[-0.0893, -0.4754, -0.2498],
[ 0.3805, 0.8959, 1.0928]]]])
mask:
tensor([[[[ 1, 1, 1]]],
[[[ 1, 1, 0]]]], dtype=torch.uint8)
softmax(Q * K^T) with masking:
tensor([[[[ 0.2172, 0.5865, 0.1963],
[ 0.4001, 0.1556, 0.4443],
[ 0.2481, 0.5376, 0.2143],
[ 0.3732, 0.2287, 0.3981],
[ 0.3953, 0.1570, 0.4478]],
[[ 0.2321, 0.1864, 0.5816],
[ 0.1821, 0.2459, 0.5721],
[ 0.2056, 0.2175, 0.5769],
[ 0.2134, 0.1301, 0.6565],
[ 0.2411, 0.2577, 0.5011]],
[[ 0.2567, 0.4612, 0.2822],
[ 0.0521, 0.8045, 0.1434],
[ 0.1124, 0.7026, 0.1850],
[ 0.2206, 0.4766, 0.3028],
[ 0.3729, 0.2503, 0.3768]]],
[[[ 0.5345, 0.4655, 0.0000],
[ 0.6299, 0.3701, 0.0000],
[ 0.3615, 0.6385, 0.0000],
[ 0.3359, 0.6641, 0.0000],
[ 0.4551, 0.5449, 0.0000]],
[[ 0.6339, 0.3661, 0.0000],
[ 0.4816, 0.5184, 0.0000],
[ 0.5670, 0.4330, 0.0000],
[ 0.4887, 0.5113, 0.0000],
[ 0.3726, 0.6274, 0.0000]],
[[ 0.5940, 0.4060, 0.0000],
[ 0.2979, 0.7021, 0.0000],
[ 0.5562, 0.4438, 0.0000],
[ 0.5954, 0.4046, 0.0000],
[ 0.3739, 0.6261, 0.0000]]]])
V:
tensor([[[[ 0.7363, 0.0968],
[ 0.7019, -0.1864],
[ 0.3013, -0.5080]],
[[ 0.7667, 0.1734],
[-0.2915, 1.3010],
[ 0.4510, 1.5238]],
[[-0.5390, -0.9138],
[ 0.7364, -1.1734],
[-0.6351, -1.3181]]],
[[[ 0.9191, 0.4546],
[ 0.7295, 1.4285],
[ 0.3693, -0.8896]],
[[ 0.9879, 0.4118],
[ 1.1443, -0.6735],
[ 0.2565, 0.9689]],
[[ 0.0065, -1.2190],
[-0.5786, -0.5889],
[-0.5328, -0.9883]]]])
softmax(Q * K^T) * V:
tensor([[[[ 0.2434, -0.0875],
[ 0.5974, -0.2400],
[ 0.2747, -0.0943],
[ 0.6170, -0.2319],
[ 0.5957, -0.2427]],
[[ 0.4288, 1.2988],
[ 0.3622, 1.3591],
[ 0.3938, 1.3308],
[ 0.4687, 1.3408],
[ 0.3731, 1.2676]],
[[ 0.0244, -1.2751],
[ 0.5259, -1.3118],
[ 0.4443, -1.1870],
[ 0.1762, -1.0648],
[-0.2844, -1.2568]]],
[[[ 0.9231, 1.0088],
[ 0.9432, 0.9056],
[ 0.8867, 1.1960],
[ 0.8813, 1.2237],
[ 0.9064, 1.0948]],
[[ 1.1613, 0.0161],
[ 0.5286, 0.2203],
[ 1.1729, -0.0646],
[ 1.1865, -0.1591],
[ 0.7977, -0.4695]],
[[-0.2567, -1.0702],
[-0.4492, -0.8629],
[ 0.0040, -0.7533],
[-0.2558, -1.0712],
[-0.3998, -0.9161]]]])
Merge, take a linear map and add to the tgt_embed_attn
tensor([[[-0.2161, -2.4811, 1.1464, 1.3303, 0.2541, 0.4871],
[-1.3193, -2.5867, -1.2033, 0.2351, -0.2230, 0.0062],
[ 0.7021, -2.1894, 0.5505, 2.3140, 0.8843, 0.1961],
[-1.8023, -3.4669, 0.6040, 2.2281, -0.7768, 0.9494],
[-2.7583, -0.8034, 0.1077, 0.7447, 0.3146, -0.4146]],
[[ 0.2576, -1.8936, -0.4380, 0.7006, -0.3638, 0.8867],
[-0.2066, 0.5637, -0.0672, -1.1310, -0.9635, 0.5576],
[ 0.0467, -0.5195, -1.8140, -1.3613, 0.2738, 1.4879],
[-1.1212, -0.4227, -1.0964, -0.5154, -0.8633, 2.4635],
[-1.1198, -0.0412, -0.2934, 0.6077, -2.5765, -0.7313]]])
3-5. X = X + FF(X)
tensor([[[-0.1272, -1.8795, 0.3698, 1.0207, 0.5115, 0.1045],
[-0.2001, -1.4223, -0.7834, 1.2942, 0.4322, 0.6795],
[ 0.3626, -1.7664, -0.3882, 1.1899, 0.3173, 0.2848],
[-0.7579, -1.5202, 0.3895, 1.3258, 0.1286, 0.4343],
[-1.6360, 0.0659, 0.6310, 0.9940, 0.6550, -0.7098]],
[[ 0.3640, -1.6767, -0.5457, 0.5169, 0.1269, 1.2146],
[-0.3392, 1.3893, 0.6082, -0.7329, -1.3625, 0.4371],
[ 0.2417, -0.2835, -1.2188, -0.9513, 1.2164, 0.9954],
[-0.9828, -0.2044, -0.4296, -0.0254, -0.2876, 1.9299],
[-0.7952, -0.1242, -0.0943, 1.5632, -1.2098, 0.6603]]])
out:
tensor([[[-0.1272, -1.8795, 0.3698, 1.0207, 0.5115, 0.1045],
[-0.2001, -1.4223, -0.7834, 1.2942, 0.4322, 0.6795],
[ 0.3626, -1.7664, -0.3882, 1.1899, 0.3173, 0.2848],
[-0.7579, -1.5202, 0.3895, 1.3258, 0.1286, 0.4343],
[-1.6360, 0.0659, 0.6310, 0.9940, 0.6550, -0.7098]],
[[ 0.3640, -1.6767, -0.5457, 0.5169, 0.1269, 1.2146],
[-0.3392, 1.3893, 0.6082, -0.7329, -1.3625, 0.4371],
[ 0.2417, -0.2835, -1.2188, -0.9513, 1.2164, 0.9954],
[-0.9828, -0.2044, -0.4296, -0.0254, -0.2876, 1.9299],
[-0.7952, -0.1242, -0.0943, 1.5632, -1.2098, 0.6603]]])
4. Prediction
4-1. Calculate probability (linear map + softmax)
target vocabulary: ['a', 'am', 'month', 'i', 'student', 'what', '<blank>', '<s>', '</s>']
tensor([[[ 0.1593, 0.2304, 0.2158, 0.0366, 0.0287, 0.1716, 0.0475,
0.0771, 0.0329],
[ 0.1775, 0.1055, 0.2827, 0.0478, 0.0193, 0.2719, 0.0189,
0.0265, 0.0499],
[ 0.2190, 0.1487, 0.2939, 0.0531, 0.0175, 0.1848, 0.0242,
0.0293, 0.0296],
[ 0.1144, 0.1960, 0.1891, 0.0353, 0.0387, 0.2022, 0.0599,
0.0902, 0.0741],
[ 0.0652, 0.2027, 0.0562, 0.0103, 0.0422, 0.0570, 0.1214,
0.3786, 0.0664]],
[[ 0.1261, 0.0600, 0.2794, 0.1511, 0.0339, 0.2632, 0.0161,
0.0157, 0.0545],
[ 0.0193, 0.0192, 0.0304, 0.3145, 0.1829, 0.0173, 0.1027,
0.0722, 0.2414],
[ 0.1132, 0.0315, 0.1908, 0.2659, 0.0712, 0.2250, 0.0138,
0.0266, 0.0620],
[ 0.0306, 0.0193, 0.1023, 0.2058, 0.1102, 0.2093, 0.0229,
0.0248, 0.2748],
[ 0.0900, 0.0885, 0.1555, 0.0921, 0.0535, 0.1170, 0.0773,
0.0569, 0.2692]]])
4-2. Predict based on the probability
argmax:
tensor([[ 1, 2, 2, 5, 7],
[ 2, 3, 3, 8, 8]])
predicted sentences:
[['am', 'month', 'month', 'what', '<s>'], ['month', 'i', 'i', '</s>', '</s>']]
4-3. Compare to the true output
target batch true output: [['i', 'am', 'a', 'student', '</s>'], ['what', 'month', '</s>', '<blank>', '<blank>']]
