1.代码展示

Transformer示例代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import dataset
import time
from tempfile import TemporaryDirectory
from typing import Tuple

# 定义Transformer模型
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)

self.init_weights()

def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output


# 定义位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)

position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)

def forward(self, x: Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)

def generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
return mask

# 加载本地Wikitext-2数据集并预处理
def load_and_preprocess_data(local_data_path: str):
# 验证数据路径是否存在
if not os.path.exists(local_data_path):
raise FileNotFoundError(f"Dataset not found at {local_data_path}")

# 使用torchtext加载本地数据
def yield_tokens(file_path):
with open(file_path, encoding="utf-8") as f:
for line in f:
yield tokenizer(line.strip())

tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(yield_tokens(os.path.join(local_data_path, 'wiki.train.tokens')), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(file_path):
with open(file_path, encoding="utf-8") as f:
return torch.cat([torch.tensor(vocab(tokenizer(line.strip())), dtype=torch.long) for line in f])

train_data = data_process(os.path.join(local_data_path, 'wiki.train.tokens'))
val_data = data_process(os.path.join(local_data_path, 'wiki.valid.tokens'))
test_data = data_process(os.path.join(local_data_path, 'wiki.test.tokens'))

return train_data, val_data, test_data, vocab


# 数据批处理函数
def batchify(data: Tensor, bsz: int) -> Tensor:
seq_len = data.size(0) // bsz
data = data[:seq_len * bsz]
data = data.view(bsz, seq_len).t().contiguous()
return data.to(device)


# 数据加载与处理
local_data_path = "wikitext-2" # 替换为本地Wikitext-2数据集路径
train_data, val_data, test_data, vocab = load_and_preprocess_data(local_data_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)


# 模型与训练配置
ntokens = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

criterion = nn.CrossEntropyLoss()
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

bptt = 35


# 训练和评估函数
def train(model: nn.Module):
model.train()
total_loss = 0.
log_interval = 200
start_time = time.time()
src_mask = generate_square_subsequent_mask(bptt).to(device)

num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
seq_len = data.size(0)
if seq_len != bptt:
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

total_loss += loss.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
ppl = math.exp(cur_loss)
print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
total_loss = 0
start_time = time.time()


def evaluate(model: nn.Module, eval_data: Tensor) -> float:
model.eval()
total_loss = 0.
src_mask = generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i)
seq_len = data.size(0)
if seq_len != bptt:
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += seq_len * criterion(output_flat, targets).item()
return total_loss / (len(eval_data) - 1)


def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target


# 训练循环
best_val_loss = float('inf')
epochs = 3

with TemporaryDirectory() as tempdir:
best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(model)
val_loss = evaluate(model, val_data)
val_ppl = math.exp(val_loss)
elapsed = time.time() - epoch_start_time
print('-' * 89)
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
print('-' * 89)

if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), best_model_params_path)

scheduler.step()

model.load_state_dict(torch.load(best_model_params_path))

# 测试模型
test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | test ppl {test_ppl:8.2f}')
print('=' * 89)

2.数据集下载

数据集下载地址如下:wikitext-2 CSDN链接

下载后,数据集与py文件的部署方式如下:

  • project/
    • wikitext-2
      • wiki.test.token
      • wiki.train.token
      • wiki.valid.token
    • transformer_tutorial.py

3.代码解读

3.1 主要组件的实现

3.1.1 Transformer 模型 (TransformerModel)

TransformerModel类是整个模型的核心,它是基于 PyTorch 提供的 nn.TransformerEncoder 实现的。
输入:输入的文本会被嵌入为词向量(通过 nn.Embedding),加上位置编码(PositionalEncoding)。
主要模块:
TransformerEncoder:堆叠多个 Transformer 编码器层(TransformerEncoderLayer)。
PositionalEncoding:为序列中每个位置添加位置信息,以保留单词之间的相对顺序。
Linear 层(解码器):将编码器的输出映射到词表大小的向量,用于预测下一个词。
初始化权重:使用 init_weights 方法对嵌入层和解码层的权重进行初始化。

3.1.2 位置编码 (PositionalEncoding)

位置编码为输入的嵌入向量添加位置信息。
使用正弦和余弦函数计算不同维度的编码,确保编码可以保留序列的相对位置关系。

3.1.3 掩码 (generate_square_subsequent_mask)

生成一个上三角矩阵,用于屏蔽序列中未来时间步的输入,确保 Transformer 模型仅能看到之前的时间步。

3.2 数据加载与预处理

3.2.1 本地数据加载 (load_and_preprocess_data)

从本地目录中加载 Wikitext-2 数据集,包括 wiki.train.tokens, wiki.valid.tokens, 和 wiki.test.tokens。
预处理:
使用 basic_english 分词器对文本进行分词。
使用 TorchText 的 build_vocab_from_iterator 创建词表,并将所有单词映射为索引。
将文本转化为 PyTorch 张量。

3.2.2 批处理函数 (batchify)

将整个数据集分割为多个固定长度的批次。
数据被转置为形状 [seq_len, batch_size],便于 Transformer 模型的处理。

3.3 训练和评估函数

3.3.1 训练函数 (train)

逐批获取输入数据,使用 generate_square_subsequent_mask 生成掩码。
对每个批次的输入,计算模型输出和损失:
损失函数:交叉熵损失 (nn.CrossEntropyLoss)。
使用梯度裁剪(clip_grad_norm_)防止梯度爆炸。
打印训练过程中的损失、困惑度(Perplexity,语言建模中的重要指标)。

3.3.2 评估函数 (evaluate)

在验证或测试集上评估模型性能。
模型设置为评估模式(model.eval()),以防止 dropout 等正则化方法的干扰。

3.3.3 数据生成 (get_batch)

生成输入和目标对:
输入:当前时间步的序列。
目标:下一时间步对应的序列。

3.4训练过程

3.4.1 模型超参数

词表大小 (ntokens):等于数据集中所有词的数量。
嵌入维度 (emsize):词向量的维度。
隐藏层维度 (d_hid):Transformer 编码器的前馈网络隐藏层大小。
头数 (nhead):多头注意力机制的头数。
层数 (nlayers):编码器的层数。
dropout:用于防止过拟合的丢弃率。

3.4.2 优化器与学习率调度器

使用 SGD 优化器和学习率调度器 (StepLR),每次学习率会按设定的衰减率减少。

3.4.3 训练循环

逐个 epoch 地训练模型:
记录训练损失和验证损失。
在验证集上性能更优时,保存模型的参数。

3.5 测试模型

加载最优模型参数并在测试集上评估。
打印测试集的损失和困惑度(Perplexity)。

3.6 代码工作流程总结

加载数据:从本地加载 Wikitext-2 数据集并进行分词和索引映射。
创建模型:定义并初始化 Transformer 模型。
训练模型:通过批量数据训练模型,记录损失和困惑度。
验证模型:在验证集上选择最佳模型。
测试模型:用测试集评估最终模型性能。

4. 运行结果

| epoch 1 | 200/ 2928 batches | lr 5.00 | ms/batch 131.84 | loss 8.14 | ppl 3423.94
| epoch 1 | 400/ 2928 batches | lr 5.00 | ms/batch 140.47 | loss 6.88 | ppl 970.93
| epoch 1 | 600/ 2928 batches | lr 5.00 | ms/batch 139.09 | loss 6.43 | ppl 623.05
| epoch 1 | 800/ 2928 batches | lr 5.00 | ms/batch 135.78 | loss 6.30 | ppl 545.26
| epoch 1 | 1000/ 2928 batches | lr 5.00 | ms/batch 138.33 | loss 6.19 | ppl 486.50
| epoch 1 | 1200/ 2928 batches | lr 5.00 | ms/batch 140.05 | loss 6.15 | ppl 467.70
| epoch 1 | 1400/ 2928 batches | lr 5.00 | ms/batch 142.60 | loss 6.11 | ppl 451.95
| epoch 1 | 1600/ 2928 batches | lr 5.00 | ms/batch 140.72 | loss 6.10 | ppl 448.07
| epoch 1 | 1800/ 2928 batches | lr 5.00 | ms/batch 143.08 | loss 6.02 | ppl 410.79
| epoch 1 | 2000/ 2928 batches | lr 5.00 | ms/batch 146.73 | loss 6.02 | ppl 410.46
| epoch 1 | 2200/ 2928 batches | lr 5.00 | ms/batch 149.91 | loss 5.90 | ppl 364.34
| epoch 1 | 2400/ 2928 batches | lr 5.00 | ms/batch 152.67 | loss 5.97 | ppl 391.99
| epoch 1 | 2600/ 2928 batches | lr 5.00 | ms/batch 153.82 | loss 5.95 | ppl 383.01
| epoch 1 | 2800/ 2928 batches | lr 5.00 | ms/batch 156.59 | loss 5.88 | ppl 358.70

| end of epoch 1 | time: 441.33s | valid loss 5.81 | valid ppl 333.35

| epoch 2 | 200/ 2928 batches | lr 4.75 | ms/batch 158.28 | loss 5.86 | ppl 351.74
| epoch 2 | 400/ 2928 batches | lr 4.75 | ms/batch 157.82 | loss 5.84 | ppl 345.05
| epoch 2 | 600/ 2928 batches | lr 4.75 | ms/batch 158.74 | loss 5.67 | ppl 288.98
| epoch 2 | 800/ 2928 batches | lr 4.75 | ms/batch 160.91 | loss 5.71 | ppl 301.57
| epoch 2 | 1000/ 2928 batches | lr 4.75 | ms/batch 160.65 | loss 5.66 | ppl 286.73
| epoch 2 | 1200/ 2928 batches | lr 4.75 | ms/batch 161.34 | loss 5.68 | ppl 294.02
| epoch 2 | 1400/ 2928 batches | lr 4.75 | ms/batch 161.62 | loss 5.69 | ppl 296.31
| epoch 2 | 1600/ 2928 batches | lr 4.75 | ms/batch 162.15 | loss 5.71 | ppl 302.16
| epoch 2 | 1800/ 2928 batches | lr 4.75 | ms/batch 161.82 | loss 5.65 | ppl 285.07
| epoch 2 | 2000/ 2928 batches | lr 4.75 | ms/batch 161.89 | loss 5.68 | ppl 291.94
| epoch 2 | 2200/ 2928 batches | lr 4.75 | ms/batch 161.86 | loss 5.56 | ppl 258.91
| epoch 2 | 2400/ 2928 batches | lr 4.75 | ms/batch 162.87 | loss 5.65 | ppl 282.90
| epoch 2 | 2600/ 2928 batches | lr 4.75 | ms/batch 161.08 | loss 5.65 | ppl 284.46
| epoch 2 | 2800/ 2928 batches | lr 4.75 | ms/batch 161.55 | loss 5.58 | ppl 264.79

| end of epoch 2 | time: 489.75s | valid loss 5.63 | valid ppl 279.19

| epoch 3 | 200/ 2928 batches | lr 4.51 | ms/batch 161.83 | loss 5.60 | ppl 270.67
| epoch 3 | 400/ 2928 batches | lr 4.51 | ms/batch 161.85 | loss 5.63 | ppl 278.09
| epoch 3 | 600/ 2928 batches | lr 4.51 | ms/batch 161.18 | loss 5.44 | ppl 229.32
| epoch 3 | 800/ 2928 batches | lr 4.51 | ms/batch 160.05 | loss 5.49 | ppl 241.14
| epoch 3 | 1000/ 2928 batches | lr 4.51 | ms/batch 159.72 | loss 5.44 | ppl 231.02
| epoch 3 | 1200/ 2928 batches | lr 4.51 | ms/batch 160.88 | loss 5.48 | ppl 240.77
| epoch 3 | 1400/ 2928 batches | lr 4.51 | ms/batch 159.98 | loss 5.50 | ppl 243.48
| epoch 3 | 1600/ 2928 batches | lr 4.51 | ms/batch 159.63 | loss 5.53 | ppl 252.02
| epoch 3 | 1800/ 2928 batches | lr 4.51 | ms/batch 160.42 | loss 5.47 | ppl 236.82
| epoch 3 | 2000/ 2928 batches | lr 4.51 | ms/batch 158.95 | loss 5.48 | ppl 240.99
| epoch 3 | 2200/ 2928 batches | lr 4.51 | ms/batch 156.87 | loss 5.37 | ppl 214.50
| epoch 3 | 2400/ 2928 batches | lr 4.51 | ms/batch 157.96 | loss 5.47 | ppl 236.30
| epoch 3 | 2600/ 2928 batches | lr 4.51 | ms/batch 157.73 | loss 5.47 | ppl 238.29
| epoch 3 | 2800/ 2928 batches | lr 4.51 | ms/batch 374.93 | loss 5.41 | ppl 223.21

| end of epoch 3 | time: 623.88s | valid loss 5.60 | valid ppl 270.49

=========================================================================================
| End of training | test loss 5.51 | test ppl 246.59