Upload wkv.py
Browse files
Trained_20G/qwen_r1_7b_withgate_freezemlp__16G_hf/wkv.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from einops import rearrange
|
| 3 |
|
|
@@ -41,6 +42,9 @@ else:
|
|
| 41 |
def compile_decorator(func):
|
| 42 |
return func
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class Rwkv_Tmix_x070(nn.Module):
|
| 46 |
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
|
|
@@ -205,23 +209,22 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 205 |
output_final_state,
|
| 206 |
cu_seqlens
|
| 207 |
):
|
| 208 |
-
if
|
| 209 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
| 210 |
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
| 211 |
o, state = native_recurrent_rwkv7(
|
| 212 |
r=r, k=k, v=v, w=w,
|
| 213 |
a=a, b=b,
|
| 214 |
scale=1.0,
|
| 215 |
-
initial_state=s
|
| 216 |
output_final_state=True,
|
| 217 |
head_first=True,
|
| 218 |
)
|
| 219 |
-
state = state.transpose(-1, -2)
|
| 220 |
x = rearrange(o, "b h l d -> b l (h d)")
|
| 221 |
else:
|
| 222 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
| 223 |
x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
| 224 |
-
wkv7_func = chunk_rwkv7 if
|
| 225 |
o, state = wkv7_func(
|
| 226 |
r=r, k=k, v=v, w=w,
|
| 227 |
a=a, b=b,
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
from einops import rearrange
|
| 4 |
|
|
|
|
| 42 |
def compile_decorator(func):
|
| 43 |
return func
|
| 44 |
|
| 45 |
+
wkv_mode = os.environ.get("WKV_MODE", "fused")
|
| 46 |
+
wkv_mode = wkv_mode.lower()
|
| 47 |
+
assert wkv_mode in ['fused', 'chunk', 'pytorch']
|
| 48 |
|
| 49 |
class Rwkv_Tmix_x070(nn.Module):
|
| 50 |
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
|
|
|
|
| 209 |
output_final_state,
|
| 210 |
cu_seqlens
|
| 211 |
):
|
| 212 |
+
if wkv_mode == 'pytorch':
|
| 213 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
| 214 |
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
| 215 |
o, state = native_recurrent_rwkv7(
|
| 216 |
r=r, k=k, v=v, w=w,
|
| 217 |
a=a, b=b,
|
| 218 |
scale=1.0,
|
| 219 |
+
initial_state=s,
|
| 220 |
output_final_state=True,
|
| 221 |
head_first=True,
|
| 222 |
)
|
|
|
|
| 223 |
x = rearrange(o, "b h l d -> b l (h d)")
|
| 224 |
else:
|
| 225 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
| 226 |
x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
| 227 |
+
wkv7_func = chunk_rwkv7 if wkv_mode == 'chunk' else fused_recurrent_rwkv7
|
| 228 |
o, state = wkv7_func(
|
| 229 |
r=r, k=k, v=v, w=w,
|
| 230 |
a=a, b=b,
|