code:gpt_example
문서의 이전 판입니다!
[example] GPT
- 참고
- gpt.py
import numpy as np import plotille import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm from gr.pygr import mlab from IPython import embed class MultiheadAttention(nn.Module): def __init__(self, key_dim, num_heads, drop=0.1): super().__init__() self.scale = np.power(key_dim, 0.5) self.n_heads = num_heads self.dropout = nn.Dropout(drop) def forward(self, q, k, v, attn_mask): q = self.split_heads(q) k = self.split_heads(k, key=True) v = self.split_heads(v) w = torch.matmul(q, k) w = w / self.scale w.masked_fill_(attn_mask, -np.inf) attn = F.softmax(w, dim=-1) attn = self.dropout(attn) context = torch.matmul(attn, v) context = self.merge_heads(context) return context, attn def split_heads(self, x, key=False): seq, bs, emb = x.size() d_k = emb // self.n_heads x = x.view(seq, bs, self.n_heads, d_k) if key: # bs, self.n_heads, d_k, seq x = x.permute(1, 2, 3, 0) else: # bs, self.n_heads, seq, d_k x = x.permute(1, 2, 0, 3) return x def merge_heads(self, x): bs, heads, seq, d_k = x.size() x = x.permute(2, 0, 1, 3) x = x.reshape(seq, bs, self.n_heads * d_k) return x class MHA(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.n_heads = num_heads self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False) # self.attn = MultiheadAttention(embed_dim, num_heads) self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.out = nn.Linear(embed_dim, embed_dim) layers = (self.qkv, self.out) for layer in layers: torch.nn.init.xavier_uniform_(layer.weight) self.out.bias.data.zero_() def forward(self, x): seq, bs, emb = x.size() q, k, v = self.qkv(x).split(emb, dim=2) mask = (torch.tril(torch.ones(seq, seq)) == 0).to(x.device) context, weight = self.attn(q, k, v, attn_mask=mask) return self.out(context) class MLP(nn.Module): def __init__(self, embed_dim, factor=4): super(MLP, self).__init__() self.fc = nn.Linear(embed_dim, embed_dim * factor) self.fc2 = nn.Linear(embed_dim * factor, embed_dim) torch.nn.init.normal_(self.fc.weight, std=0.02) torch.nn.init.uniform_(self.fc.bias, -0.001, 0.001) torch.nn.init.normal_(self.fc2.weight, std=0.02) torch.nn.init.uniform_(self.fc2.bias, -0.001, 0.001) def forward(self, x): x = self.fc(x) x = F.gelu(x) x = self.fc2(x) return x class Block(nn.Module): def __init__(self, embed_dim, num_heads): super(Block, self).__init__() self.ln_1 = nn.LayerNorm(embed_dim) self.attn = MHA(embed_dim, num_heads) self.ln_2 = nn.LayerNorm(embed_dim) self.mlp = MLP(embed_dim) def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class GPTModel(nn.Module): def __init__(self, input_dims, output_dims, max_len): super().__init__() self.n_layers = 3 self.n_heads = 16 self.d_model = 512 self.max_len = max_len self.we = nn.Linear(input_dims, self.d_model, bias=False) self.wp = nn.Embedding(self.max_len, self.d_model, padding_idx=0) self.blocks = nn.ModuleList([Block(self.d_model, self.n_heads) for _ in range(self.n_layers)]) self.norm = nn.LayerNorm(self.d_model) self.wd = nn.Linear(self.d_model, output_dims, bias=False) torch.nn.init.normal_(self.we.weight, std=0.02) torch.nn.init.uniform_(self.wp.weight, -0.01, 0.01) torch.nn.init.normal_(self.wd.weight, std=0.02) def forward(self, src): seq_len, mb, _ = src.size() src_embed = self.we(src) pos_idx = torch.arange(len(src), device=src.device) pos_embed = self.wp(pos_idx).unsqueeze(1) hx = src_embed + pos_embed for block in self.blocks: hx = block(hx) hx = self.norm(hx) out = self.wd(self.norm(hx)) src = torch.cat([src[1:], out[-1:]], dim=0).detach() return out, src if __name__ == '__main__': n_epochs = 2500 prev_steps = 64 next_steps = 64 test_steps = 512 mb = 8 # 128 device = 'cuda' dataset = np.sin(np.arange(4096) / 10.) model = GPTModel(1, 1, prev_steps + next_steps).to(device) optimizer = optim.Adam(model.parameters(), lr=0.000001, betas=(0.9, 0.95), eps=1e-8) step = 0 train_loss_list = list() test_loss_list = list() for _ in tqdm.trange(n_epochs): # make batch id bid = np.arange(len(dataset)-(prev_steps + next_steps)) np.random.shuffle(bid) bid = bid[:len(bid) // mb * mb] bid = bid.reshape((len(bid) // mb, 1, mb)) pos = np.arange(prev_steps + next_steps).reshape(1, -1, 1) idxes = bid + pos # mini-batch x seq x data-index # fitting for idx in tqdm.tqdm(idxes, leave=False): data = dataset[idx].reshape((prev_steps + next_steps, mb, 1)) tgt = torch.tensor( data + np.random.normal(0, 0.5, data.shape), # data + noise dtype=torch.float32, device=device ) gen, _ = model(tgt) optimizer.zero_grad() loss = (0.5 * (tgt[1:] - gen[:-1]) ** 2).mean() loss.backward() optimizer.step() step += 1 / len(idxes) train_loss_list.append((step, loss.item())) # eval idx = np.random.randint(0, len(dataset)-(prev_steps + test_steps), 1).reshape(-1, 1) idx = idx + np.arange(prev_steps + test_steps).reshape(-1, 1) data = dataset[idx].reshape(prev_steps + test_steps, 1, 1) tgt = torch.tensor( data + np.random.normal(0, 0.5, data.shape), dtype=torch.float32, device=device ) src = tgt[:prev_steps] gen = tgt[:prev_steps] with torch.no_grad(): for _ in range(test_steps): gen_, src = model(src) gen = torch.cat([gen, gen_[-1:]], dim=0) mlab.plot(data.reshape(-1)) mlab.oplot(gen.squeeze_().cpu().numpy()) loss = (0.5 * (data.reshape(-1) - gen.squeeze_().cpu().numpy()) ** 2).mean() test_loss_list.append((step, loss.item())) tqdm.tqdm.write(plotille.scatter(*zip(*train_loss_list[-1000:]), height=25)) tqdm.tqdm.write(plotille.scatter(*zip(*test_loss_list[-1000:]), height=25)) embed()
code/gpt_example.1596223331.txt.gz · 마지막으로 수정됨: (바깥 편집)