1 DCFormer

1.1 config.json

本文件是DCFormer 模型的配置文件,定义了模型的各种超参数和设定。

-字段说明
**architectures:**模型架构的名称,这里是 “DCFormer”。
**auto_map:**自动映射模型和配置类,指定了 AutoConfig 和 AutoModelForCausalLM 对应的配置类和模型类路径。
AutoConfig: 指定 DCFormerConfig,这是模型的配置类。
**AutoModelForCausalLM:**指定 DCFormer 模型类,这是实际的推理模型。
**block_size:**输入序列的最大长度(2048),影响模型在训练或推理时处理的上下文窗口大小。
**bos_token_id 和 eos_token_id:**这是模型的特殊标记(begin-of-sequence 和 end-of-sequence),分别对应词汇表中的 1 和 2。
**dim:**模型的隐藏层维度,设置为 2560。
**head_dim:**每个注意力头的维度,设置为 80。
**intermediate_size:**前馈网络的中间层维度,设置为 6912。
**n_head:**注意力头的数量,这里设置为 32。
**n_layer:**Transformer 层数,设置为 32。
**norm_eps:**归一化中的 epsilon 值,设置为 1e-6。
**q_chunk_size:**查询(query)的分块大小,这个参数用于控制大规模模型训练中的内存使用。
**use_dcmha:**是否启用动态可组合多头注意力(DCMHA),设为 true,意味着使用了 DCFormer 改进的注意力机制。
**use_qk_norm:**是否启用查询和键的归一化(QK Norm),设为 true,这有助于避免在训练过程中出现不稳定。
**vocab_size:**词汇表大小,这里设置为 50257,符合大规模模型(如 GPT 系列)的标准大小。

1.2 configuration_dcformer.py

DCFormerConfig 类继承自 PretrainedConfig,用于初始化和管理 DCFormer模型.

-DCFormerConfig 类初始化了模型的核心超参数,很多参数与 config.json 中的字段相对应。
-构造函数中的一些值会根据给定的参数进行自动推断,比如 intermediate_size(默认情况下为 None,会根据模型维度自动计算)。
-该类继承自 PretrainedConfig,使得模型可以无缝与 Huggingface Transformers 框架集成。

1.3 generation_demo.py

本文件展示了如何使用 DCFormer 模型进行推理生成。

-加载预训练的 DCFormer-2.8B 模型和分词器。
-使用 CUDA 加速模型推理,确保计算效率。
-模型生成 100 个 token,并使用分词器解码生成的 ID 到文本格式。
-示例代码简单直观,演示了如何利用 Huggingface AutoModelForCausalLM 和 AutoTokenizer 接口快速加载并推理生成文本。

1.4 maxtext2torch.py

maxtext2torch.py 文件的作用是将 MaxText 格式的模型权重转换为 PyTorch 支持的格式,通常用于加载预训练模型或进行模型迁移。MaxText 是一种特殊的模型保存格式,可能用于一些特定框架或定制模型存储。
将其转换为 PyTorch 格式后,可以直接使用 transformers 库加载并进行推理或微调。

-加载 MaxText 权重:假设 MaxText 权重是以某种特定的格式存储的,首先需要通过某种方式将其解析出来。通常这种格式会涉及到二进制文件或者自定义格式,需要根据模型开发者提供的解析规则来加载。
-加载 PyTorch 模型:在转换过程中,我们使用 AutoModelForCausalLM 来加载一个预训练的 PyTorch 模型,这里假设已经有一个对应的 DCFormer 模型。
-权重加载:将从 MaxText 格式读取的权重加载到 PyTorch 模型中,并且调用 model.load_state_dict() 方法。
-保存转换后的模型:将最终加载了 MaxText 权重的模型保存在 .pt 格式的文件中,以便后续使用。

1.5 modeling_dcformer.py

代码实现了 DCFormer 模型,基于 Transformer 的自注意力机制,进行了动态加权、窗口优化和分组注意力等改进,目的是优化训练和推理性能。

以下是核心代码解析:

1.5.1 KVKWCache

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class KVKWCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, window_size=2048, dtype=torch.float16, use_kw_cache=True):
super().__init__()
self.head_dim = head_dim
self.kw_dim = 2 * n_heads
self.n_heads = n_heads
self.window_size = window_size
self.use_kw_cache = use_kw_cache
if window_size is None:
self.seq_length = max_seq_length
else:
self.seq_length = min(window_size, max_seq_length)
cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
kw_cache_shape = (max_batch_size, self.seq_length, 2, n_heads, n_heads)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
if self.use_kw_cache:
self.register_buffer('kw_cache', torch.zeros(kw_cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val, kw_val=None): # kw_val B,N,S,2,N B2NSD
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[-1] == k_val.shape[2]
B,N,S,D = v_val.shape
k_out = self.k_cache
v_out = self.v_cache
if self.use_kw_cache:
kw_out = self.kw_cache
else:
kw_out = None

if self.window_size is None:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
if self.use_kw_cache and kw_val is not None:
kw_out[:,input_pos] = kw_val
elif S == 1:
input_pos = input_pos % self.seq_length
v_out[:, :, input_pos] = v_val
k_out[:, :, input_pos] = k_val
if self.use_kw_cache and kw_val is not None:
kw_out[:,input_pos] = kw_val
else: # prefill
start = max(0, input_pos[-1]-self.seq_length+1)
input_pos = input_pos[start:] % self.seq_length
v_out[:, :, input_pos] = v_val[:,:,start:]
k_out[:, :, input_pos] = k_val[:,:,start:]
if self.use_kw_cache and kw_val is not None:
kw_out[:, input_pos] = kw_val[:,start:]
return k_out, v_out, kw_out

这部分是缓存模块,用于高效存储和更新注意力层的 Key-Value(KV)对。

作用:
-支持推理时的序列缓存和动态窗口机制。
提供可选的权重缓存 (kw_cache),允许动态加权的注意力机制。
主要方法:
-update: 更新缓存,支持单 token(在线生成)或批量(填充模式)更新。
关键逻辑:
窗口限制 (window_size) 允许模型仅关注局部上下文。
支持动态生成的权重(kw_cache)。

1.5.2 DCFormer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
class DCFormer(PreTrainedModel):
config_class=DCFormerConfig
'''
DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
'''

def __init__(self, config: DCFormerConfig) -> None:
super().__init__(config)
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(DCFormerBlock(config, lidx) for lidx in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.use_gradient_checkpointing = config.use_gradient_checkpointing
self.is_training = config.is_training

self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
self.window_size = config.window_size
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length, set_kv_cache=True):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
if not self.is_training:
for b in self.layers:
if set_kv_cache:
use_kw_cache = False if b.attention.query_wise else True
b.attention.kv_cache = KVKWCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, window_size=b.attention.window_size, use_kw_cache=use_kw_cache)
b.attention.dyn_w_proj.merge_weights()
if not b.attention.use_sw:
dtype = b.attention.wo.weight.dtype
device = b.attention.wo.weight.device
b.attention.dyn_w_proj.sw = b.attention.dyn_w_proj.sw.to(device=device, dtype=dtype)
b.attention.dyn_w_proj.pre_proj.w = b.attention.dyn_w_proj.pre_proj.w.to(device=device, dtype=dtype)
b.attention.dyn_w_proj.post_proj.w = b.attention.dyn_w_proj.post_proj.w.to(device=device, dtype=dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base).to(self.tok_embeddings.weight.device)
if self.is_training:
self.causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool, device=self.tok_embeddings.weight.device))
elif self.window_size is None:
self.causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
else:
self.causal_mask = torch.stack([make_window_mask(max_seq_length, self.config.window_size), torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))]) # LG

def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
batch_size, seq_length = input_ids.shape
input_pos = torch.arange(seq_length, device=self.device)
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
next_token[:batch_size] = _next_token
generated_ids[:, seq_length] = next_token[:batch_size, 0]
input_pos = torch.tensor([seq_length], device=self.device)
for _ in range(1, num_tokens_to_generate):
if compiled_decode_one_token is not None:
next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
else:
next_token = self.decode_one_token(next_token.clone(), input_pos)
generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
input_pos += 1
return generated_ids

def decode_one_token(self, cur_token, input_pos):
logits = self.forward(
cur_token,
input_pos=input_pos,
return_tensor=True
)
new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
return new_token

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
if input_pos is None:
input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
if self.window_size is None or self.is_training:
mask = self.causal_mask[None, None, input_pos]
else:
mask = self.causal_mask[None, None,:,input_pos]
freqs_cis = self.freqs_cis[input_pos][:idx.shape[-1]]
x = self.tok_embeddings(idx)
for i, layer in enumerate(self.layers):
if self.is_training or self.window_size is None :
layer_mask = mask
gen_mask = None
elif self.window_size is not None:
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
if self.use_gradient_checkpointing:
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
else:
x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
x = self.norm(x)
logits = self.output(x)
if return_tensor:
return logits
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=logits)

这是模型的核心类,继承自 PreTrainedModel,具备加载和保存模型的能力。

结构:
-词嵌入层(tok_embeddings)。
-多个 Transformer 块(DCFormerBlock)。
-RMSNorm 归一化和输出层。

方法解析:
-setup_caches:
初始化 KV 缓存。
预计算频率编码(freqs_cis)用于旋转位置编码。
-generate:
循环生成文本,支持快速解码(compiled_decode_one_token)。
-forward:
依次通过各层 Transformer 块。
支持梯度检查点(节省显存)。
根据窗口类型和训练模式切换掩码。

特点:
使用 RMSNorm 替代 LayerNorm。
通过 KV 缓存和窗口机制优化推理。

1.5.3 DCFormerBlock

1
2
3
4
5
6
7
8
9
10
11
12
13
class DCFormerBlock(nn.Module):
def __init__(self, config: DCFormerConfig, lidx) -> None:
super().__init__()
self.lidx = lidx
self.attention = DCMHAttention(config, lidx)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, gen_mask=gen_mask, fast_infer=True)
out = h + self.feed_forward(self.ffn_norm(h))
return out

单个 Transformer 块,包含以下模块:
-动态加权多头注意力(DCMHAttention)。
-前馈网络(FeedForward)。
-两个归一化层(RMSNorm)。

实现:
注意力和前馈网络均通过残差连接(h = x + …)相加。

1.5.4 DCMHAttention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class DCMHAttention(nn.Module):
def __init__(self, config: DCFormerConfig, lidx, use_sw=False):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.lidx = lidx
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.is_training = config.is_training
self.dim = config.dim
self.use_dcmha = config.use_dcmha
self.scale_factor = 1 / math.sqrt(self.head_dim)
self.q_chunk_size = config.q_chunk_size
self.use_sw = use_sw
self.dyn_w_proj = DynamicWeightProjection(num_heads=self.n_head, query_input_dim=config.dim, dynamic_squeeze_ratio=self.n_head//2, dynamic_w_hidden_dim=self.n_head*4, use_sw=use_sw)
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = RMSnorm(hid_dim=self.head_dim)
self.k_norm = RMSnorm(hid_dim=self.head_dim)

self.window_types = {
"LG":[256, None],
"LGLL":[256, None, 256, 256],
"LGL6":[256, None, 256, 256, 256, 256, 256, 256],
}

self.query_wise = config.query_wise
if config.window_type is None: # LG
self.window_size = None if self.lidx % 2 == 1 else config.window_size
else:
window_l = self.window_types[config.window_type]
self.window_size = window_l[self.lidx % len(window_l)]

if not self.is_training:
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def _generate_fast(self, x, input_pos, q, k, v, k_mask):
B,T,D = x.shape
N,I = self.n_head, self.dyn_w_proj.dynamic_hidden_dim # 32, 2
dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) # BTD, D(4K+4N) -> BT(4K+4N) -> BT(4K), BT(4N)
dw_hidden = dw_hidden.view((B,T,4,-1,1)) # BT(4K) -> BT4K1
dw = (self.dyn_w_proj.dw_hidden_activation(dw_hidden) * self.dyn_w_proj.qkw_m).sum(-2) # gelu, BT4K1, 4K(IM)->BT4K(IM)->BT4(IM)
w1, w2 = dw.view((B,T,2,2,-1,N)).split(I,-2) # BT4(IM)->BT{pre/post}{q/k}IM->[BT22IM] * 2
w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
qkdd = self.dyn_w_proj.dw_activation(dd.view((B,T,2,2,N))) # BT2{2}N1->BT2{2}N tanh
qkw = torch.einsum('BTKJIN,BTKJIM->BTKJNM', w1, w2) + torch.diag_embed(qkdd) # j=k=2, BT2{2}NM q/k, pre/post
if self.query_wise: # TODO: do not generate kw and kdd
qw, _ = qkw.unbind(3) # BS2NM
kw_new = None
qw = qw + self.dyn_w_proj.sw
else:
qw, kw_new = qkw.unbind(3) # BS{pre/post}{q/k}NM -> BS{pre/post}NM * 2
kw_new = kw_new + self.dyn_w_proj.sw # BS2NM + 2NM-> BS2NM
if self.kv_cache is not None:
k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) #BNT2M
logits = q @ k.transpose(-2, -1) * self.scale_factor
if self.query_wise:
w = qw # B12NM
else:
w = qw + kw_out # B12NM,BS2NM -> BS2NM
wl, w = w.permute(0,2,3,4,1).unbind(1) # BS2NM->B2NMS->[BNMS]*2
logits = (logits * wl).sum(1).unsqueeze(2) # BN1S, BNMS -> BNMS-> BMS-> BM1S
min_value = torch.finfo(torch.float16).min
logits = torch.where(k_mask, logits, min_value)
probs = logits.softmax(-1)
probs = (probs * w).sum(1).unsqueeze(2)
y = probs @ v
return y

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim) # BSND
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) # BNSD

if self.is_training:
N, D, I = self.n_head, self.head_dim, self.dyn_w_proj.dynamic_hidden_dim; # 6.7B
B,T,E = x.shape
if self.use_dcmha:
project_logits = True
project_probs = True
if project_probs:
dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1)
dw_hidden = self.dyn_w_proj.dw_hidden_activation(dw_hidden)
dw_hidden = dw_hidden.view(dw_hidden.shape[:2]+(4,-1)) #B T (4 K) -> B T 4 K # reshape
dw = torch.einsum('B T C K, C K D -> B T C D', dw_hidden, self.dyn_w_proj.qkw_m) # BT4K,4K(MI)->BT4(MI)
shape = (B,T,2*2,-1,N)# if project_logits else (B,T,2,N,-1) # BT(pre/post)(q/k)IN
w1, w2 = dw.view(shape).split(I,-2)
w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
if self.use_sw:
pre_sw, post_sw = self.dyn_w_proj.sw.unbind(0)
else:
pre_sw, post_sw = None, None
pre_qw1, pre_kw1, post_qw1, post_kw1 = w1.unbind(2) # BT(2{*2})IN->[BTIN]*4
pre_qw2, pre_kw2, post_qw2, post_kw2 = w2.unbind(2)
qkdd = F.tanh(dd).squeeze(-1).view(shape[:-2] + (N,)) # BT(2{*2})N1->BT(2{*2})N
pre_qdd, pre_kdd, post_qdd, post_kdd = qkdd.unbind(2) # BT(2{*2})N->[BTN]*4

y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
for i in range(T // self.q_chunk_size):
start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
kv_start = max(0, stop - self.q_chunk_size -self.window_size)
_q = q[:, :, start : stop, :]
_k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
_atten_mask = mask[:, :, start : stop, kv_start : stop]
_pre_proj_dw_args = slice_dw(pre_sw, pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd, start, stop, kv_start) \
if project_logits else None
_post_proj_dw_args = slice_dw(post_sw, post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd, start,stop,kv_start) \
if project_probs else None
_o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
y[:,:,start:stop] = _o
else:
y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
for i in range(T // self.q_chunk_size):
start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
kv_start = max(0, stop - self.q_chunk_size -self.window_size)
_q = q[:, :, start : stop, :]
_k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
_atten_mask = mask[:, :, start : stop, kv_start : stop]
_pre_proj_dw_args, _post_proj_dw_args = None, None
_o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
y[:,:,start:stop] = _o
else: # inference
if seqlen == 1: # one-token generation
k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length]
if fast_infer:
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
else:
assert not self.query_wise
# generate dw from hidden_state
pre_proj_dw_args, post_proj_dw_args, kw_new = self.dyn_w_proj(x, gen_cache=True)

# update kvkw cache
kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
if self.kv_cache is not None:
k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN

logits = q @ k.transpose(-2, -1) * self.scale_factor
# merge pre_w and apply it
pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd = pre_proj_dw_args
pre_qw = torch.einsum('BTGIN, BTGIM->BTNM',pre_qw1, pre_qw2) + torch.diag_embed(pre_qdd.squeeze(2))
pre_w = pre_qw + kw_out[:,:,0] # B1NM, BSNM -> BSNM
logits = self.dyn_w_proj.pre_proj(logits, proj_w=pre_w.squeeze(1))

logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
probs = logits.softmax(-1)

# merge post_w and apply it
post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd = post_proj_dw_args
post_qw = torch.einsum('BTGIN, BTGIM->BTNM', post_qw1, post_qw2) + torch.diag_embed(post_qdd.squeeze(2))
post_w = post_qw + kw_out[:,:,1]
probs = self.dyn_w_proj.post_proj(probs, proj_w=post_w.squeeze(1))

y = probs @ v
else: # prefill
k_mask = mask[:,:,:,:k.shape[-2]]
pre_proj_dw_args, post_proj_dw_args,kw_new = self.dyn_w_proj(x, gen_cache=True)
kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
if self.kv_cache is not None:
self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
logits = q @ k.transpose(-2, -1) * self.scale_factor
logits = self.dyn_w_proj.pre_proj(logits, dws=pre_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # XD BN1S
logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
probs = logits.softmax(-1)
probs = self.dyn_w_proj.post_proj(probs, dws=post_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # BN1S
y = probs @ v

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = self.wo(y)
return y

本部分为动态加权多头注意力模块,是模型的核心改进。

功能:
计算 Query、Key 和 Value(通过 wqkv 投影)。
支持 Query 和 Key 的归一化。
根据窗口类型和位置掩码调整计算范围。
使用 DynamicWeightProjection 动态调整注意力权重。

推理优化:
_generate_fast 用于快速单步解码。
支持按块分组(q_chunk_size)计算,减少内存开销。

动态权重投影的关键逻辑:
通过动态权重投影(DynamicWeightProjection)生成权重。
fast_infer 模式下,利用缓存加速计算。

1.5.5 DynamicWeightProjection

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class DynamicWeightProjection(nn.Module):

def __init__(self, num_heads=32, num_groups=1, residual=True, query_input_dim=4096, dynamic_squeeze_ratio=16, dynamic_w_hidden_dim=128,dtype=torch.float16,use_sw=False):
super().__init__()
self.num_heads = num_heads
self.num_groups = num_groups
self.query_input_dim = query_input_dim
self.dynamic_squeeze_ratio = dynamic_squeeze_ratio
self.dynamic_w_hidden_dim = dynamic_w_hidden_dim
self.dw_hidden_activation = nn.GELU()
self.num_heads_per_group = self.num_heads // self.num_groups
self.dw_activation = nn.Tanh()
self.dw1_norm = RMSnormNoscale(dim=-1)
self.use_sw = use_sw
self.pre_proj = CrossHeadProjection('pre', num_heads=self.num_heads, use_sw=use_sw)
self.post_proj = CrossHeadProjection('post', num_heads=self.num_heads, use_sw=use_sw)

dynamic_hidden_dim = self.num_heads_per_group // self.dynamic_squeeze_ratio
self.dynamic_hidden_dim = dynamic_hidden_dim
self.dw1 = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, 4, self.dynamic_w_hidden_dim, dtype=dtype)) #(4096, 1, 4, 128)
G, K, M = self.num_groups, self.dynamic_w_hidden_dim, self.num_heads_per_group
I = dynamic_hidden_dim * 2
self.qkw = nn.parameter.Parameter(torch.zeros([G, 4, K, I, M], dtype=dtype)) # (1, 4, 128, 4, 32)
self.dd = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, self.num_heads_per_group * 4, dtype=dtype)) # (4096, 1, 128)

self.merge_weights()

def merge_weights(self):
self.dw_m = nn.parameter.Parameter(torch.cat([self.dw1.reshape(self.query_input_dim, -1), self.dd.squeeze(1)], dim=-1)).to(self.dw1.device) # E,(4*K + K) K=2*N*I
self.qkw_m = nn.parameter.Parameter(self.qkw.permute(0,1,2,3,4).reshape(4,self.dynamic_w_hidden_dim,-1)).to(self.dw1.device) #(4,K,I*M)
if self.use_sw:
self.sw = nn.parameter.Parameter(torch.stack([self.pre_proj.w, self.post_proj.w]).squeeze(1) + torch.eye(self.num_heads) ).to(self.dw1.device) # (2,N,N) sw + identity matrix
else:
self.sw = (torch.eye(self.num_heads).expand(2,self.num_heads,self.num_heads)).to(self.dw1.device) # identity matrix (2,N,N)

def forward(self,query_vec,KW:Optional[torch.Tensor]=None, gen_cache:Optional[bool]=True):
dw_hidden = torch.einsum('BTD,DGCK->BTGCK', query_vec, self.dw1) # C=4 [pre,post]*[query,key]
dw_hidden = self.dw_hidden_activation(dw_hidden) #BTGCK
w1, w2 = torch.split(torch.einsum('BTGCK,GCKIM->BTGCIM', dw_hidden, self.qkw), self.qkw.shape[-2]//2, dim=-2) #BTGC(2I)M -> [BTGCIM] * 2
w1 = self.dw1_norm(w1) # BTGCIM
pre_qw1, pre_kw1, post_qw1, post_kw1 = unbind(w1, 4, dim=3) # BTG4IM->[BTGIM]*4
pre_qw2, pre_kw2, post_qw2, post_kw2 = unbind(w2, 4, dim=3)
dd = torch.einsum('BTD,DGM->BTGM', query_vec, self.dd) # BTG(4M)
dd = self.dw_activation(dd)
pre_qdd, pre_kdd, post_qdd, post_kdd = torch.split(dd, dd.shape[-1] // 4, dim=-1) # BTG(4N)->[BTGN]*4
pre_dw_args = (pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd)
post_dw_args = (post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd)
if gen_cache: # generate KW cache
pre_kw = torch.einsum('BSGIM, BSGIN->BSMN', pre_kw1, pre_kw2) + torch.diag_embed(pre_kdd.squeeze(2)) # merge kw and kdd
post_kw = torch.einsum('BSGIM, BSGIN->BSMN', post_kw1, post_kw2) + torch.diag_embed(post_kdd.squeeze(2))
KW = torch.stack((pre_kw, post_kw), dim=-3) # BSMN,BSMN->BS2MN
return pre_dw_args, post_dw_args, KW

实现动态权重投影,为 Query 和 Key 生成自适应权重。

特点:
动态生成投影权重 (dw_hidden, qkw)。
支持分组权重(num_groups)。
提供预投影(pre_proj)和后投影(post_proj)的支持。

关键方法:
merge_weights: 将预定义权重合并,用于优化推理。
forward: 根据输入动态生成权重,并选择是否缓存。

1.5.6 FeedForward

1
2
3
4
5
6
7
8
9
class FeedForward(nn.Module):
def __init__(self, config: DCFormerConfig) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))

这部分是前馈网络模块,典型的 Transformer 组件:
-两个全连接层。
-使用 SILU 激活函数。

1.5.7 RMSNorm

1
2
3
4
5
6
7
8
9
10
11
12
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight

RMSNorm 是一种替代 LayerNorm 的归一化技术,具有更好的训练稳定性。
公式:
$$
\text{output} = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}} \cdot \text{scale}
$$

1.5.8 辅助函数

precompute_freqs_cis
-用于预计算旋转位置编码(RoPE)。
-将频率参数以极坐标形式表示,支持快速计算旋转编码。

make_window_mask
生成窗口掩码,仅允许模型关注特定上下文范围。

apply_rotary_emb
应用旋转位置编码,增强 Transformer 在长序列上的建模能力。

slice_dw
切片动态权重,用于按窗口分割输入。

2 DCpythia

config.json

本文件定义了模型的各种配置参数,包括架构、token的id、模型的维度等。该配置文件是用来加载模型时给定的超参数和设置。

主要配置项:
architectures: 定义模型的架构类型,这里指定为 “DCPythia”,即模型是基于DCPythia架构。
auto_map: 映射自动配置类和模型类,AutoConfig指向 configuration_dcpythia.DCPythiaConfig,而 AutoModelForCausalLM 指向 modeling_dcpythia.DCPythia。
block_size: 定义块的大小,通常影响模型处理的最大序列长度。
vocab_size: 模型词汇表的大小,通常影响模型的输入和输出维度。
dim: 模型的隐藏层维度,也就是每个token的表示向量的大小。
n_layer: 模型的层数,也就是Transformer的堆叠层数。
n_head: 每个自注意力层的头数。
intermediate_size: Transformer中FeedForward层的维度,影响每层的宽度。
bos_token_id, eos_token_id: 序列的起始token和结束token的ID。
torch_dtype: 定义使用的PyTorch数据类型,这里使用float16,表示使用半精度浮点数,通常用于节省内存和加速计算。
use_dcmha: 是否使用DCMHA(Dynamic Chunked Multi-Head Attention)。
use_parallel_residual: 是否启用并行残差连接。
use_linear_bias: 是否使用线性偏置。
use_qk_norm: 是否在自注意力计算中使用查询(Q)和键(K)的归一化。
transformers_version: 定义所使用的Transformers库版本。

configuration_dcythia.py

1
2
3
4
5
6
7
8
9
10
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from typing import Optional, Tuple, List

class DCPythiaConfig(PretrainedConfig):
model_type = "dcpythia"
'''
DCPythiaConfig is a config class for DCPythia, which is adapted from
https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
'''

DCPythiaConfig 继承了 PretrainedConfig,这使得它能够作为 Hugging Face transformers 库中的配置类来使用。配置类用于管理模型的参数和设置。
model_type 定义了模型类型为 “dcpythia”,这表明该配置是为 DCPythia 模型定制的。
注释表明该配置类来自某个 GitHub 项目链接,表明 DCPythia 是从 GPT-fast 模型中改编的。

初始化方法 (init):
这是 DCPythiaConfig 类的初始化方法,定义了很多用于 DCPythia 模型的超参数。
常见的超参数包括:
block_size:序列的最大长度,通常决定输入的最大 token 数量。
vocab_size:词汇表的大小。
n_layer:Transformer 网络的层数。
n_head:多头注意力机制中的头数。
dim:模型的隐层维度。
intermediate_size:中间层的维度,通常是 dim 的倍数。
head_dim:每个头的维度,通常 dim / n_head。
use_gradient_checkpointing:是否启用梯度检查点,以节省内存。
use_dcmha:是否使用 DC-MHA(假设是某种改进的多头注意力机制)。
use_qk_norm:是否对查询-键(Q和K)做归一化。
window_size 和 window_type:窗口大小和类型,可能用于局部注意力机制。
rotary_pct:用于旋转位置编码的百分比。

接着是后处理部分:

1
2
3
4
5
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
self.intermediate_size = 4 * self.dim
self.head_dim = self.dim // self.n_head

如果 n_local_heads 没有被指定(即为 -1),则将其设置为 n_head。
如果没有指定 intermediate_size,则将其设置为 dim 的四倍。
head_dim 被计算为 dim // n_head。

然后调用父类的构造方法:

1
2
3
4
5
6
7
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

这里调用了父类 PretrainedConfig 的构造方法,并传递了额外的 token ID 和其他参数。

generation_demo.py

1
2
3
4
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

导入 torch 和 transformers 库,用于加载模型和执行推理。
关闭 TOKENIZERS_PARALLELISM 环境变量,通常用于避免多线程令牌化过程中的问题。

加载模型和分词器:

1
2
tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCPythia-6.9B")
model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCPythia-6.9B", trust_remote_code=True)

使用 AutoTokenizer 和 AutoModelForCausalLM 从指定的模型库 “Caiyun-AI/DCPythia-6.9B” 加载模型和分词器。
trust_remote_code=True 表示信任并加载远程代码,可能包含自定义实现。

设置设备:

1
device = torch.device('cuda')

将模型和推理过程设置为在 GPU 上运行。

配置模型:

1
2
3
4
5
MAX_BATCH_SIZE = 1
MAX_SEQ_LENGTH = 2048
NUM_TOKENS_TO_GENERATE = 100
COMPILE = True
_ = model.to(device=device, dtype=torch.float16)

设置批量大小、最大序列长度和生成的 token 数量。
COMPILE = True 表示是否启用编译功能(即是否使用 torch.compile)。
将模型转移到 GPU 上,并设置数据类型为 float16 以减少内存使用。

设置缓存:

1
2
with torch.device(device):
model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH, set_kv_cache=True)

在指定的设备上配置模型的缓存。这里指定了批处理大小、最大序列长度,并启用了键值缓存。

定义 decode_one_token 函数:

1
2
3
4
def decode_one_token(model, cur_token, input_pos):
logits = model(cur_token, input_pos=input_pos, return_tensor=True)
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token

定义了一个函数 decode_one_token,它根据当前 token 和位置生成下一个 token。模型输出的 logits 被用来选择概率最大的 token。

生成文本:

1
2
3
prompt = "Beijing is the capital of China. London is the capital of"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
compiled_decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) if COMPILE else None

使用给定的 prompt 编码为 token ids。
torch.compile 对 decode_one_token 函数进行优化,减少计算开销。

推理生成:

1
2
3
4
with torch.no_grad():
generated_ids = model.generate(input_ids.to(device), num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
text = tokenizer.decode(generated_ids[0])
print('generated text:', text)

使用 model.generate 方法生成文本。
使用 tokenizer.decode 将生成的 token 转换为文本并打印出来。

modeling_dcpythia.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class KVKWCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, window_size=2048, dtype=torch.float16, use_kw_cache=True):
super().__init__()
self.head_dim = head_dim
self.kw_dim = 2 * n_heads
self.n_heads = n_heads
self.window_size = window_size
self.use_kw_cache = use_kw_cache
if window_size is None:
self.seq_length = max_seq_length
else:
self.seq_length = min(window_size, max_seq_length)
cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
kw_cache_shape = (max_batch_size, self.seq_length, 2, n_heads, n_heads)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
if self.use_kw_cache:
self.register_buffer('kw_cache', torch.zeros(kw_cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val, kw_val=None): # kw_val B,N,S,2,N B2NSD
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[-1] == k_val.shape[2]
B,N,S,D = v_val.shape
k_out = self.k_cache
v_out = self.v_cache
if self.use_kw_cache:
kw_out = self.kw_cache
else:
kw_out = None

if self.window_size is None:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
if self.use_kw_cache and kw_val is not None:
kw_out[:,input_pos] = kw_val
elif S == 1:
input_pos = input_pos % self.seq_length
v_out[:, :, input_pos] = v_val
k_out[:, :, input_pos] = k_val
if self.use_kw_cache and kw_val is not None:
kw_out[:,input_pos] = kw_val
else: # prefill
start = max(0, input_pos[-1]-self.seq_length+1)
input_pos = input_pos[start:] % self.seq_length
v_out[:, :, input_pos] = v_val[:,:,start:]
k_out[:, :, input_pos] = k_val[:,:,start:]
if self.use_kw_cache and kw_val is not None:
kw_out[:, input_pos] = kw_val[:,start:]
return k_out, v_out, kw_out

用于管理模型的键值缓存,适配不同的序列长度、窗口大小和头部配置。
支持常规键值缓存 (k_cache 和 v_cache) 和额外的 kw_cache。
方法 update 用于更新缓存,是模型的核心之一,处理序列窗口截断和序列填充。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
class DCPythia(PreTrainedModel):
config_class=DCPythiaConfig

def __init__(self, config: DCPythiaConfig) -> None:
super().__init__(config)
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(DCPythiaBlock(config, lidx) for lidx in range(config.n_layer))
self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False) # no bias in pythia
self.use_gradient_checkpointing = config.use_gradient_checkpointing
self.is_training = config.is_training

self.freqs_cis: Optional[Tensor] = None
self.rotary_ndims = int(config.head_dim * config.rotary_pct)
self.mask_cache: Optional[Tensor] = None
self.window_size = config.window_size
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length, set_kv_cache=True):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
if not self.is_training:
for b in self.layers:
if set_kv_cache:
use_kw_cache = False if b.attention.query_wise else True
b.attention.kv_cache = KVKWCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, window_size=b.attention.window_size, use_kw_cache=use_kw_cache)
b.attention.dyn_w_proj.merge_weights()
if not b.attention.use_sw:
dtype = b.attention.wo.weight.dtype
device = b.attention.wo.weight.device
b.attention.dyn_w_proj.sw = b.attention.dyn_w_proj.sw.to(device=device, dtype=dtype)
b.attention.dyn_w_proj.pre_proj.w = b.attention.dyn_w_proj.pre_proj.w.to(device=device, dtype=dtype)
b.attention.dyn_w_proj.post_proj.w = b.attention.dyn_w_proj.post_proj.w.to(device=device, dtype=dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device)
if self.is_training:
self.causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool, device=self.tok_embeddings.weight.device))
elif self.window_size is None:
self.causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
else:
self.causal_mask = torch.stack([make_window_mask(max_seq_length, self.config.window_size), torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))]) # LG

def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
batch_size, seq_length = input_ids.shape
input_pos = torch.arange(seq_length, device=self.device)
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
next_token[:batch_size] = _next_token
generated_ids[:, seq_length] = next_token[:batch_size, 0]
input_pos = torch.tensor([seq_length], device=self.device)
for _ in range(1, num_tokens_to_generate):
if compiled_decode_one_token is not None:
next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
else:
next_token = self.decode_one_token(next_token.clone(), input_pos)
generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
input_pos += 1
return generated_ids

def decode_one_token(self, cur_token, input_pos):
logits = self.forward(
cur_token,
input_pos=input_pos,
return_tensor=True,
)
new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
return new_token

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
if input_pos is None:
input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
if self.window_size is None or self.is_training:
mask = self.causal_mask[None, None, input_pos]
else:
mask = self.causal_mask[None, None,:,input_pos]
freqs_cis = self.freqs_cis[input_pos][:idx.shape[-1]]
x = self.tok_embeddings(idx)
for i, layer in enumerate(self.layers):
if self.is_training or self.window_size is None :
layer_mask = mask
gen_mask = None
elif self.window_size is not None:
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
if self.use_gradient_checkpointing:
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
else:
x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
x = self.norm(x)
logits = self.output(x)
if return_tensor:
return logits
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=logits)

基于 HuggingFace 的 PreTrainedModel,是该模型的主干结构。
包括以下子模块:
-tok_embeddings:词嵌入层。
-多层 DCPythiaBlock。
-LayerNorm 和输出线性层。
-支持训练和推理两种模式,拥有灵活的缓存管理功能。
方法 setup_caches 初始化模型缓存,确保性能优化。
方法 generate 实现逐步生成,支持动态调整输入序列。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class DCPythiaBlock(nn.Module):
def __init__(self, config: DCPythiaConfig, lidx) -> None:
super().__init__()
self.lidx = lidx
self.attention = DCMHAttention(config, lidx)
self.feed_forward = FeedForward(config)
self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
self.use_parallel_residual = config.use_parallel_residual

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True, gen_mask=gen_mask)
if self.use_parallel_residual:
out = h + self.feed_forward(self.ffn_norm(x))
else:
out = h + self.feed_forward(self.ffn_norm(h))
return out

表示模型的一层,包含注意力机制和前馈网络。
使用并行残差结构(use_parallel_residual)来优化计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class DynamicWeightProjection(nn.Module):

def __init__(self, num_heads=32, num_groups=1, residual=True, query_input_dim=4096, dynamic_squeeze_ratio=16, dynamic_w_hidden_dim=128,dtype=torch.float16,use_sw=False):
super().__init__()
self.num_heads = num_heads
self.num_groups = num_groups
self.query_input_dim = query_input_dim
self.dynamic_squeeze_ratio = dynamic_squeeze_ratio
self.dynamic_w_hidden_dim = dynamic_w_hidden_dim
self.dw_hidden_activation = nn.GELU()
self.num_heads_per_group = self.num_heads // self.num_groups
self.dw_activation = nn.Tanh()
self.dw1_norm = RMSnormNoscale(dim=-1)
self.use_sw = use_sw
self.pre_proj = CrossHeadProjection('pre', num_heads=self.num_heads, use_sw=use_sw)
self.post_proj = CrossHeadProjection('post', num_heads=self.num_heads, use_sw=use_sw)

dynamic_hidden_dim = self.num_heads_per_group // self.dynamic_squeeze_ratio
self.dynamic_hidden_dim = dynamic_hidden_dim
self.dw1 = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, 4, self.dynamic_w_hidden_dim, dtype=dtype)) #(4096, 1, 4, 128)
G, K, M = self.num_groups, self.dynamic_w_hidden_dim, self.num_heads_per_group
I = dynamic_hidden_dim * 2
self.qkw = nn.parameter.Parameter(torch.zeros([G, 4, K, I, M], dtype=dtype)) # (1, 4, 128, 4, 32)
self.dd = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, self.num_heads_per_group * 4, dtype=dtype)) # (4096, 1, 128)

self.merge_weights()

def merge_weights(self):
self.dw_m = nn.parameter.Parameter(torch.cat([self.dw1.reshape(self.query_input_dim, -1), self.dd.squeeze(1)], dim=-1)).to(self.dw1.device) # E,(4*K + K) K=2*N*I
self.qkw_m = nn.parameter.Parameter(self.qkw.permute(0,1,2,3,4).reshape(4,self.dynamic_w_hidden_dim,-1)).to(self.dw1.device) #(4,K,I*M)
if self.use_sw:
self.sw = nn.parameter.Parameter(torch.stack([self.pre_proj.w, self.post_proj.w]).squeeze(1) + torch.eye(self.num_heads) ).to(self.dw1.device) # (2,N,N) sw + identity matrix
else:
self.sw = (torch.eye(self.num_heads).expand(2,self.num_heads,self.num_heads)).to(self.dw1.device) # identity matrix (2,N,N)

def forward(self,query_vec,KW:Optional[torch.Tensor]=None, gen_cache:Optional[bool]=True):
dw_hidden = torch.einsum('BTD,DGCK->BTGCK', query_vec, self.dw1) # C=4 [pre,post]*[query,key]
dw_hidden = self.dw_hidden_activation(dw_hidden) #BTGCK
w1, w2 = torch.split(torch.einsum('BTGCK,GCKIM->BTGCIM', dw_hidden, self.qkw), self.qkw.shape[-2]//2, dim=-2) #BTGC(2I)M -> [BTGCIM] * 2
w1 = self.dw1_norm(w1) # BTGCIM
pre_qw1, pre_kw1, post_qw1, post_kw1 = unbind(w1, 4, dim=3) # BTG4IM->[BTGIM]*4
pre_qw2, pre_kw2, post_qw2, post_kw2 = unbind(w2, 4, dim=3)
dd = torch.einsum('BTD,DGM->BTGM', query_vec, self.dd) # BTG(4M)
dd = self.dw_activation(dd)
pre_qdd, pre_kdd, post_qdd, post_kdd = torch.split(dd, dd.shape[-1] // 4, dim=-1) # BTG(4N)->[BTGN]*4
pre_dw_args = (pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd)
post_dw_args = (post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd)
if gen_cache: # generate KW cache
pre_kw = torch.einsum('BSGIM, BSGIN->BSMN', pre_kw1, pre_kw2) + torch.diag_embed(pre_kdd.squeeze(2)) # merge kw and kdd
post_kw = torch.einsum('BSGIM, BSGIN->BSMN', post_kw1, post_kw2) + torch.diag_embed(post_kdd.squeeze(2))
KW = torch.stack((pre_kw, post_kw), dim=-3) # BSMN,BSMN->BS2MN
return pre_dw_args, post_dw_args, KW

一个动态权重投影模块,用于生成自适应的权重矩阵。
通过输入序列特征动态调整注意力计算的权重。
提供高效的权重生成和缓存机制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class DCMHAttention(nn.Module):
def __init__(self, config: DCPythiaConfig, lidx, use_sw=False):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.lidx = lidx
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.use_linear_bias)
self.wo = nn.Linear(config.dim, config.dim, bias=config.use_linear_bias)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.is_training = config.is_training
self.dim = config.dim
self.use_dcmha = config.use_dcmha
self.scale_factor = 1 / math.sqrt(self.head_dim)
self.q_chunk_size = config.q_chunk_size
self.use_sw = use_sw
self.dyn_w_proj = DynamicWeightProjection(num_heads=self.n_head, query_input_dim=config.dim, dynamic_squeeze_ratio=self.n_head//2, dynamic_w_hidden_dim=self.n_head*4, use_sw=use_sw)
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = RMSnorm(hid_dim=self.head_dim)
self.k_norm = RMSnorm(hid_dim=self.head_dim)

self.window_types = {
"LG":[256, None],
"LGLL":[256, None, 256, 256],
"LGL6":[256, None, 256, 256, 256, 256, 256, 256],
}

self.query_wise = config.query_wise
if config.window_type is None: # LG
self.window_size = None if self.lidx % 2 == 1 else config.window_size
else:
window_l = self.window_types[config.window_type]
self.window_size = window_l[self.lidx % len(window_l)]

self.rotary_ndims = int(self.head_dim * config.rotary_pct)

if not self.is_training:
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
if prefix + "wq.bias" in state_dict:
wq_b = state_dict.pop(prefix + "wq.bias")
wk_b = state_dict.pop(prefix + "wk.bias")
wv_b = state_dict.pop(prefix + "wv.bias")
state_dict[prefix + "wqkv.bias"] = torch.cat([wq_b, wk_b, wv_b])

def _generate_fast(self, x, input_pos, q, k, v, k_mask):
B,T,D = x.shape
N,I = self.n_head, self.dyn_w_proj.dynamic_hidden_dim # 32, 2
dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) # BTD, D(4K+4N) -> BT(4K+4N) -> BT(4K), BT(4N)
dw_hidden = dw_hidden.view((B,T,4,-1,1)) # BT(4K) -> BT4K1
dw = (self.dyn_w_proj.dw_hidden_activation(dw_hidden) * self.dyn_w_proj.qkw_m).sum(-2) # gelu, BT4K1, 4K(IM)->BT4K(IM)->BT4(IM)
w1, w2 = dw.view((B,T,2,2,-1,N)).split(I,-2) # BT4(IM)->BT{pre/post}{q/k}IM->[BT22IM] * 2
w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
qkdd = self.dyn_w_proj.dw_activation(dd.view((B,T,2,2,N))) # BT2{2}N1->BT2{2}N tanh
qkw = torch.einsum('BTKJIN,BTKJIM->BTKJNM', w1, w2) + torch.diag_embed(qkdd) # j=k=2, BT2{2}NM q/k, pre/post
if self.query_wise: # TODO: do not generate kw and kdd
qw, _ = qkw.unbind(3) # BS2NM
kw_new = None
qw = qw + self.dyn_w_proj.sw
else:
qw, kw_new = qkw.unbind(3) # BS{pre/post}{q/k}NM -> BS{pre/post}NM * 2
kw_new = kw_new + self.dyn_w_proj.sw # BS2NM + 2NM-> BS2NM
if self.kv_cache is not None:
k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) #BNT2M
logits = q @ k.transpose(-2, -1) * self.scale_factor
if self.query_wise:
w = qw # B12NM
else:
w = qw + kw_out # B12NM,BS2NM -> BS2NM
wl, w = w.permute(0,2,3,4,1).unbind(1) # BS2NM->B2NMS->[BNMS]*2
logits = (logits * wl).sum(1).unsqueeze(2) # BN1S, BNMS -> BNMS-> BMS-> BM1S
min_value = torch.finfo(torch.float16).min
logits = torch.where(k_mask, logits, min_value)
probs = logits.softmax(-1)
probs = (probs * w).sum(1).unsqueeze(2)
y = probs @ v
return y

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim) # BSND
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)

if self.rotary_ndims == self.head_dim:
q = apply_rotary_emb(q, freqs_cis) #BTND
k = apply_rotary_emb(k, freqs_cis)
else:
q_rot = q[..., : self.rotary_ndims]
q_pass = q[..., self.rotary_ndims :]
k_rot = k[..., : self.rotary_ndims]
k_pass = k[..., self.rotary_ndims :]
q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND
k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half')
q = torch.cat((q_rot, q_pass), dim=-1)
k = torch.cat((k_rot, k_pass), dim=-1)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) # BNSD

if self.is_training:
N, D, I = self.n_head, self.head_dim, self.dyn_w_proj.dynamic_hidden_dim; # 6.7B
B,T,E = x.shape
if self.use_dcmha:
project_logits = True
project_probs = True
if project_probs:
dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1)
dw_hidden = self.dyn_w_proj.dw_hidden_activation(dw_hidden)
dw_hidden = dw_hidden.view(dw_hidden.shape[:2]+(4,-1)) #B T (4 K) -> B T 4 K # reshape
dw = torch.einsum('B T C K, C K D -> B T C D', dw_hidden, self.dyn_w_proj.qkw_m) # BT4K,4K(MI)->BT4(MI)
shape = (B,T,2*2,-1,N)# if project_logits else (B,T,2,N,-1) # BT(pre/post)(q/k)IN
w1, w2 = dw.view(shape).split(I,-2)
w1 = self.dyn_w_proj.dw1_norm(w1) # BT22IN
if self.use_sw:
pre_sw, post_sw = self.dyn_w_proj.sw.unbind(0)
else:
pre_sw, post_sw = None, None
pre_qw1, pre_kw1, post_qw1, post_kw1 = w1.unbind(2) # BT(2{*2})IN->[BTIN]*4
pre_qw2, pre_kw2, post_qw2, post_kw2 = w2.unbind(2)
qkdd = F.tanh(dd).squeeze(-1).view(shape[:-2] + (N,)) # BT(2{*2})N1->BT(2{*2})N
pre_qdd, pre_kdd, post_qdd, post_kdd = qkdd.unbind(2) # BT(2{*2})N->[BTN]*4

y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
for i in range(T // self.q_chunk_size):
start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
kv_start = max(0, stop - self.q_chunk_size -self.window_size)
_q = q[:, :, start : stop, :]
_k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
_atten_mask = mask[:, :, start : stop, kv_start : stop]
_pre_proj_dw_args = slice_dw(pre_sw, pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd, start, stop, kv_start) \
if project_logits else None
_post_proj_dw_args = slice_dw(post_sw, post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd, start,stop,kv_start) \
if project_probs else None
_o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
y[:,:,start:stop] = _o
else:
y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16)
for i in range(T // self.q_chunk_size):
start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size
kv_start = max(0, stop - self.q_chunk_size -self.window_size)
_q = q[:, :, start : stop, :]
_k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :]
_atten_mask = mask[:, :, start : stop, kv_start : stop]
_pre_proj_dw_args, _post_proj_dw_args = None, None
_o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args)
y[:,:,start:stop] = _o
else: # inference
if seqlen == 1: # one-token generation
k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length]
if fast_infer:
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
else:
assert not self.query_wise
# generate dw from hidden_state
pre_proj_dw_args, post_proj_dw_args, kw_new = self.dyn_w_proj(x, gen_cache=True)

# update kvkw cache
kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
if self.kv_cache is not None:
k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN

logits = q @ k.transpose(-2, -1) * self.scale_factor
# merge pre_w and apply it
pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd = pre_proj_dw_args
pre_qw = torch.einsum('BTGIN, BTGIM->BTNM',pre_qw1, pre_qw2) + torch.diag_embed(pre_qdd.squeeze(2))
pre_w = pre_qw + kw_out[:,:,0] # B1NM, BSNM -> BSNM
logits = self.dyn_w_proj.pre_proj(logits, proj_w=pre_w.squeeze(1))

logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
probs = logits.softmax(-1)

# merge post_w and apply it
post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd = post_proj_dw_args
post_qw = torch.einsum('BTGIN, BTGIM->BTNM', post_qw1, post_qw2) + torch.diag_embed(post_qdd.squeeze(2))
post_w = post_qw + kw_out[:,:,1]
probs = self.dyn_w_proj.post_proj(probs, proj_w=post_w.squeeze(1))

y = probs @ v
else: # prefill
k_mask = mask[:,:,:,:k.shape[-2]]
pre_proj_dw_args, post_proj_dw_args,kw_new = self.dyn_w_proj(x, gen_cache=True)
kw_new = kw_new + self.dyn_w_proj.sw # absorb residual or sw into kw cache
if self.kv_cache is not None:
self.kv_cache.update(input_pos, k, v, kw_val=kw_new) # BNSD, BNSD, BS2NN
logits = q @ k.transpose(-2, -1) * self.scale_factor
logits = self.dyn_w_proj.pre_proj(logits, dws=pre_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # XD BN1S
logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min)
probs = logits.softmax(-1)
probs = self.dyn_w_proj.post_proj(probs, dws=post_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) # BN1S
y = probs @ v

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = self.wo(y)
return y

多头注意力实现,支持动态权重投影和窗口化机制。
结合旋转位置嵌入(Rotary Embedding)进行位置编码。
提供快速推理模式,通过 _generate_fast 优化生成速度。
支持复杂的多阶段窗口机制(如 LG 和 LGLL)。

1
2
3
4
5
6
7
8
class FeedForward(nn.Module):
def __init__(self, config: DCPythiaConfig) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=config.use_linear_bias)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=config.use_linear_bias)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.gelu(self.w1(x)))

两层前馈网络,激活函数使用 GELU。
处理非线性变换和特征映射。

辅助函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _atten_context(query, key, value, atten_mask, pre_proj_dw_args, post_proj_dw_args):
logits = query @ key.transpose(-2, -1)
if pre_proj_dw_args is not None: logits = _cross_head_proj(logits, *pre_proj_dw_args)
logits = torch.where(atten_mask, logits, torch.finfo(torch.float16).min)
probs = logits.softmax(-1)
if post_proj_dw_args is not None: probs = _cross_head_proj(probs, *post_proj_dw_args)
o = probs @ value # BNTS,BNSD->BNTD
return o

def _cross_head_proj(inputs, sw, qw1, qw2, kw1, kw2, qdd, kdd, loop_over_dynamic_hd=False):
out = inputs + torch.einsum('BNTS,NM->BMTS', inputs, sw) if sw is not None else inputs
for i in range(2): # qw1.shape[-2]):
qhidden = (inputs * qw1[..., i, :].transpose(-2, -1).unsqueeze(-1)).sum(1) # BNTS,(BTN->BNT->BNT1)->BNTS->BTS
qout = qhidden.unsqueeze(1) * qw2[..., i, :].transpose(-2, -1).unsqueeze(-1) # (BTS->B1TS),(BTN->BNT->BNT1)->BNTS
out = out + qout
khidden = (inputs * kw1[..., i, :].transpose(-2, -1).unsqueeze(-2)).sum(1) # BNTS,(BSN->BNS->BN1S)->BNTS->BTS
kout = khidden.unsqueeze(1) * kw2[..., i, :].transpose(-2, -1).unsqueeze(-2) # (BTS->B1TS),(BSN->BNS->BNS1)->BNTS
out = out + kout
qdout = inputs * qdd.transpose(-2, -1).unsqueeze(-1); out = out + qdout # BNTS,(BTN->BNT->BNT1)->BNTS
kdout = inputs * kdd.transpose(-2, -1).unsqueeze(-2); out = out + kdout # BNTS,(BSN->BNS->BN1S)->BNTS
return out

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)

def make_window_mask(t, window_size):
col_idx = torch.tile(torch.arange(t).unsqueeze(0), [t, 1])
row_idx = torch.tile(torch.arange(t).unsqueeze(1), [1, t])
bias_mask = (col_idx + window_size >= row_idx).tril().view(t, t)
return bias_mask

def slice_dw(sw, qw1, qw2, kw1, kw2, qdd, kdd, start, stop, kv_start):
return (sw,
qw1[:, start : stop] if qw1 is not None else None,
qw2[:, start : stop] if qw2 is not None else None,
kw1[:, kv_start : stop] if kw1 is not None else None,
kw2[:, kv_start : stop] if kw2 is not None else None,
qdd[:, start : stop] if qdd is not None else None,
kdd[:, kv_start : stop] if kdd is not None else None)

def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.float16)

def unbind(ary, n, dim=0):
return [torch.squeeze(a, dim=dim) for a in torch.split(ary, ary.shape[dim] // n, dim=dim)]

def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
if mode == 'half':
xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
elif mode == 'alternative':
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)

提供位置编码 (precompute_freqs_cis)、窗口掩码生成 (make_window_mask)、嵌入旋转 (apply_rotary_emb) 和张量拆分 (unbind) 等工具函数。