import argparse import math import numpy as np import plotille import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm from gr.pygr import mlab from IPython import embed from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--max_epoch', type=int, default=200) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--data_repeat', type=int, default=1) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--block_size', type=int, default=32) parser.add_argument('--test_steps', type=int, default=512) parser.add_argument('--n_workers', type=int, default=1) parser.add_argument('--weight_decay', type=float, default=0.1) parser.add_argument('--noise_scale', type=float, default=0.1) parser.add_argument('--max_grad_norm', type=float, default=1.0) parser.add_argument('--dataset', choices=['BasicDataset', 'MotionDataset'], default='MotionDataset') return parser.parse_args() args = parse_args() class CausalSelfAttention(nn.Module): """ https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ def __init__(self, d_model, n_head, block_size, dropout): super().__init__() assert d_model % n_head == 0 # key, query, value projections for all heads self.key = nn.Linear(d_model, d_model) self.query = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) # regularization self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) # output projection self.proj = nn.Linear(d_model, d_model) # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size) ) self.n_head = n_head def forward(self, x, layer_past=None): B, T, C = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.mask[:,:,:T,:T] == 0, -1e10) # todo: just use float('-inf') instead? att = F.softmax(att, dim=-1) att = self.attn_drop(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_drop(self.proj(y)) return y class Block(nn.Module): """ an unassuming Transformer block """ def __init__(self, d_model, n_head, block_size, dropout): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.attn = CausalSelfAttention(d_model, n_head, block_size, dropout) self.mlp = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPTModel(nn.Module): def __init__(self, input_dims, output_dims, block_size): super().__init__() self.n_layers = 6 self.n_heads = 8 self.d_model = 512 self.block_size = block_size self.we = nn.Linear(input_dims, self.d_model, bias=True) self.wp = nn.Parameter(torch.zeros(1, self.block_size, self.d_model)) self.blocks = nn.Sequential(*[ Block(self.d_model, self.n_heads, self.block_size, args.dropout) for _ in range(self.n_layers) ]) self.norm = nn.LayerNorm(self.d_model) self.wd = nn.Linear(self.d_model, output_dims, bias=True) self.apply(self._init_weights) print(f'n_params: {sum(p.numel() for p in self.parameters())}') def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, src): B, T, C = src.size() src_embed = self.we(src) pos_embed = self.wp[:, :T, :] hx = src_embed + pos_embed hx = self.blocks(hx) hx = self.norm(hx) out = self.wd(self.norm(hx)) src = torch.cat([src[:, 1:, :], out[:, -1:, :]], dim=1).detach() return out, src class BasicDataset(Dataset): def __init__(self, block_size, repeat, noise_scale): self.block_size = block_size self.data = np.sin(np.arange(10240) / 10.) # self.data = np.sin(np.arange(10240) / 10.) * 0.5 + 2.5 # self.data = np.abs(np.sin(np.arange(10240) / 10.)) # data = np.sin(np.arange(10240) / 10.) * (np.sin(np.arange(10240) / 10.) > 0.0) self.data = self.data.astype(np.float32) self.data = self.data.reshape(-1, 1) self.data_std = self.data.std(0) self.repeat = repeat self.noise_scale = noise_scale def __len__(self): # return math.ceil(len(self.data) / (self.block_size + 1)) return len(self.data) * self.repeat def __getitem__(self, idx): # we're actually going to "cheat" and pick a spot in the dataset at random i = np.random.randint(0, len(self.data) - (self.block_size + 1)) chunk = self.data[i: i+self.block_size+1] chunk += np.random.normal(0, args.noise_scale, chunk.shape) * self.data_std x = torch.tensor(chunk[:-1], dtype=torch.float32) y = torch.tensor(chunk[1:], dtype=torch.float32) return x, y def get_test_data(self, test_steps, device): i = np.random.randint(0, len(self.data) - (test_steps + 1)) idx = np.arange(i, i+test_steps) data = self.data[idx].reshape(1, -1, 1) tgt = torch.tensor(data, device=device) src = tgt[:, :args.block_size] gen = tgt[:, :args.block_size] return tgt, src, gen class MotionDataset(Dataset): def __init__(self, block_size, repeat, noise_scale): self.block_size = block_size import urllib, json url = "https://raw.githubusercontent.com/xbpeng/DeepMimic/master/data/motions/humanoid3d_backflip.txt" self.data = json.loads(urllib.request.urlopen(url).read())['Frames'] self.data = np.array(self.data, dtype=np.float32) self.data = np.hstack([self.data[:, 3:4], self.data]) self.data = np.tile(self.data, (100, 1)) self.dims = self.data.shape[-1] self.data_mean = self.data.mean(0, keepdims=True) self.data_std = self.data.std(0, keepdims=True) self.data = (self.data - self.data_mean) / self.data_std self.data = self.data.astype(np.float32) self.repeat = repeat self.noise_scale = noise_scale def __len__(self): # return math.ceil(len(self.data) / (self.block_size + 1)) return len(self.data) * self.repeat def __getitem__(self, idx): # we're actually going to "cheat" and pick a spot in the dataset at random i = np.random.randint(0, len(self.data) - (self.block_size + 1)) chunk = self.data[i: i+self.block_size+1] chunk += np.random.normal(0, args.noise_scale, chunk.shape) x = torch.tensor(chunk[:-1], dtype=torch.float32) y = torch.tensor(chunk[1:], dtype=torch.float32) return x, y def get_test_data(self, test_steps, device): i = np.random.randint(0, len(self.data) - (test_steps + 1)) idx = np.arange(i, i+test_steps) data = self.data[idx].reshape(1, -1, self.dims) tgt = torch.tensor(data, device=device) src = tgt[:, :args.block_size] gen = tgt[:, :args.block_size] return tgt, src, gen if __name__ == '__main__': # create the dataloader Dataset = globals()[args.dataset] dataset = Dataset(args.block_size, args.data_repeat, args.noise_scale) loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.n_workers) # create the model dim = dataset.data.shape[-1] model = GPTModel(dim, dim, args.block_size).to(args.device) # create the optimizer no_decay = ["bias", "LayerNorm.weight"] params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)] optim_groups = [ {"params": params_decay, "weight_decay": args.weight_decay}, {"params": params_nodecay, "weight_decay": 0.0}, ] optimizer = optim.AdamW(optim_groups, lr=args.lr, betas=(0.9, 0.95)) def warmup_cosine(optimizer, lr_max, epoch, warmup=1.0): s = float(epoch <= warmup) w = s*(epoch / warmup) + (1-s)*(0.5 * (1 + np.cos(np.pi * epoch))) for param_group in optimizer.param_groups: param_group['lr'] = w * lr_max step = 0 train_loss_list = list() test_score_list = list() for epoch in tqdm.trange(args.max_epoch): # fitting model.train() for i, (src, tgt) in tqdm.tqdm(enumerate(loader), total=len(loader), leave=False): src, tgt = src.to(args.device), tgt.to(args.device) gen, _ = model(src) optimizer.zero_grad() loss = (0.5 * (tgt - gen) ** 2).mean() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() warmup_cosine(optimizer, args.lr, epoch + i / len(loader)) step += 1 / len(loader) train_loss_list.append((step, loss.item())) tqdm.tqdm.write(plotille.scatter(*zip(*train_loss_list[-1000:]), height=25)) # eval model.eval() tgt, src, gen = dataset.get_test_data(args.test_steps, args.device) with torch.no_grad(): for i in range(args.test_steps - args.block_size): gen_, src = model(src) gen = torch.cat([gen, gen_[:, -1:, :]], dim=1) loss = (0.5 * (tgt - gen) ** 2).mean() score = (-loss).exp() test_score_list.append((step, score.item())) mlab.plot(tgt.cpu().numpy()[0, :, 0]) mlab.oplot(gen.cpu().numpy()[0, :, 0]) tqdm.tqdm.write(plotille.scatter(*zip(*test_score_list[-1000:]), height=25)) tqdm.tqdm.write(str(args)) embed()