自然语言处理中的注意力机制:Self-Attention到Multi-Head Attention

摘要

注意力机制(Attention Mechanism)是现代自然语言处理领域最重要的技术突破之一,从最初的序列到序列模型中的简单注意力,到Transformer架构中的自注意力机制,再到多头注意力的创新设计,注意力机制彻底改变了NLP任务的处理方式。本文将深入探讨注意力机制的发展历程、核心原理、技术实现和实际应用,重点分析Self-Attention和Multi-Head Attention的设计思想与优化策略,并提供详细的代码实现和案例分析。

关键词:注意力机制、Self-Attention、Multi-Head Attention、Transformer、自然语言处理、深度学习


1. 引言

自然语言处理(NLP)领域在过去十年中经历了革命性的变化,其中注意力机制的引入是最关键的技术突破之一。从2014年Bahdanau等人首次在神经机器翻译中引入注意力机制,到2017年Vaswani等人提出的Transformer架构完全基于注意力机制,这一技术已经成为现代NLP系统的核心组件。

传统的循环神经网络(RNN)和长短期记忆网络(LSTM)在处理长序列时面临梯度消失和计算效率低下的问题。注意力机制通过允许模型直接关注输入序列中的任意位置,有效解决了这些问题,并显著提升了模型的性能和可解释性。

本文将系统性地介绍注意力机制的发展历程,从基础的加性注意力到现代的多头自注意力机制,深入分析其数学原理、实现细节和应用场景,为读者提供全面而深入的技术理解。

2. 注意力机制基础理论

2.1 注意力机制的核心思想

注意力机制的核心思想源于人类的认知过程。当我们阅读一段文本或观察一个场景时,我们不会平等地关注所有信息,而是会将注意力集中在最相关的部分。在深度学习中,注意力机制模拟了这一过程,允许模型在处理序列数据时动态地分配注意力权重。

数学上,注意力机制可以表述为一个函数,它接受查询(Query)、键(Key)和值(Value)作为输入,输出加权的值的组合:

1
Attention(Q, K, V) = softmax(f(Q, K))V

其中:

  • Q(Query):查询向量,表示当前需要关注的信息
  • K(Key):键向量,表示可以被关注的信息
  • V(Value):值向量,表示实际的信息内容
  • f(Q, K):相似度函数,计算查询和键之间的匹配程度

2.2 早期注意力机制的发展

2.2.1 加性注意力(Additive Attention)

最早的注意力机制由Bahdanau等人在2014年提出,被称为加性注意力或Bahdanau注意力:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class AdditiveAttention(nn.Module):
"""加性注意力机制实现"""

def __init__(self, hidden_size, attention_size):
super(AdditiveAttention, self).__init__()
self.hidden_size = hidden_size
self.attention_size = attention_size

# 线性变换层
self.W_q = nn.Linear(hidden_size, attention_size, bias=False)
self.W_k = nn.Linear(hidden_size, attention_size, bias=False)
self.v = nn.Linear(attention_size, 1, bias=False)

# 初始化参数
self._init_weights()

def _init_weights(self):
"""初始化权重"""
nn.init.xavier_uniform_(self.W_q.weight)
nn.init.xavier_uniform_(self.W_k.weight)
nn.init.xavier_uniform_(self.v.weight)

def forward(self, query, keys, values, mask=None):
"""
前向传播

Args:
query: [batch_size, hidden_size] 查询向量
keys: [batch_size, seq_len, hidden_size] 键序列
values: [batch_size, seq_len, hidden_size] 值序列
mask: [batch_size, seq_len] 掩码

Returns:
context: [batch_size, hidden_size] 上下文向量
attention_weights: [batch_size, seq_len] 注意力权重
"""
batch_size, seq_len, _ = keys.size()

# 扩展查询向量维度
query_expanded = query.unsqueeze(1).expand(batch_size, seq_len, -1)

# 计算注意力分数
# e_ij = v^T * tanh(W_q * query + W_k * key_j)
query_proj = self.W_q(query_expanded) # [batch_size, seq_len, attention_size]
keys_proj = self.W_k(keys) # [batch_size, seq_len, attention_size]

# 加性注意力计算
energy = torch.tanh(query_proj + keys_proj) # [batch_size, seq_len, attention_size]
attention_scores = self.v(energy).squeeze(-1) # [batch_size, seq_len]

# 应用掩码
if mask is not None:
attention_scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(attention_scores, dim=-1) # [batch_size, seq_len]

# 计算上下文向量
context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)

return context, attention_weights

# 使用示例
def test_additive_attention():
"""测试加性注意力机制"""
batch_size, seq_len, hidden_size = 2, 5, 128
attention_size = 64

# 创建模型
attention = AdditiveAttention(hidden_size, attention_size)

# 创建测试数据
query = torch.randn(batch_size, hidden_size)
keys = torch.randn(batch_size, seq_len, hidden_size)
values = torch.randn(batch_size, seq_len, hidden_size)
mask = torch.ones(batch_size, seq_len)

# 前向传播
context, weights = attention(query, keys, values, mask)

print(f"Context shape: {context.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Attention weights sum: {weights.sum(dim=-1)}")

return context, weights

# 运行测试
if __name__ == "__main__":
test_additive_attention()

2.2.2 乘性注意力(Multiplicative Attention)

乘性注意力由Luong等人在2015年提出,计算效率更高:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class MultiplicativeAttention(nn.Module):
"""乘性注意力机制实现"""

def __init__(self, hidden_size, scale=True):
super(MultiplicativeAttention, self).__init__()
self.hidden_size = hidden_size
self.scale = scale

if scale:
self.scaling_factor = math.sqrt(hidden_size)

def forward(self, query, keys, values, mask=None):
"""
前向传播

Args:
query: [batch_size, hidden_size] 查询向量
keys: [batch_size, seq_len, hidden_size] 键序列
values: [batch_size, seq_len, hidden_size] 值序列
mask: [batch_size, seq_len] 掩码

Returns:
context: [batch_size, hidden_size] 上下文向量
attention_weights: [batch_size, seq_len] 注意力权重
"""
# 计算注意力分数
# e_ij = query^T * key_j
attention_scores = torch.bmm(
query.unsqueeze(1), # [batch_size, 1, hidden_size]
keys.transpose(1, 2) # [batch_size, hidden_size, seq_len]
).squeeze(1) # [batch_size, seq_len]

# 缩放(可选)
if self.scale:
attention_scores = attention_scores / self.scaling_factor

# 应用掩码
if mask is not None:
attention_scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(attention_scores, dim=-1)

# 计算上下文向量
context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)

return context, attention_weights

class GeneralAttention(nn.Module):
"""通用注意力机制(带权重矩阵的乘性注意力)"""

def __init__(self, hidden_size):
super(GeneralAttention, self).__init__()
self.hidden_size = hidden_size
self.W = nn.Linear(hidden_size, hidden_size, bias=False)

# 初始化权重
nn.init.xavier_uniform_(self.W.weight)

def forward(self, query, keys, values, mask=None):
"""
前向传播

Args:
query: [batch_size, hidden_size] 查询向量
keys: [batch_size, seq_len, hidden_size] 键序列
values: [batch_size, seq_len, hidden_size] 值序列
mask: [batch_size, seq_len] 掩码

Returns:
context: [batch_size, hidden_size] 上下文向量
attention_weights: [batch_size, seq_len] 注意力权重
"""
# 变换查询向量
query_transformed = self.W(query) # [batch_size, hidden_size]

# 计算注意力分数
# e_ij = query_transformed^T * key_j
attention_scores = torch.bmm(
query_transformed.unsqueeze(1), # [batch_size, 1, hidden_size]
keys.transpose(1, 2) # [batch_size, hidden_size, seq_len]
).squeeze(1) # [batch_size, seq_len]

# 应用掩码
if mask is not None:
attention_scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(attention_scores, dim=-1)

# 计算上下文向量
context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)

return context, attention_weights

2.3 注意力机制的数学基础

2.3.1 相似度函数

注意力机制的核心是计算查询和键之间的相似度。常用的相似度函数包括:

  1. 点积相似度

    1
    sim(q, k) = q^T k
  2. 缩放点积相似度

    1
    sim(q, k) = (q^T k) / √d_k
  3. 加性相似度

    1
    sim(q, k) = v^T tanh(W_q q + W_k k)
  4. 双线性相似度

    1
    sim(q, k) = q^T W k

2.3.2 注意力权重计算

注意力权重通过softmax函数归一化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def compute_attention_weights(scores, mask=None, temperature=1.0):
"""
计算注意力权重

Args:
scores: [batch_size, seq_len] 注意力分数
mask: [batch_size, seq_len] 掩码
temperature: 温度参数,控制分布的尖锐程度

Returns:
weights: [batch_size, seq_len] 注意力权重
"""
# 应用温度缩放
scaled_scores = scores / temperature

# 应用掩码
if mask is not None:
scaled_scores.masked_fill_(mask == 0, -1e9)

# 计算softmax
weights = F.softmax(scaled_scores, dim=-1)

return weights

def attention_entropy(weights):
"""
计算注意力权重的熵,衡量注意力分布的集中程度

Args:
weights: [batch_size, seq_len] 注意力权重

Returns:
entropy: [batch_size] 注意力熵
"""
# 避免log(0)
weights_safe = weights + 1e-8
entropy = -torch.sum(weights_safe * torch.log(weights_safe), dim=-1)

return entropy

3. Self-Attention机制深度解析

3.1 Self-Attention的核心概念

Self-Attention(自注意力)机制是Transformer架构的核心组件,它允许序列中的每个位置都能关注到序列中的所有位置,包括它自己。与传统的注意力机制不同,Self-Attention的查询、键和值都来自同一个输入序列。

3.2 Scaled Dot-Product Attention

Transformer中使用的是缩放点积注意力(Scaled Dot-Product Attention):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""

def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value, mask=None, return_attention=False):
"""
前向传播

Args:
query: [batch_size, seq_len, d_model] 查询矩阵
key: [batch_size, seq_len, d_model] 键矩阵
value: [batch_size, seq_len, d_model] 值矩阵
mask: [batch_size, seq_len, seq_len] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len, d_model] 输出
attention_weights: [batch_size, seq_len, seq_len] 注意力权重(可选)
"""
batch_size, seq_len, d_model = query.size()

# 计算注意力分数
# scores = Q * K^T / √d_k
scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(d_model)

# 应用掩码
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
output = torch.bmm(attention_weights, value)

if return_attention:
return output, attention_weights
else:
return output

class SelfAttention(nn.Module):
"""自注意力机制"""

def __init__(self, d_model, dropout=0.1):
super(SelfAttention, self).__init__()
self.d_model = d_model

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

# 注意力计算
self.attention = ScaledDotProductAttention(d_model, dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, x, mask=None, return_attention=False):
"""
前向传播

Args:
x: [batch_size, seq_len, d_model] 输入序列
mask: [batch_size, seq_len, seq_len] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len, d_model] 输出
attention_weights: [batch_size, seq_len, seq_len] 注意力权重(可选)
"""
# 线性变换得到Q、K、V
Q = self.W_q(x) # [batch_size, seq_len, d_model]
K = self.W_k(x) # [batch_size, seq_len, d_model]
V = self.W_v(x) # [batch_size, seq_len, d_model]

# 计算注意力
if return_attention:
attention_output, attention_weights = self.attention(
Q, K, V, mask, return_attention=True
)
else:
attention_output = self.attention(Q, K, V, mask)

# 输出线性变换
output = self.W_o(attention_output)

if return_attention:
return output, attention_weights
else:
return output

3.3 位置编码(Positional Encoding)

由于Self-Attention机制本身不包含位置信息,需要添加位置编码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PositionalEncoding(nn.Module):
"""位置编码"""

def __init__(self, d_model, max_len=5000, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)

# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

# 计算除数项
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))

# 应用sin和cos函数
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# 添加batch维度
pe = pe.unsqueeze(0).transpose(0, 1)

# 注册为buffer,不参与梯度更新
self.register_buffer('pe', pe)

def forward(self, x):
"""
前向传播

Args:
x: [seq_len, batch_size, d_model] 输入序列

Returns:
output: [seq_len, batch_size, d_model] 添加位置编码后的序列
"""
x = x + self.pe[:x.size(0), :]
return self.dropout(x)

class LearnablePositionalEncoding(nn.Module):
"""可学习的位置编码"""

def __init__(self, d_model, max_len=5000, dropout=0.1):
super(LearnablePositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)

# 可学习的位置嵌入
self.position_embeddings = nn.Embedding(max_len, d_model)

# 初始化
nn.init.normal_(self.position_embeddings.weight, std=0.02)

def forward(self, x):
"""
前向传播

Args:
x: [batch_size, seq_len, d_model] 输入序列

Returns:
output: [batch_size, seq_len, d_model] 添加位置编码后的序列
"""
batch_size, seq_len, d_model = x.size()

# 创建位置索引
position_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

# 获取位置编码
position_embeddings = self.position_embeddings(position_ids)

# 添加位置编码
x = x + position_embeddings

return self.dropout(x)

class RelativePositionalEncoding(nn.Module):
"""相对位置编码"""

def __init__(self, d_model, max_relative_position=128):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.max_relative_position = max_relative_position

# 相对位置嵌入
self.relative_position_embeddings = nn.Embedding(
2 * max_relative_position + 1, d_model
)

# 初始化
nn.init.normal_(self.relative_position_embeddings.weight, std=0.02)

def forward(self, seq_len):
"""
生成相对位置编码

Args:
seq_len: 序列长度

Returns:
relative_positions: [seq_len, seq_len, d_model] 相对位置编码
"""
# 创建相对位置矩阵
range_vec = torch.arange(seq_len)
range_mat = range_vec.unsqueeze(0).expand(seq_len, -1)
distance_mat = range_mat - range_mat.transpose(0, 1)

# 裁剪到最大相对位置
distance_mat_clipped = torch.clamp(
distance_mat, -self.max_relative_position, self.max_relative_position
)

# 转换为正数索引
final_mat = distance_mat_clipped + self.max_relative_position

# 获取相对位置编码
relative_positions = self.relative_position_embeddings(final_mat)

return relative_positions

3.4 Self-Attention的变体

3.4.1 因果自注意力(Causal Self-Attention)

用于语言模型等自回归任务:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class CausalSelfAttention(nn.Module):
"""因果自注意力机制(用于语言模型)"""

def __init__(self, d_model, dropout=0.1):
super(CausalSelfAttention, self).__init__()
self.d_model = d_model

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def _create_causal_mask(self, seq_len, device):
"""创建因果掩码"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask.unsqueeze(0) # [1, seq_len, seq_len]

def forward(self, x, return_attention=False):
"""
前向传播

Args:
x: [batch_size, seq_len, d_model] 输入序列
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len, d_model] 输出
attention_weights: [batch_size, seq_len, seq_len] 注意力权重(可选)
"""
batch_size, seq_len, d_model = x.size()

# 线性变换得到Q、K、V
Q = self.W_q(x) # [batch_size, seq_len, d_model]
K = self.W_k(x) # [batch_size, seq_len, d_model]
V = self.W_v(x) # [batch_size, seq_len, d_model]

# 计算注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_model)

# 应用因果掩码
causal_mask = self._create_causal_mask(seq_len, x.device)
scores.masked_fill_(causal_mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
attention_output = torch.bmm(attention_weights, V)
output = self.W_o(attention_output)

if return_attention:
return output, attention_weights
else:
return output

3.4.2 稀疏注意力(Sparse Attention)

为了处理长序列,可以使用稀疏注意力模式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class SparseAttention(nn.Module):
"""稀疏注意力机制"""

def __init__(self, d_model, pattern='local', window_size=128, dropout=0.1):
super(SparseAttention, self).__init__()
self.d_model = d_model
self.pattern = pattern
self.window_size = window_size

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

def _create_sparse_mask(self, seq_len, device):
"""创建稀疏注意力掩码"""
mask = torch.zeros(seq_len, seq_len, device=device)

if self.pattern == 'local':
# 局部注意力模式
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 1

elif self.pattern == 'strided':
# 步长注意力模式
stride = self.window_size
for i in range(seq_len):
# 局部窗口
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 1

# 步长位置
for j in range(0, seq_len, stride):
if j < seq_len:
mask[i, j] = 1

elif self.pattern == 'random':
# 随机稀疏模式
sparsity = 0.1 # 保留10%的连接
mask = torch.rand(seq_len, seq_len, device=device) < sparsity
mask = mask.float()

# 确保对角线为1(自注意力)
mask.fill_diagonal_(1)

return mask.unsqueeze(0) # [1, seq_len, seq_len]

def forward(self, x, return_attention=False):
"""
前向传播

Args:
x: [batch_size, seq_len, d_model] 输入序列
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len, d_model] 输出
attention_weights: [batch_size, seq_len, seq_len] 注意力权重(可选)
"""
batch_size, seq_len, d_model = x.size()

# 线性变换得到Q、K、V
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)

# 计算注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_model)

# 应用稀疏掩码
sparse_mask = self._create_sparse_mask(seq_len, x.device)
scores.masked_fill_(sparse_mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
attention_output = torch.bmm(attention_weights, V)
output = self.W_o(attention_output)

if return_attention:
return output, attention_weights
else:
return output

4. Multi-Head Attention机制详解

4.1 Multi-Head Attention的设计理念

Multi-Head Attention(多头注意力)是Transformer架构的核心创新之一。其基本思想是将注意力机制并行化,让模型能够同时关注不同类型的信息和不同的表示子空间。

4.2 Multi-Head Attention的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""

def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, query, key, value, mask=None, return_attention=False):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len_q, d_model] 输出
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] 注意力权重(可选)
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)
seq_len_v = value.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query) # [batch_size, seq_len_q, d_model]
K = self.W_k(key) # [batch_size, seq_len_k, d_model]
V = self.W_v(value) # [batch_size, seq_len_v, d_model]

# 重塑为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
# 现在形状为 [batch_size, num_heads, seq_len, d_k]

# 调整掩码维度
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)

# 计算多头注意力
attention_output, attention_weights = self._scaled_dot_product_attention(
Q, K, V, mask, return_attention=True
)

# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(attention_output)

if return_attention:
return output, attention_weights
else:
return output

def _scaled_dot_product_attention(self, Q, K, V, mask=None, return_attention=False):
"""
缩放点积注意力

Args:
Q: [batch_size, num_heads, seq_len_q, d_k] 查询矩阵
K: [batch_size, num_heads, seq_len_k, d_k] 键矩阵
V: [batch_size, num_heads, seq_len_v, d_k] 值矩阵
mask: [batch_size, num_heads, seq_len_q, seq_len_k] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, num_heads, seq_len_q, d_k] 输出
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] 注意力权重(可选)
"""
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# 应用掩码
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
output = torch.matmul(attention_weights, V)

if return_attention:
return output, attention_weights
else:
return output

4.3 Multi-Head Attention的优化变体

4.3.1 分组查询注意力(Grouped Query Attention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class GroupedQueryAttention(nn.Module):
"""分组查询注意力机制"""

def __init__(self, d_model, num_heads, num_kv_heads=None, dropout=0.1):
super(GroupedQueryAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads

assert d_model % num_heads == 0
assert num_heads % self.num_kv_heads == 0

self.d_k = d_model // num_heads
self.num_queries_per_kv = num_heads // self.num_kv_heads

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, self.num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, self.num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, query, key, value, mask=None, return_attention=False):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len_q, d_model] 输出
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] 注意力权重(可选)
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query) # [batch_size, seq_len_q, d_model]
K = self.W_k(key) # [batch_size, seq_len_k, num_kv_heads * d_k]
V = self.W_v(value) # [batch_size, seq_len_v, num_kv_heads * d_k]

# 重塑Q为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
# [batch_size, num_heads, seq_len_q, d_k]

# 重塑K、V为分组形式
K = K.view(batch_size, seq_len_k, self.num_kv_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_kv_heads, self.d_k).transpose(1, 2)
# [batch_size, num_kv_heads, seq_len_k, d_k]

# 扩展K、V以匹配Q的头数
K = K.repeat_interleave(self.num_queries_per_kv, dim=1)
V = V.repeat_interleave(self.num_queries_per_kv, dim=1)
# [batch_size, num_heads, seq_len_k, d_k]

# 调整掩码维度
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)

# 计算注意力
attention_output, attention_weights = self._scaled_dot_product_attention(
Q, K, V, mask, return_attention=True
)

# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(attention_output)

if return_attention:
return output, attention_weights
else:
return output

def _scaled_dot_product_attention(self, Q, K, V, mask=None, return_attention=False):
"""缩放点积注意力"""
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# 应用掩码
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
output = torch.matmul(attention_weights, V)

if return_attention:
return output, attention_weights
else:
return output

4.3.2 Flash Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class FlashAttention(nn.Module):
"""Flash Attention实现(简化版)"""

def __init__(self, d_model, num_heads, block_size=64, dropout=0.1):
super(FlashAttention, self).__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.block_size = block_size

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, query, key, value, mask=None, return_attention=False):
"""
前向传播(简化的Flash Attention)

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码
return_attention: 是否返回注意力权重

Returns:
output: [batch_size, seq_len_q, d_model] 输出
attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] 注意力权重(可选)
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# 重塑为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)

# 分块计算注意力(简化实现)
output = self._flash_attention_forward(Q, K, V, mask)

# 合并多头输出
output = output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(output)

if return_attention:
# 注意:Flash Attention通常不返回完整的注意力权重矩阵
return output, None
else:
return output

def _flash_attention_forward(self, Q, K, V, mask=None):
"""
Flash Attention前向传播(简化版)

Args:
Q: [batch_size, num_heads, seq_len_q, d_k] 查询矩阵
K: [batch_size, num_heads, seq_len_k, d_k] 键矩阵
V: [batch_size, num_heads, seq_len_v, d_k] 值矩阵
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, num_heads, seq_len_q, d_k] 输出
"""
batch_size, num_heads, seq_len_q, d_k = Q.size()
seq_len_k = K.size(2)

# 初始化输出
output = torch.zeros_like(Q)

# 分块处理
for i in range(0, seq_len_q, self.block_size):
end_i = min(i + self.block_size, seq_len_q)
Q_block = Q[:, :, i:end_i, :] # [batch_size, num_heads, block_size, d_k]

# 初始化块的累积值
block_output = torch.zeros_like(Q_block)
block_max = torch.full((batch_size, num_heads, end_i - i, 1),
-float('inf'), device=Q.device)
block_sum = torch.zeros((batch_size, num_heads, end_i - i, 1), device=Q.device)

for j in range(0, seq_len_k, self.block_size):
end_j = min(j + self.block_size, seq_len_k)
K_block = K[:, :, j:end_j, :] # [batch_size, num_heads, block_size, d_k]
V_block = V[:, :, j:end_j, :] # [batch_size, num_heads, block_size, d_k]

# 计算注意力分数
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / math.sqrt(d_k)

# 应用掩码(如果有)
if mask is not None:
mask_block = mask[:, i:end_i, j:end_j]
mask_block = mask_block.unsqueeze(1).expand(-1, num_heads, -1, -1)
scores.masked_fill_(mask_block == 0, -1e9)

# 在线softmax计算
block_max_new = torch.max(block_max, torch.max(scores, dim=-1, keepdim=True)[0])

# 更新之前的值
alpha = torch.exp(block_max - block_max_new)
block_output = block_output * alpha
block_sum = block_sum * alpha

# 计算当前块的贡献
exp_scores = torch.exp(scores - block_max_new)
current_sum = torch.sum(exp_scores, dim=-1, keepdim=True)
current_output = torch.matmul(exp_scores, V_block)

# 累积
block_output = block_output + current_output
block_sum = block_sum + current_sum
block_max = block_max_new

# 归一化
output[:, :, i:end_i, :] = block_output / block_sum

return output

4.4 注意力机制的可视化与分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import matplotlib.pyplot as plt
import seaborn as sns

class AttentionVisualizer:
"""注意力机制可视化工具"""

def __init__(self):
self.attention_weights = None
self.tokens = None

def visualize_attention_matrix(self, attention_weights, tokens=None,
head_idx=0, layer_name="Attention"):
"""
可视化注意力权重矩阵

Args:
attention_weights: [batch_size, num_heads, seq_len, seq_len] 注意力权重
tokens: 词汇列表
head_idx: 要可视化的注意力头索引
layer_name: 层名称
"""
# 提取指定头的注意力权重
weights = attention_weights[0, head_idx].detach().cpu().numpy()

# 创建热力图
plt.figure(figsize=(10, 8))
sns.heatmap(weights,
xticklabels=tokens if tokens else False,
yticklabels=tokens if tokens else False,
cmap='Blues',
cbar=True)

plt.title(f'{layer_name} - Head {head_idx}')
plt.xlabel('Key Positions')
plt.ylabel('Query Positions')
plt.tight_layout()
plt.show()

def visualize_attention_heads(self, attention_weights, tokens=None,
max_heads=8, layer_name="Attention"):
"""
可视化多个注意力头

Args:
attention_weights: [batch_size, num_heads, seq_len, seq_len] 注意力权重
tokens: 词汇列表
max_heads: 最大显示头数
layer_name: 层名称
"""
num_heads = min(attention_weights.size(1), max_heads)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for head_idx in range(num_heads):
weights = attention_weights[0, head_idx].detach().cpu().numpy()

sns.heatmap(weights,
ax=axes[head_idx],
xticklabels=tokens if tokens else False,
yticklabels=tokens if tokens else False,
cmap='Blues',
cbar=True)

axes[head_idx].set_title(f'Head {head_idx}')
axes[head_idx].set_xlabel('Key Positions')
axes[head_idx].set_ylabel('Query Positions')

plt.suptitle(f'{layer_name} - Multiple Heads')
plt.tight_layout()
plt.show()

def analyze_attention_patterns(self, attention_weights):
"""
分析注意力模式

Args:
attention_weights: [batch_size, num_heads, seq_len, seq_len] 注意力权重

Returns:
analysis: 分析结果字典
"""
weights = attention_weights[0].detach().cpu().numpy()
batch_size, num_heads, seq_len, _ = weights.shape

analysis = {
'entropy': [],
'max_attention': [],
'diagonal_attention': [],
'local_attention': []
}

for head_idx in range(num_heads):
head_weights = weights[head_idx]

# 计算熵(注意力分散程度)
entropy = -np.sum(head_weights * np.log(head_weights + 1e-8), axis=-1)
analysis['entropy'].append(np.mean(entropy))

# 最大注意力值
max_att = np.max(head_weights, axis=-1)
analysis['max_attention'].append(np.mean(max_att))

# 对角线注意力(自注意力强度)
diagonal_att = np.diag(head_weights)
analysis['diagonal_attention'].append(np.mean(diagonal_att))

# 局部注意力(相邻位置注意力)
local_att = 0
for i in range(seq_len - 1):
local_att += head_weights[i, i+1] + head_weights[i+1, i]
analysis['local_attention'].append(local_att / (2 * (seq_len - 1)))

return analysis

def plot_attention_statistics(self, analysis):
"""
绘制注意力统计图表

Args:
analysis: 分析结果字典
"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 熵分布
axes[0, 0].bar(range(len(analysis['entropy'])), analysis['entropy'])
axes[0, 0].set_title('Attention Entropy by Head')
axes[0, 0].set_xlabel('Head Index')
axes[0, 0].set_ylabel('Entropy')

# 最大注意力值
axes[0, 1].bar(range(len(analysis['max_attention'])), analysis['max_attention'])
axes[0, 1].set_title('Max Attention by Head')
axes[0, 1].set_xlabel('Head Index')
axes[0, 1].set_ylabel('Max Attention')

# 对角线注意力
axes[1, 0].bar(range(len(analysis['diagonal_attention'])), analysis['diagonal_attention'])
axes[1, 0].set_title('Diagonal Attention by Head')
axes[1, 0].set_xlabel('Head Index')
axes[1, 0].set_ylabel('Diagonal Attention')

# 局部注意力
axes[1, 1].bar(range(len(analysis['local_attention'])), analysis['local_attention'])
axes[1, 1].set_title('Local Attention by Head')
axes[1, 1].set_xlabel('Head Index')
axes[1, 1].set_ylabel('Local Attention')

plt.tight_layout()
plt.show()

# 使用示例
def test_attention_visualization():
"""测试注意力可视化"""
# 创建测试数据
batch_size, num_heads, seq_len, d_model = 1, 8, 10, 64

# 创建多头注意力模型
attention = MultiHeadAttention(d_model, num_heads)

# 创建测试输入
x = torch.randn(batch_size, seq_len, d_model)

# 前向传播
output, attention_weights = attention(x, x, x, return_attention=True)

# 创建可视化器
visualizer = AttentionVisualizer()

# 创建示例词汇
tokens = [f'token_{i}' for i in range(seq_len)]

# 可视化注意力矩阵
visualizer.visualize_attention_matrix(attention_weights, tokens, head_idx=0)

# 分析注意力模式
analysis = visualizer.analyze_attention_patterns(attention_weights)

# 绘制统计图表
visualizer.plot_attention_statistics(analysis)

return output, attention_weights, analysis

## 5. 注意力机制的实际应用

### 5.1 机器翻译中的注意力机制

```python
class TranslationAttention(nn.Module):
"""机器翻译中的注意力机制"""

def __init__(self, encoder_hidden_size, decoder_hidden_size, attention_size):
super(TranslationAttention, self).__init__()
self.encoder_hidden_size = encoder_hidden_size
self.decoder_hidden_size = decoder_hidden_size
self.attention_size = attention_size

# 注意力网络
self.W_encoder = nn.Linear(encoder_hidden_size, attention_size)
self.W_decoder = nn.Linear(decoder_hidden_size, attention_size)
self.v = nn.Linear(attention_size, 1)

# 上下文向量投影
self.W_context = nn.Linear(encoder_hidden_size, decoder_hidden_size)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_encoder, self.W_decoder, self.v, self.W_context]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, encoder_outputs, decoder_hidden, encoder_mask=None):
"""
前向传播

Args:
encoder_outputs: [batch_size, src_len, encoder_hidden_size] 编码器输出
decoder_hidden: [batch_size, decoder_hidden_size] 解码器隐藏状态
encoder_mask: [batch_size, src_len] 编码器掩码

Returns:
context_vector: [batch_size, decoder_hidden_size] 上下文向量
attention_weights: [batch_size, src_len] 注意力权重
"""
batch_size, src_len, encoder_hidden_size = encoder_outputs.size()

# 扩展解码器隐藏状态
decoder_hidden_expanded = decoder_hidden.unsqueeze(1).expand(
batch_size, src_len, -1
)

# 计算注意力分数
encoder_proj = self.W_encoder(encoder_outputs)
decoder_proj = self.W_decoder(decoder_hidden_expanded)

energy = torch.tanh(encoder_proj + decoder_proj)
attention_scores = self.v(energy).squeeze(-1)

# 应用掩码
if encoder_mask is not None:
attention_scores.masked_fill_(encoder_mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(attention_scores, dim=-1)

# 计算上下文向量
context_vector = torch.bmm(
attention_weights.unsqueeze(1), encoder_outputs
).squeeze(1)

# 投影到解码器维度
context_vector = self.W_context(context_vector)

return context_vector, attention_weights

class Seq2SeqWithAttention(nn.Module):
"""带注意力机制的序列到序列模型"""

def __init__(self, src_vocab_size, tgt_vocab_size, embedding_size,
hidden_size, num_layers=2, dropout=0.1):
super(Seq2SeqWithAttention, self).__init__()

# 嵌入层
self.src_embedding = nn.Embedding(src_vocab_size, embedding_size)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, embedding_size)

# 编码器
self.encoder = nn.LSTM(embedding_size, hidden_size, num_layers,
batch_first=True, dropout=dropout, bidirectional=True)

# 解码器
self.decoder = nn.LSTM(embedding_size + hidden_size * 2, hidden_size,
num_layers, batch_first=True, dropout=dropout)

# 注意力机制
self.attention = TranslationAttention(hidden_size * 2, hidden_size, hidden_size)

# 输出投影
self.output_projection = nn.Linear(hidden_size, tgt_vocab_size)

self.dropout = nn.Dropout(dropout)

def forward(self, src_tokens, tgt_tokens, src_mask=None):
"""
前向传播

Args:
src_tokens: [batch_size, src_len] 源语言词汇索引
tgt_tokens: [batch_size, tgt_len] 目标语言词汇索引
src_mask: [batch_size, src_len] 源语言掩码

Returns:
logits: [batch_size, tgt_len, tgt_vocab_size] 输出logits
attention_weights: [batch_size, tgt_len, src_len] 注意力权重
"""
batch_size, src_len = src_tokens.size()
tgt_len = tgt_tokens.size(1)

# 编码器
src_embeddings = self.src_embedding(src_tokens)
encoder_outputs, (encoder_hidden, encoder_cell) = self.encoder(src_embeddings)

# 初始化解码器状态
decoder_hidden = encoder_hidden[-1] # 使用最后一层的隐藏状态
decoder_cell = encoder_cell[-1]

# 解码器前向传播
tgt_embeddings = self.tgt_embedding(tgt_tokens)

outputs = []
attention_weights_list = []

for t in range(tgt_len):
# 计算注意力
context_vector, attention_weights = self.attention(
encoder_outputs, decoder_hidden, src_mask
)

# 准备解码器输入
decoder_input = torch.cat([
tgt_embeddings[:, t:t+1, :], # 当前目标词嵌入
context_vector.unsqueeze(1) # 上下文向量
], dim=-1)

# 解码器前向传播
decoder_output, (decoder_hidden, decoder_cell) = self.decoder(
decoder_input, (decoder_hidden.unsqueeze(0), decoder_cell.unsqueeze(0))
)

decoder_hidden = decoder_hidden.squeeze(0)
decoder_cell = decoder_cell.squeeze(0)

# 输出投影
output = self.output_projection(decoder_output.squeeze(1))

outputs.append(output)
attention_weights_list.append(attention_weights)

# 合并输出
logits = torch.stack(outputs, dim=1)
attention_weights = torch.stack(attention_weights_list, dim=1)

return logits, attention_weights

5.2 文本摘要中的注意力机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class HierarchicalAttention(nn.Module):
"""层次化注意力机制(用于文档级任务)"""

def __init__(self, word_hidden_size, sentence_hidden_size, attention_size):
super(HierarchicalAttention, self).__init__()

# 词级注意力
self.word_attention = nn.Sequential(
nn.Linear(word_hidden_size, attention_size),
nn.Tanh(),
nn.Linear(attention_size, 1)
)

# 句子级注意力
self.sentence_attention = nn.Sequential(
nn.Linear(sentence_hidden_size, attention_size),
nn.Tanh(),
nn.Linear(attention_size, 1)
)

def forward(self, word_outputs, sentence_outputs, word_mask=None, sentence_mask=None):
"""
前向传播

Args:
word_outputs: [batch_size, num_sentences, max_words, word_hidden_size] 词级输出
sentence_outputs: [batch_size, num_sentences, sentence_hidden_size] 句子级输出
word_mask: [batch_size, num_sentences, max_words] 词级掩码
sentence_mask: [batch_size, num_sentences] 句子级掩码

Returns:
document_representation: [batch_size, sentence_hidden_size] 文档表示
word_attention_weights: 词级注意力权重
sentence_attention_weights: 句子级注意力权重
"""
batch_size, num_sentences, max_words, word_hidden_size = word_outputs.size()

# 词级注意力
word_attention_scores = self.word_attention(word_outputs).squeeze(-1)

if word_mask is not None:
word_attention_scores.masked_fill_(word_mask == 0, -1e9)

word_attention_weights = F.softmax(word_attention_scores, dim=-1)

# 计算句子表示(词级加权平均)
sentence_representations = torch.sum(
word_attention_weights.unsqueeze(-1) * word_outputs, dim=2
)

# 句子级注意力
sentence_attention_scores = self.sentence_attention(sentence_representations).squeeze(-1)

if sentence_mask is not None:
sentence_attention_scores.masked_fill_(sentence_mask == 0, -1e9)

sentence_attention_weights = F.softmax(sentence_attention_scores, dim=-1)

# 计算文档表示(句子级加权平均)
document_representation = torch.sum(
sentence_attention_weights.unsqueeze(-1) * sentence_representations, dim=1
)

return document_representation, word_attention_weights, sentence_attention_weights

5.3 问答系统中的注意力机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class QuestionAnsweringAttention(nn.Module):
"""问答系统中的注意力机制"""

def __init__(self, hidden_size, attention_size):
super(QuestionAnsweringAttention, self).__init__()
self.hidden_size = hidden_size
self.attention_size = attention_size

# 双向注意力网络
self.W_context = nn.Linear(hidden_size, attention_size)
self.W_question = nn.Linear(hidden_size, attention_size)
self.W_similarity = nn.Linear(attention_size, 1)

# 自注意力网络
self.self_attention = MultiHeadAttention(hidden_size, num_heads=8)

# 输出层
self.start_pointer = nn.Linear(hidden_size * 2, 1)
self.end_pointer = nn.Linear(hidden_size * 2, 1)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_context, self.W_question, self.W_similarity,
self.start_pointer, self.end_pointer]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def forward(self, context_encoding, question_encoding, context_mask=None):
"""
前向传播

Args:
context_encoding: [batch_size, context_len, hidden_size] 上下文编码
question_encoding: [batch_size, question_len, hidden_size] 问题编码
context_mask: [batch_size, context_len] 上下文掩码

Returns:
start_logits: [batch_size, context_len] 开始位置logits
end_logits: [batch_size, context_len] 结束位置logits
attention_weights: [batch_size, context_len, question_len] 注意力权重
"""
batch_size, context_len, hidden_size = context_encoding.size()
question_len = question_encoding.size(1)

# 计算问题表示(平均池化)
question_representation = torch.mean(question_encoding, dim=1) # [batch_size, hidden_size]

# 扩展问题表示
question_expanded = question_representation.unsqueeze(1).expand(
batch_size, context_len, -1
)

# 计算上下文-问题注意力
context_proj = self.W_context(context_encoding)
question_proj = self.W_question(question_expanded)

similarity_scores = self.W_similarity(
torch.tanh(context_proj + question_proj)
).squeeze(-1)

# 应用掩码
if context_mask is not None:
similarity_scores.masked_fill_(context_mask == 0, -1e9)

attention_weights = F.softmax(similarity_scores, dim=-1)

# 计算注意力加权的上下文表示
attended_context = attention_weights.unsqueeze(-1) * context_encoding

# 自注意力增强
enhanced_context = self.self_attention(attended_context, attended_context, attended_context)

# 合并原始上下文和增强上下文
combined_context = torch.cat([context_encoding, enhanced_context], dim=-1)

# 预测开始和结束位置
start_logits = self.start_pointer(combined_context).squeeze(-1)
end_logits = self.end_pointer(combined_context).squeeze(-1)

# 应用掩码
if context_mask is not None:
start_logits.masked_fill_(context_mask == 0, -1e9)
end_logits.masked_fill_(context_mask == 0, -1e9)

return start_logits, end_logits, attention_weights

6. 注意力机制的优化策略

6.1 计算效率优化

6.1.1 线性注意力(Linear Attention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class LinearAttention(nn.Module):
"""线性注意力机制"""

def __init__(self, d_model, num_heads, feature_map='elu'):
super(LinearAttention, self).__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.feature_map = feature_map

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def _feature_map_func(self, x):
"""特征映射函数"""
if self.feature_map == 'elu':
return F.elu(x) + 1
elif self.feature_map == 'relu':
return F.relu(x)
else:
return x

def forward(self, query, key, value, mask=None):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, seq_len_q, d_model] 输出
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# 重塑为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)

# 应用特征映射
Q = self._feature_map_func(Q)
K = self._feature_map_func(K)

# 线性注意力计算
# O = Q * (K^T * V) / (Q * K^T * 1)
KV = torch.matmul(K.transpose(-2, -1), V) # [batch_size, num_heads, d_k, d_k]
QKV = torch.matmul(Q, KV) # [batch_size, num_heads, seq_len_q, d_k]

# 归一化项
K_sum = torch.sum(K, dim=-2, keepdim=True) # [batch_size, num_heads, 1, d_k]
normalizer = torch.matmul(Q, K_sum.transpose(-2, -1)) # [batch_size, num_heads, seq_len_q, 1]

# 避免除零
normalizer = torch.clamp(normalizer, min=1e-6)

# 计算输出
attention_output = QKV / normalizer

# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(attention_output)

return output

6.1.2 局部注意力(Local Attention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class LocalAttention(nn.Module):
"""局部注意力机制"""

def __init__(self, d_model, num_heads, window_size=128, dropout=0.1):
super(LocalAttention, self).__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.window_size = window_size

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

# 初始化权重
self._init_weights()

def _init_weights(self):
"""初始化权重"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)

def _create_local_mask(self, seq_len, device):
"""创建局部注意力掩码"""
mask = torch.zeros(seq_len, seq_len, device=device)

for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 1

return mask

def forward(self, query, key, value, mask=None):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, seq_len_q, d_model] 输出
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# 重塑为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)

# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# 创建局部掩码
local_mask = self._create_local_mask(seq_len_q, query.device)
local_mask = local_mask.unsqueeze(0).unsqueeze(0).expand(
batch_size, self.num_heads, -1, -1
)

# 应用局部掩码
scores.masked_fill_(local_mask == 0, -1e9)

# 应用额外掩码
if mask is not None:
mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
scores.masked_fill_(mask == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
attention_output = torch.matmul(attention_weights, V)

# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(attention_output)

return output

6.2 内存优化策略

6.2.1 梯度检查点(Gradient Checkpointing)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class CheckpointedMultiHeadAttention(nn.Module):
"""带梯度检查点的多头注意力"""

def __init__(self, d_model, num_heads, dropout=0.1, use_checkpoint=True):
super(CheckpointedMultiHeadAttention, self).__init__()
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.use_checkpoint = use_checkpoint

def forward(self, query, key, value, mask=None):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, seq_len_q, d_model] 输出
"""
if self.use_checkpoint and self.training:
# 使用梯度检查点
return torch.utils.checkpoint.checkpoint(
self.attention, query, key, value, mask
)
else:
return self.attention(query, key, value, mask)

7. 技术挑战与解决方案

7.1 长序列处理挑战

7.1.1 计算复杂度问题

传统的Self-Attention机制的时间复杂度为O(n²),其中n是序列长度。对于长序列,这会导致计算和内存开销急剧增加。

解决方案

  1. 稀疏注意力模式:只计算部分位置之间的注意力
  2. 线性注意力:将复杂度降低到O(n)
  3. 分层注意力:在不同层次上应用注意力机制

7.1.2 内存消耗问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class MemoryEfficientAttention(nn.Module):
"""内存高效的注意力机制"""

def __init__(self, d_model, num_heads, chunk_size=1024, dropout=0.1):
super(MemoryEfficientAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.chunk_size = chunk_size

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value, mask=None):
"""
分块计算注意力以节省内存

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, seq_len_q, d_model] 输出
"""
batch_size, seq_len_q, d_model = query.size()
seq_len_k = key.size(1)

# 线性变换得到Q、K、V
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# 重塑为多头形式
Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)

# 分块计算注意力
output_chunks = []

for i in range(0, seq_len_q, self.chunk_size):
end_i = min(i + self.chunk_size, seq_len_q)
Q_chunk = Q[:, :, i:end_i, :]

# 计算当前块的注意力分数
scores = torch.matmul(Q_chunk, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# 应用掩码
if mask is not None:
mask_chunk = mask[:, i:end_i, :]
mask_chunk = mask_chunk.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
scores.masked_fill_(mask_chunk == 0, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
chunk_output = torch.matmul(attention_weights, V)
output_chunks.append(chunk_output)

# 合并所有块
attention_output = torch.cat(output_chunks, dim=2)

# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, d_model
)

# 输出线性变换
output = self.W_o(attention_output)

return output

7.2 训练稳定性问题

7.2.1 梯度消失和爆炸

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class StableMultiHeadAttention(nn.Module):
"""稳定的多头注意力机制"""

def __init__(self, d_model, num_heads, dropout=0.1,
use_layer_norm=True, use_residual=True):
super(StableMultiHeadAttention, self).__init__()
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.use_layer_norm = use_layer_norm
self.use_residual = use_residual

if use_layer_norm:
self.layer_norm = nn.LayerNorm(d_model)

self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value, mask=None):
"""
前向传播

Args:
query: [batch_size, seq_len_q, d_model] 查询序列
key: [batch_size, seq_len_k, d_model] 键序列
value: [batch_size, seq_len_v, d_model] 值序列
mask: [batch_size, seq_len_q, seq_len_k] 注意力掩码

Returns:
output: [batch_size, seq_len_q, d_model] 输出
"""
# 注意力计算
attention_output = self.attention(query, key, value, mask)

# 应用dropout
attention_output = self.dropout(attention_output)

# 残差连接
if self.use_residual:
output = query + attention_output
else:
output = attention_output

# 层归一化
if self.use_layer_norm:
output = self.layer_norm(output)

return output

8. 总结与展望

8.1 核心贡献

注意力机制在自然语言处理领域的发展历程体现了深度学习技术的不断演进和创新。从最初的加性注意力到现代的多头自注意力机制,这一技术的发展带来了以下核心贡献:

  1. 突破序列建模限制:注意力机制有效解决了传统RNN在处理长序列时的梯度消失问题,使得模型能够捕获长距离依赖关系。

  2. 提升并行计算效率:Self-Attention机制的并行化特性显著提高了模型训练和推理的效率,为大规模语言模型的发展奠定了基础。

  3. 增强模型可解释性:注意力权重提供了模型决策过程的可视化途径,增强了深度学习模型的可解释性。

  4. 推动架构创新:Transformer架构完全基于注意力机制,开创了新的神经网络设计范式,影响了整个深度学习领域。

8.2 技术发展趋势

8.2.1 效率优化方向

未来注意力机制的发展将更加注重计算效率和内存优化:

  • 线性注意力:继续探索将注意力复杂度从O(n²)降低到O(n)的方法
  • 稀疏注意力:设计更加智能的稀疏模式,在保持性能的同时减少计算量
  • 硬件优化:针对特定硬件架构优化注意力计算,如GPU、TPU等

8.2.2 架构创新方向

  • 混合注意力:结合不同类型的注意力机制,如局部注意力和全局注意力的混合
  • 动态注意力:根据输入内容动态调整注意力模式和参数
  • 多模态注意力:扩展到处理文本、图像、音频等多模态数据的注意力机制

8.3 应用前景

注意力机制在未来将在以下领域发挥更大作用:

  1. 大语言模型:作为GPT、BERT等大型语言模型的核心组件,注意力机制将继续推动自然语言理解和生成能力的提升。

  2. 多模态AI:在视觉-语言模型、语音识别、视频理解等多模态任务中发挥关键作用。

  3. 科学计算:在蛋白质结构预测、药物发现、气候建模等科学计算领域展现巨大潜力。

  4. 边缘计算:通过效率优化,使注意力机制能够在移动设备和边缘设备上高效运行。

8.4 未来挑战

尽管注意力机制取得了巨大成功,但仍面临以下挑战:

  1. 可解释性:虽然注意力权重提供了一定的可解释性,但对于复杂任务的决策过程仍需要更深入的理解。

  2. 鲁棒性:提高模型对对抗样本和分布偏移的鲁棒性。

  3. 公平性:确保注意力机制不会放大训练数据中的偏见和不公平性。

  4. 能耗问题:大规模注意力模型的能耗问题需要通过算法和硬件协同优化来解决。

8.5 结语

注意力机制作为现代自然语言处理的核心技术,不仅改变了我们处理序列数据的方式,更为人工智能的发展开辟了新的道路。从Self-Attention到Multi-Head Attention,从Transformer到大语言模型,注意力机制的演进历程展现了深度学习技术的强大创新能力。

随着技术的不断发展,我们有理由相信,注意力机制将在未来的人工智能系统中发挥更加重要的作用,为构建更加智能、高效、可解释的AI系统提供强有力的技术支撑。


参考文献

  1. Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.

  2. Luong, M. T., Pham, H., & Manning, C. D. (2015). Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025.

  3. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.

  4. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.

  5. Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., … & Amodei, D. (2020). Language models are few-shot learners. Advances in neural information processing systems, 33, 1877-1901.

  6. Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451.

  7. Wang, S., Li, B. Z., Khabsa, M., Fang, H., & Ma, H. (2020). Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768.

  8. Child, R., Gray, S., Radford, A., & Sutskever, I. (2019). Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509.

关键词:注意力机制、Self-Attention、Multi-Head Attention、Transformer、自然语言处理、深度学习、序列建模、神经网络、机器翻译、问答系统

发布时间:2025年3月15日

版权所有,如有侵权请联系我