v-mpo:example
문서의 이전 판입니다!
V-MPO: example
- vmpo_example.py
#!/usr/bin/env python """ conda install swig pip install gym[box2d] pip install plotille """ import argparse import threading from collections import deque import gym import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm from IPython import embed import plotille def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--comment', type=str, default='baseline') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--n_envs', type=int, default=8) parser.add_argument('--horizon', type=int, default=2048) parser.add_argument('--mini_batch_size', type=int, default=128) parser.add_argument('--lr', type=float, default=0.0005) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--lam', type=float, default=0.99) parser.add_argument('--max_grad', type=float, default=10.0) parser.add_argument('--eps_eta', type=float, default=1.0) parser.add_argument('--eps_alpha', type=float, default=0.5) return parser.parse_args() class Model(nn.Module): def __init__(self, n_features, n_actions): super().__init__() self.fc1 = nn.Linear(n_features, 32) self.norm1 = nn.LayerNorm(32) self.vf = nn.Linear(32, 1) self.policy = nn.Linear(32, n_actions) def forward(self, x): x = torch.relu(self.norm1(self.fc1(x))) logit = self.policy(x) return self.vf(x), F.log_softmax(logit, -1), F.softmax(logit, -1) if __name__ == '__main__': args = parse_args() np.random.seed(args.seed) torch.manual_seed(args.seed) # 환경 생성 env = gym.vector.make('LunarLander-v2', num_envs=args.n_envs) env.seed(args.seed) n_features = env.observation_space.shape[1] n_actions = env.action_space[0].n # 모델 & 옵티마이저 생성 model = Model(n_features, n_actions) model_old = Model(n_features, n_actions) model_old.load_state_dict(model.state_dict()) eta = torch.tensor(1.0, requires_grad=True) alpha = torch.tensor(5.0, requires_grad=True) optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=0.01) # 테스트 게임 시작 def run_test_game(): test_env = gym.make('LunarLander-v2') test_env.seed(args.seed) while True: test_state = test_env.reset() score = 0 while True: with torch.no_grad(): _, _, prob = model_old( torch.from_numpy(test_state).float().view(1, -1) ) action = prob.multinomial(1).item() test_state, r, done, _ = test_env.step(action) score += r test_env.render() if done: score = 0 break threading.Thread(target=run_test_game, daemon=True).start() # 버퍼 생성 D_obs = np.zeros((args.horizon, args.n_envs, n_features)) D_action = np.zeros((args.horizon, args.n_envs)) D_reward = np.zeros((args.horizon, args.n_envs)) D_done = np.zeros((args.horizon, args.n_envs)) D_value = np.zeros((args.horizon, args.n_envs)) D_logp = np.zeros((args.horizon, args.n_envs, n_actions)) # 학습 시작 frames = 0 score = np.zeros(args.n_envs) n_scores_in_epoch = 0 scores = deque(maxlen=10000) obs_prime = env.reset() while True: obs = obs_prime for D_i in tqdm.trange(args.horizon, desc='Rollout'): # 게임 플레이 & 데이터 수집 with torch.no_grad(): value, logp, prob = model_old(torch.from_numpy(obs).float()) action = prob.multinomial(num_samples=1).numpy().reshape(-1) obs_prime, reward, done, info = env.step(action) # 점수 기록 score += reward scores.extend([(frames, s) for s in score[done]]) score[done] = 0 n_scores_in_epoch += done.sum() # 데이터 저장 D_obs[D_i] = obs D_action[D_i] = action D_reward[D_i] = reward / 100. D_done[D_i] = done D_value[D_i] = value.view(-1).numpy() D_logp[D_i] = logp.numpy() obs = obs_prime frames += args.n_envs # 데이터 수집 완료 D_i = 0 # gamma gamma = args.gamma * (1 - D_done) # return 계산 D_ret = np.zeros((args.horizon + 1, args.n_envs)) with torch.no_grad(): value, _, _ = model_old(torch.from_numpy(D_obs[-1]).float()) D_ret[-1] = value.view(-1).numpy() for t in reversed(range(args.horizon)): D_ret[t] = D_reward[t] + gamma[t] * D_ret[t+1] D_ret = D_ret[:-1] # adv 계산 value_ = np.vstack([D_value, value.numpy().transpose(1, 0)]) delta = D_reward + gamma * value_[1:] - value_[:-1] D_adv = np.zeros((args.horizon, args.n_envs)) gae = 0 for t in reversed(range(args.horizon)): gae = gae * gamma[t] * args.lam + delta[t] D_adv[t] = gae # batch 차원 제거 FD_obs = D_obs.reshape(-1, n_features) FD_action = D_action.reshape(-1) FD_logp = D_logp.reshape(-1, n_actions) FD_ret = D_ret.reshape(-1) FD_adv = D_adv.reshape(-1) # top_k (상위 50% advantage 위치) FD_top_k = FD_adv > np.median(FD_adv) # 미니배치 index 준비 idx = np.arange(args.horizon * args.n_envs) np.random.shuffle(idx) n_mini_batchs = args.horizon * args.n_envs // args.mini_batch_size for mb_i in tqdm.trange(n_mini_batchs, desc='Fit'): # 미니배치 준비 sel = idx[mb_i * args.mini_batch_size: (mb_i+1) * args.mini_batch_size] obs = torch.from_numpy(FD_obs[sel]).float() action = torch.from_numpy(FD_action[sel]).long() ret = torch.from_numpy(FD_ret[sel]).float() adv = torch.from_numpy(FD_adv[sel]).float() logp_old = torch.from_numpy(FD_logp[sel]).float() top_k = torch.from_numpy(FD_top_k[sel]).bool() # 그래프 생성 value, logp, prob = model(obs) logp_a = logp.gather(1, action.view(-1, 1)).view(-1) entropy = -(prob * logp).sum(1).mean().item() # loss_v loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean() # loss_pi with torch.no_grad(): aug_adv_max = (adv[top_k] / eta).max() aug_adv = (adv[top_k] / eta - aug_adv_max).exp() / (adv[top_k] / eta - aug_adv_max).exp().sum() loss_pi = -(aug_adv * logp_a[top_k]).sum() # loss_eta (dual func.) loss_eta = eta * args.eps_eta + aug_adv_max + eta * (adv[top_k] / eta - aug_adv_max).exp().mean().log().detach() # loss_alpha # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean') kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean() # total_loss loss = loss_v + loss_pi + loss_eta + loss_alpha optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_([*model.parameters(), eta, alpha], args.max_grad) optimizer.step() with torch.no_grad(): eta.fill_(eta.clamp(min=1e-6, max=1e6)) alpha.fill_(alpha.clamp(min=1e-6, max=1e6)) # 학습결과 출력 scores_x, scores_y = zip(*scores) mean_score = np.mean(scores_y[-n_scores_in_epoch:]) print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte')) print(plotille.histogram(scores_y[-n_scores_in_epoch:], bins=n_scores_in_epoch//5, height=25, color_mode='byte')) n_scores_in_epoch = 0 print( f'{frames:,} => {mean_score:.3f}, ' f'ret: {D_ret.mean():.3f}, ' f'v: {value.mean():.3f}, ' f'ent.: {entropy:.3f}, ' f'kld: {kld.mean():.3f}, ' f'L_v: {loss_v.item():.3f}, ' f'L_pi: {loss_pi.item():.3f}, ' f'L_eta: {loss_eta.item():.3f}, ' f'L_alpha: {loss_alpha.item():.3f}, ' f'eta: {eta.item():.3f}, ' f'alpha: {alpha.item():.3f}, ' ) # target 모델 교체 model_old.load_state_dict(model.state_dict())
v-mpo/example.1602606124.txt.gz · 마지막으로 수정됨: (바깥 편집)