kunjcr2 commited on
Commit
d711d91
·
verified ·
1 Parent(s): 7431e28

🚀 Push GatorGPT2 with custom config, model, tokenizer manifest

Browse files
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .configuration_gator import GatorConfig
2
+ from .modeling_gator import GatorModel
3
+ __all__ = ["GatorConfig", "GatorModel"]
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GatorModel"
4
+ ],
5
+ "model_type": "gator-transformer",
6
+ "hidden_size": 448,
7
+ "num_attention_heads": 8,
8
+ "num_hidden_layers": 10,
9
+ "vocab_size": 50257,
10
+ "max_position_embeddings": 1024,
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_gator.GatorConfig",
13
+ "AutoModelForCausalLM": "modeling_gator.GatorModel"
14
+ }
15
+ }
configuration_gator.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class GatorConfig(PretrainedConfig):
4
+ model_type = "gator-transformer"
5
+ def __init__(self, hidden_size=448, num_attention_heads=8, num_hidden_layers=10,
6
+ vocab_size=50257, max_position_embeddings=1024, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.hidden_size = hidden_size
9
+ self.num_attention_heads = num_attention_heads
10
+ self.num_hidden_layers = num_hidden_layers
11
+ self.vocab_size = vocab_size
12
+ self.max_position_embeddings = max_position_embeddings
modeling_gator.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from transformers import PreTrainedModel
6
+ from .configuration_gator import GatorConfig
7
+
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, dim, eps=1e-5):
10
+ super().__init__()
11
+ self.eps = eps
12
+ self.weight = nn.Parameter(torch.ones(dim))
13
+ def forward(self, x):
14
+ norm = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1])
15
+ return self.weight * (x / (norm + self.eps))
16
+
17
+ class Rope(nn.Module):
18
+ def __init__(self, d_model, max_len=1024):
19
+ super().__init__()
20
+ assert d_model % 2 == 0
21
+ self.register_buffer("pos", torch.arange(max_len).unsqueeze(1))
22
+ self.register_buffer("inv_freq", torch.exp(
23
+ torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)))
24
+ def forward(self, x):
25
+ t = x.size(1)
26
+ freqs = self.pos[:t] * self.inv_freq
27
+ cos, sin = torch.cos(freqs), torch.sin(freqs)
28
+ x = x.view(*x.shape[:-1], -1, 2)
29
+ x1, x2 = x[...,0], x[...,1]
30
+ x_rot = torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1)
31
+ return x_rot.view(*x.shape[:-2], -1)
32
+
33
+ class GQA(nn.Module):
34
+ def __init__(self, d_model, n_heads, gqa_groups, max_len):
35
+ super().__init__()
36
+ self.n_heads = n_heads
37
+ self.head_dim = d_model // n_heads
38
+ self.n_kv = n_heads // gqa_groups
39
+ self.q_proj = nn.Linear(d_model, n_heads*self.head_dim, bias=False)
40
+ self.k_proj = nn.Linear(d_model, self.n_kv*self.head_dim, bias=False)
41
+ self.v_proj = nn.Linear(d_model, self.n_kv*self.head_dim, bias=False)
42
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
43
+ self.rope_q = Rope(n_heads*self.head_dim, max_len)
44
+ self.rope_k = Rope(self.n_kv*self.head_dim, max_len)
45
+ def forward(self, x):
46
+ B,T,C = x.shape
47
+ q = self.rope_q(self.q_proj(x)).view(B,T,self.n_heads,self.head_dim).transpose(1,2)
48
+ k = self.rope_k(self.k_proj(x)).view(B,T,self.n_kv,self.head_dim).transpose(1,2)
49
+ v = self.v_proj(x).view(B,T,self.n_kv,self.head_dim).transpose(1,2)
50
+ expand = self.n_heads // self.n_kv
51
+ k = k.repeat_interleave(expand, dim=1)
52
+ v = v.repeat_interleave(expand, dim=1)
53
+ attn = torch.softmax((q @ k.transpose(-2,-1))/math.sqrt(self.head_dim), dim=-1)
54
+ out = attn @ v
55
+ out = out.transpose(1,2).contiguous().view(B,T,C)
56
+ return self.o_proj(out)
57
+
58
+ class MLP(nn.Module):
59
+ def __init__(self, d_model, d_ff):
60
+ super().__init__()
61
+ self.fc1 = nn.Linear(d_model, 2*d_ff, bias=False)
62
+ self.fc2 = nn.Linear(d_ff, d_model, bias=False)
63
+ def forward(self,x):
64
+ up, gate = self.fc1(x).chunk(2, dim=-1)
65
+ return self.fc2(up * F.silu(gate))
66
+
67
+ class Block(nn.Module):
68
+ def __init__(self, cfg):
69
+ super().__init__()
70
+ self.rms1 = RMSNorm(cfg.hidden_size)
71
+ self.rms2 = RMSNorm(cfg.hidden_size)
72
+ self.attn = GQA(cfg.hidden_size, cfg.num_attention_heads, 2, cfg.max_position_embeddings)
73
+ self.mlp = MLP(cfg.hidden_size, 2*cfg.hidden_size)
74
+ def forward(self,x):
75
+ x = x + self.attn(self.rms1(x))
76
+ x = x + self.mlp(self.rms2(x))
77
+ return x
78
+
79
+ class GatorModel(PreTrainedModel):
80
+ config_class = GatorConfig
81
+ def __init__(self, config):
82
+ super().__init__(config)
83
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
84
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
85
+ self.norm = RMSNorm(config.hidden_size)
86
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
87
+ self.lm_head.weight = self.embed.weight
88
+ def forward(self, input_ids):
89
+ h = self.embed(input_ids)
90
+ for blk in self.blocks: h = blk(h)
91
+ h = self.norm(h)
92
+ return {"logits": self.lm_head(h)}
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72daa34aefd98e8b836b4cbdcdaff3c88d1503e920ca8a89ca82b6ea3310e51
3
+ size 162621067
tokenizer_manifest.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "library": "tiktoken",
3
+ "encoding": "p50k_base"
4
+ }