사용자 도구

사이트 도구


code:gpt_example

문서의 이전 판입니다!


GPT 예제

gpt.py
import numpy as np
import plotille
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from gr.pygr import mlab
from IPython import embed
 
 
class MultiheadAttention(nn.Module):
    def __init__(self, key_dim, num_heads, drop=0.1):
        super().__init__()
        self.temperature = np.power(key_dim, 0.5)
        self.n_heads = num_heads
        self.dropout = nn.Dropout(drop)
 
    def forward(self, q, k, v, attn_mask):
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        energy = torch.bmm(q, k.transpose(1, 2)) / self.temperature
        energy.masked_fill_(attn_mask, -np.inf)
        attn = F.softmax(energy, dim=2)
        context = torch.bmm(attn, v)
        context = self.merge_heads(context)
        context = self.dropout(context)
        return context, attn
 
    def split_heads(self, x):
        seq, bs, emb = x.size()
        d_k = emb // self.n_heads
        x = x.view(seq, bs, self.n_heads, d_k)
        x = x.permute(1, 2, 0, 3)
        x = x.reshape(bs * self.n_heads, seq, d_k)
        return x
 
    def merge_heads(self, x):
        bs_heads, seq, d_k = x.size()
        bs = bs_heads // self.n_heads
        x = x.view(bs, self.n_heads, seq, d_k)
        x = x.permute(2, 0, 1, 3)
        x = x.reshape(seq, bs, self.n_heads * d_k)
        return x
 
 
class MHA(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.n_heads = num_heads
        self.attn = MultiheadAttention(embed_dim, num_heads)
        # self.attn = nn.MultiheadAttention(embed_dim, num_heads)
 
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)
 
        layers = (self.query, self.key, self.value, self.out)
        for layer in layers:
            torch.nn.init.normal_(layer.weight, std=0.02)
            torch.nn.init.uniform_(layer.bias, -0.001, 0.001)
 
    def forward(self, x):
        seq = x.size(0)
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        mask = (torch.tril(torch.ones(seq, seq)) == 0).to(x.device)
        context, attn_weights = self.attn(q, k, v, attn_mask=mask)
        return self.out(context)
 
 
class MLP(nn.Module):
    def __init__(self, embed_dim, factor=4):
        super(MLP, self).__init__()
        self.fc = nn.Linear(embed_dim, embed_dim * factor)
        self.fc2 = nn.Linear(embed_dim * factor, embed_dim)
 
        torch.nn.init.normal_(self.fc.weight, std=0.02)
        torch.nn.init.uniform_(self.fc.bias, -0.001, 0.001)
        torch.nn.init.normal_(self.fc2.weight, std=0.02)
        torch.nn.init.uniform_(self.fc2.bias, -0.001, 0.001)
 
    def forward(self, x):
        h = F.gelu(self.fc(x))
        return self.fc2(h)
 
 
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attn = MHA(embed_dim, num_heads)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim)
 
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
 
 
class GPTModel(nn.Module):
    def __init__(self, input_dims, output_dims):
        super().__init__()
        self.n_layers = 3
        self.n_heads = 16
        self.d_model = 512
        self.max_len = 32
 
        self.we = nn.Linear(input_dims, self.d_model, bias=False)
        self.wp = nn.Embedding(self.max_len, self.d_model, padding_idx=0)
        self.blocks = nn.ModuleList([Block(self.d_model, self.n_heads) for _ in range(self.n_layers)])
 
        self.norm = nn.LayerNorm(self.d_model)
        self.wd = nn.Linear(self.d_model, output_dims, bias=True)
 
        torch.nn.init.normal_(self.we.weight, std=0.02)
        torch.nn.init.uniform_(self.wp.weight, -0.01, 0.01)
        torch.nn.init.normal_(self.wd.weight, std=0.02)
        torch.nn.init.normal_(self.wd.bias, std=0.001)
 
    def forward(self, src):
        seq_len, mb, _ = src.size()
 
        src_embed = self.we(src)
        pos_embed = self.wp(torch.arange(len(src), device=src.device)).unsqueeze(1)
        hx = src_embed + pos_embed
 
        for block in self.blocks:
            hx = block(hx)
        hx = self.norm(hx)
 
        out = (hx.view(seq_len * mb, -1) @ self.we.weight).view(seq_len, mb, -1)
        # out = self.wd(self.norm(hx))
        src = torch.cat([src[1:], out[-1:]], dim=0).detach()
        return out, src
 
 
if __name__ == '__main__':
 
    n_epochs = 2500
    seq_len = 16
    prev_steps = 16
    next_steps = 32
    mb = 32
    device = 'cuda'
 
    dataset = np.sin(np.arange(1024) / 10.)
 
    model = GPTModel(1, 1).to(device)
    optimizer = optim.Adam(
        model.parameters(), lr=0.00001, betas=(0.9, 0.95), eps=1e-8
    )
 
    step = 0
    loss_list = list()
 
    for _ in tqdm.trange(n_epochs):
        bid = np.random.randint(
            0, len(dataset)-(prev_steps + next_steps), (len(dataset) // mb,  mb)
        ).reshape((len(dataset) // mb,  1, mb))
        pos = np.arange(prev_steps + next_steps).reshape(1, -1, 1)
        idxes = bid + pos
 
        for idx in idxes:
            data = dataset[idx].reshape((prev_steps + next_steps, mb, 1))
            data = torch.tensor(data, dtype=torch.float32, device=device)
            src, tgt = data[:prev_steps], data[prev_steps:]
            gen = torch.empty(0, mb, 1, dtype=torch.float32, device=device)
            for _ in range(next_steps):
                gen_, src = model(src)
                gen = torch.cat([gen, gen_[-1:]], dim=0)
 
            optimizer.zero_grad()
            loss = (0.5 * (tgt - gen) ** 2).mean()
            loss.backward()
            optimizer.step()
 
            step += 1 / len(idxes)
            loss_list.append((step, loss.item()))
 
        mlab.plot(data[:, 0, 0].cpu().numpy())
        mlab.oplot(
            torch.cat([data[:prev_steps, 0, 0], gen[:, 0, 0]], 
            dim=0).detach().cpu().numpy()
        )
        tqdm.tqdm.write(plotille.scatter(*zip(*loss_list[-1000:])))
 
    embed()
code/gpt_example.1596046689.txt.gz · 마지막으로 수정됨: (바깥 편집)