사용자 도구

사이트 도구


code:gpt_example

문서의 이전 판입니다!


[example] GPT

gpt.py
import numpy as np
import plotille
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from gr.pygr import mlab
from IPython import embed
from traitlets.config.loader import ArgumentParser
 
 
def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--custom_mha', type=lambda x: x in ('1', 'true'), default=False)
    parser.add_argument('--custom_block', type=lambda x: x in ('1', 'true'), default=True)
    return parser.parse_args()
 
args = parse_args()
 
 
class MultiheadAttention(nn.Module):
    def __init__(self, key_dim, num_heads, drop=0.1):
        super().__init__()
        self.scale = np.power(key_dim, 0.5)
        self.n_heads = num_heads
        self.dropout = nn.Dropout(drop)
 
    def forward(self, q, k, v, attn_mask):
        q = self.split_heads(q)
        k = self.split_heads(k, key=True)
        v = self.split_heads(v)
        w = torch.matmul(q, k)
        w = w / self.scale
        w.masked_fill_(attn_mask, -np.inf)
        attn = F.softmax(w, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, v)
        context = self.merge_heads(context)
        return context, attn
 
    def split_heads(self, x, key=False):
        seq, bs, emb = x.size()
        d_k = emb // self.n_heads
        x = x.view(seq, bs, self.n_heads, d_k)
        if key:
            # bs, self.n_heads, d_k, seq
            x = x.permute(1, 2, 3, 0)
        else:
            # bs, self.n_heads, seq, d_k
            x = x.permute(1, 2, 0, 3)  
        return x
 
    def merge_heads(self, x):
        bs, heads, seq, d_k = x.size()
        x = x.permute(2, 0, 1, 3)
        x = x.reshape(seq, bs, self.n_heads * d_k)
        return x
 
 
class MHA(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.n_heads = num_heads
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        if args.custom_mha:
            self.attn = MultiheadAttention(embed_dim, num_heads)
        else:
            self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.out = nn.Linear(embed_dim, embed_dim)
 
        layers = (self.qkv, self.out)
        for layer in layers:
            torch.nn.init.xavier_uniform_(layer.weight)
        self.out.bias.data.zero_()
 
    def forward(self, x, mask):
        seq, bsz, emb = x.size()
        q, k, v = self.qkv(x).split(emb, dim=2)
        context, weight = self.attn(q, k, v, attn_mask=mask)
        return self.out(context)
 
 
class MLP(nn.Module):
    def __init__(self, embed_dim, factor=4):
        super(MLP, self).__init__()
        self.fc = nn.Linear(embed_dim, embed_dim * factor)
        self.fc2 = nn.Linear(embed_dim * factor, embed_dim)
 
        torch.nn.init.normal_(self.fc.weight, std=0.02)
        torch.nn.init.uniform_(self.fc.bias, -0.001, 0.001)
        torch.nn.init.normal_(self.fc2.weight, std=0.02)
        torch.nn.init.uniform_(self.fc2.bias, -0.001, 0.001)
 
    def forward(self, x):
        x = self.fc(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
 
 
class CustomBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attn = MHA(embed_dim, num_heads)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim)
 
    def forward(self, x, src_mask=None):
        x = x + self.attn(self.ln_1(x), src_mask)
        x = x + self.mlp(self.ln_2(x))
        return x
 
 
class Block(nn.TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__(d_model, nhead, dim_feedforward, dropout)
        self.activation = F.gelu
 
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # MHA
        x = self.norm1(src)
        x = self.self_attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(x)
        # MLP
        x = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(src)))))
        src = src + self.dropout2(x)
        return src
 
 
class GPTModel(nn.Module):
    def __init__(self, input_dims, output_dims, max_len):
        super().__init__()
        self.n_layers = 3
        self.n_heads = 16
        self.d_model = 512
        self.max_len = max_len
 
        self.we = nn.Linear(input_dims, self.d_model, bias=False)
        self.wp = nn.Embedding(self.max_len, self.d_model, padding_idx=0)
        if args.custom_block:
            self.blocks = nn.ModuleList([
                CustomBlock(self.d_model, self.n_heads) for _ in range(self.n_layers)
            ])
        else:
            self.blocks = nn.ModuleList([
                Block(self.d_model, self.n_heads) for _ in range(self.n_layers)
            ])
 
        self.norm = nn.LayerNorm(self.d_model)
        self.wd = nn.Linear(self.d_model, output_dims, bias=False)
 
        torch.nn.init.normal_(self.we.weight, std=0.02)
        torch.nn.init.uniform_(self.wp.weight, -0.01, 0.01)
        torch.nn.init.normal_(self.wd.weight, std=0.02)
 
    def forward(self, src):
        src_embed = self.we(src)
        pos_idx = torch.arange(len(src), device=src.device)
        pos_embed = self.wp(pos_idx).unsqueeze(1)
        hx = src_embed + pos_embed
        src_mask = self.generate_src_mask(src.size(0), src.device)
 
        for block in self.blocks:
            hx = block(hx, src_mask=src_mask)
        hx = self.norm(hx)
 
        out = self.wd(self.norm(hx))
        src = torch.cat([src[1:], out[-1:]], dim=0).detach()
        return out, src
 
    @staticmethod
    def generate_src_mask(size, device):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().to(device)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
 
 
if __name__ == '__main__':
 
    n_epochs = 2500
    prev_steps = 64
    next_steps = 64
    test_steps = 512
    bsz = 8  # 4  # 128  # 배치 작아야 함
    device = 'cuda'
 
    dataset = np.sin(np.arange(4096) / 10.)
 
    model = GPTModel(1, 1, prev_steps + next_steps).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.000001, betas=(0.9, 0.95), eps=1e-8)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
 
    step = 0
    train_loss_list = list()
    test_loss_list = list()
 
    for epoch in tqdm.trange(n_epochs):
        # make batch id 
        bid = np.arange(len(dataset)-(prev_steps + next_steps))
        np.random.shuffle(bid)
        bid = bid[:len(bid) // bsz * bsz]
        bid = bid.reshape((len(bid) // bsz,  1, bsz))
        pos = np.arange(prev_steps + next_steps).reshape(1, -1, 1)
        idxes = bid + pos  # mini-batch x seq x data-index
 
        # fitting
        for i, idx in enumerate(tqdm.tqdm(idxes, leave=False)):
            data = dataset[idx].reshape((prev_steps + next_steps, bsz, 1))
            tgt = torch.tensor(
                data + np.random.normal(0, 0.5, data.shape),  # data + noise
                dtype=torch.float32, device=device
            )
            gen, _ = model(tgt)
 
            optimizer.zero_grad()
            loss = (0.5 * (tgt[1:] - gen[:-1]) ** 2).mean()
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + i / len(idxes))
 
            step += 1 / len(idxes)
            train_loss_list.append((step, loss.item()))
 
        # eval
        idx = np.random.randint(0, len(dataset)-(prev_steps + test_steps), 1).reshape(-1, 1)
        idx = idx + np.arange(prev_steps + test_steps).reshape(-1, 1)
        data = dataset[idx].reshape(prev_steps + test_steps, 1, 1)
        tgt = torch.tensor(
            data + np.random.normal(0, 0.5, data.shape), 
            dtype=torch.float32, device=device
        )
        src = tgt[:prev_steps]
        gen = tgt[:prev_steps]
 
        with torch.no_grad():
            for _ in range(test_steps):
                gen_, src = model(src)
                gen = torch.cat([gen, gen_[-1:]], dim=0)
 
        mlab.plot(data.reshape(-1))
        mlab.oplot(gen.squeeze_().cpu().numpy())
 
        loss = (0.5 * (data.reshape(-1) - gen.squeeze_().cpu().numpy()) ** 2).mean()
        test_loss_list.append((step, loss.item()))
 
        tqdm.tqdm.write(plotille.scatter(*zip(*train_loss_list[-1000:]), height=25))
        tqdm.tqdm.write(plotille.scatter(*zip(*test_loss_list[-1000:]), height=25))
        tqdm.tqdm.write(str(args))
 
    embed()
code/gpt_example.1596577361.txt.gz · 마지막으로 수정됨: (바깥 편집)