사용자 도구

사이트 도구


v-mpo:example

차이

문서의 선택한 두 판 사이의 차이를 보여줍니다.

차이 보기로 링크

다음 판
이전 판
v-mpo:example [2020/06/14 12:01] – 만듦 rex8312v-mpo:example [2024/03/23 02:42] (현재) – 바깥 편집 127.0.0.1
줄 1: 줄 1:
-====== V-MPO: example ======+====== Example: V-MPO ======
  
 <code python vmpo_example.py> <code python vmpo_example.py>
줄 5: 줄 5:
  
 """ """
 +conda create -n ppo python=3.7 numpy ipython matplotlib swig termcolor tqdm scipy tensorboard
 +conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
 pip install gym[box2d] pip install gym[box2d]
 pip install plotille pip install plotille
줄 10: 줄 12:
  
 import argparse import argparse
 +import threading
 from collections import deque from collections import deque
  
줄 15: 줄 18:
 import matplotlib.pyplot as plt import matplotlib.pyplot as plt
 import numpy as np import numpy as np
 +import plotille
 import torch import torch
 import torch.nn as nn import torch.nn as nn
줄 21: 줄 25:
 import tqdm import tqdm
 from IPython import embed from IPython import embed
- +from termcolor import cprint 
-import plotille+from torch.utils.tensorboard import SummaryWriter
  
  
 def parse_args(): def parse_args():
     parser = argparse.ArgumentParser()     parser = argparse.ArgumentParser()
-    parser.add_argument('--comment', type=str, default='baseline')+    parser.add_argument('--logdir', type=str
 +    parser.add_argument('--max_frames', type=int) 
 +    parser.add_argument('--eval'action='store_true')
     parser.add_argument('--seed', type=int, default=0)     parser.add_argument('--seed', type=int, default=0)
     parser.add_argument('--n_envs', type=int, default=8)     parser.add_argument('--n_envs', type=int, default=8)
줄 33: 줄 39:
     parser.add_argument('--mini_batch_size', type=int, default=128)     parser.add_argument('--mini_batch_size', type=int, default=128)
     parser.add_argument('--lr', type=float, default=0.0005)     parser.add_argument('--lr', type=float, default=0.0005)
 +    parser.add_argument('--weight_decay', type=float, default=0.0)
     parser.add_argument('--gamma', type=float, default=0.99)     parser.add_argument('--gamma', type=float, default=0.99)
     parser.add_argument('--lam', type=float, default=0.99)     parser.add_argument('--lam', type=float, default=0.99)
     parser.add_argument('--max_grad', type=float, default=10.0)     parser.add_argument('--max_grad', type=float, default=10.0)
-    parser.add_argument('--eps_eta', type=float, default=0.01)+    parser.add_argument('--eps_eta', type=float, default=1.0)
     parser.add_argument('--eps_alpha', type=float, default=0.01)     parser.add_argument('--eps_alpha', type=float, default=0.01)
 +    parser.add_argument('--value_coef', type=float, default=0.5)
     return parser.parse_args()     return parser.parse_args()
  
줄 48: 줄 56:
         self.vf = nn.Linear(32, 1)         self.vf = nn.Linear(32, 1)
         self.policy = nn.Linear(32, n_actions)         self.policy = nn.Linear(32, n_actions)
 +
 +        self.apply(self._init_weights)
 +
 +    def _init_weights(self, module):
 +        if isinstance(module, (nn.Linear, nn.Embedding)):
 +            module.weight.data.normal_(mean=0.0, std=0.02)
 +            if isinstance(module, nn.Linear) and module.bias is not None:
 +                module.bias.data.zero_()
 +        elif isinstance(module, nn.LayerNorm):
 +            module.bias.data.zero_()
 +            module.weight.data.fill_(1.0)
  
     def forward(self, x):     def forward(self, x):
줄 60: 줄 79:
     np.random.seed(args.seed)     np.random.seed(args.seed)
     torch.manual_seed(args.seed)     torch.manual_seed(args.seed)
 +    writer = SummaryWriter(args.logdir)
  
     # 환경 생성     # 환경 생성
줄 66: 줄 86:
     n_features = env.observation_space.shape[1]     n_features = env.observation_space.shape[1]
     n_actions = env.action_space[0].n     n_actions = env.action_space[0].n
- 
-    test_env = gym.make('LunarLander-v2') 
-    test_env.seed(args.seed) 
  
     # 모델 & 옵티마이저 생성     # 모델 & 옵티마이저 생성
줄 76: 줄 93:
     eta = torch.tensor(1.0, requires_grad=True)     eta = torch.tensor(1.0, requires_grad=True)
     alpha = torch.tensor(5.0, requires_grad=True)     alpha = torch.tensor(5.0, requires_grad=True)
-    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=0.01)+    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=args.weight_decay)
  
     # 버퍼 생성     # 버퍼 생성
줄 84: 줄 101:
     D_done = np.zeros((args.horizon, args.n_envs))     D_done = np.zeros((args.horizon, args.n_envs))
     D_value = np.zeros((args.horizon, args.n_envs))     D_value = np.zeros((args.horizon, args.n_envs))
-    D_logp = np.zeros((args.horizon, args.n_envs, n_actions))+    D_logp_a = np.zeros((args.horizon, args.n_envs))
  
     # 학습 시작     # 학습 시작
줄 91: 줄 108:
     n_scores_in_epoch = 0     n_scores_in_epoch = 0
     scores = deque(maxlen=10000)     scores = deque(maxlen=10000)
 +    eval_scores = list()
 +
 +    # 평가 게임 시작
 +    def eval():
 +        env = gym.make('LunarLander-v2')
 +        env.seed(args.seed)
 +        while True:
 +            obs = env.reset()
 +            score = 0
 +            while True:
 +                with torch.no_grad():
 +                    _, _, prob = model_old(
 +                        torch.from_numpy(obs).float().view(1, -1)
 +                    )
 +                action = prob.multinomial(1).item()
 +                obs, r, done, _ = env.step(action)
 +                score += r
 +                env.render()
 +                if done:
 +                    eval_scores.append(score)
 +                    break
 +
 +    if args.eval:
 +        threading.Thread(target=eval, daemon=True).start()
  
     obs_prime = env.reset()     obs_prime = env.reset()
줄 116: 줄 157:
             D_done[D_i] = done             D_done[D_i] = done
             D_value[D_i] = value.view(-1).numpy()             D_value[D_i] = value.view(-1).numpy()
-            D_logp[D_i] = logp.numpy()+            D_logp_a[D_i] = logp.numpy()[range(args.n_envs), action]
  
             obs = obs_prime             obs = obs_prime
줄 145: 줄 186:
         FD_obs = D_obs.reshape(-1, n_features)         FD_obs = D_obs.reshape(-1, n_features)
         FD_action = D_action.reshape(-1)         FD_action = D_action.reshape(-1)
-        FD_logp D_logp.reshape(-1, n_actions)+        FD_logp_a D_logp_a.reshape(-1)
         FD_ret = D_ret.reshape(-1)         FD_ret = D_ret.reshape(-1)
         FD_adv = D_adv.reshape(-1)         FD_adv = D_adv.reshape(-1)
줄 162: 줄 203:
             ret = torch.from_numpy(FD_ret[sel]).float()             ret = torch.from_numpy(FD_ret[sel]).float()
             adv = torch.from_numpy(FD_adv[sel]).float()             adv = torch.from_numpy(FD_adv[sel]).float()
-            logp_old = torch.from_numpy(FD_logp[sel]).float()+            logp_a_old = torch.from_numpy(FD_logp_a[sel]).float()
             top_k = torch.from_numpy(FD_top_k[sel]).bool()             top_k = torch.from_numpy(FD_top_k[sel]).bool()
  
줄 168: 줄 209:
             value, logp, prob = model(obs)             value, logp, prob = model(obs)
             logp_a = logp.gather(1, action.view(-1, 1)).view(-1)             logp_a = logp.gather(1, action.view(-1, 1)).view(-1)
-            entropy = -(prob * logp).sum(1).mean().item()+            entropy = -(prob * logp).sum(-1).mean()
  
             # loss_v             # loss_v
-            loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean()+            loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean(
 +            loss_v = F.smooth_l1_loss(value, ret.view(value.shape))
             # loss_pi             # loss_pi
             with torch.no_grad():             with torch.no_grad():
-                aug_adv = (adv[top_k] / eta).exp() / (adv[top_k] / eta).exp().sum() +                aug_adv_max = (adv[top_k] / eta).max() 
-            loss_pi = -(aug_adv * logp_a[top_k]).sum() +                aug_adv = (adv[top_k] / eta - aug_adv_max).exp() #  / (adv[top_k] / eta - aug_adv_max).exp().sum() 
-            # loss_eta +                norm_aug_adv = aug_adv / aug_adv.sum() 
-            loss_eta = eta * args.eps_eta + eta * (adv[top_k] / eta).exp().mean().log()+            loss_pi = -(norm_aug_adv * logp_a[top_k]).sum() 
 +            # loss_eta (dual func.) 
 +            loss_eta = eta * args.eps_eta + aug_adv_max + eta * (adv[top_k] / eta - aug_adv_max).exp().mean().log()
             # loss_alpha             # loss_alpha
 +            prob_a_old, prob_a = logp_a_old.exp(), logp_a.exp()
             # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')             # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')
-            kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) +            kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1)  
 +            kld = prob_a_old * (logp_a_old - logp_a) - prob_a_old + prob_a 
             loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()             loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()
             # total_loss             # total_loss
-            loss = loss_v + loss_pi + loss_eta + loss_alpha+            loss = loss_v * args.value_coef + loss_pi + loss_eta + loss_alpha
  
             optimizer.zero_grad()             optimizer.zero_grad()
줄 191: 줄 237:
  
             with torch.no_grad():             with torch.no_grad():
-                eta.fill_(eta.clamp(min=1e-6, max=1e6)) +                eta와 alpha는 반드시 0 이상 
-                alpha.fill_(alpha.clamp(min=1e-6, max=1e6))+                eta.data.copy_(eta.clamp(min=1e-6, max=1e6)) 
 +                alpha.data.copy_(alpha.clamp(min=1e-6, max=1e6)) 
 + 
 +        # target 모델 교체 
 +        model_old.load_state_dict(model.state_dict())
  
         # 학습결과 출력         # 학습결과 출력
줄 198: 줄 248:
         mean_score = np.mean(scores_y[-n_scores_in_epoch:])         mean_score = np.mean(scores_y[-n_scores_in_epoch:])
         print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))         print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))
-        print(plotille.histogram(scores_y[-n_scores_in_epoch:], bins=n_scores_in_epoch//5, height=25, color_mode='byte'))+        print(plotille.histogram( 
 +            scores_y[-n_scores_in_epoch:],  
 +            bins=n_scores_in_epoch // 5, height=25, color_mode='byte' 
 +        ))
         n_scores_in_epoch = 0         n_scores_in_epoch = 0
         print(         print(
줄 204: 줄 257:
             f'ret: {D_ret.mean():.3f}, '             f'ret: {D_ret.mean():.3f}, '
             f'v: {value.mean():.3f}, '             f'v: {value.mean():.3f}, '
-            f'ent.: {entropy:.3f}, '+            f'ent.: {entropy.item():.3f}, '
             f'kld: {kld.mean():.3f}, '             f'kld: {kld.mean():.3f}, '
             f'L_v: {loss_v.item():.3f}, '             f'L_v: {loss_v.item():.3f}, '
줄 213: 줄 266:
             f'alpha: {alpha.item():.3f}, '             f'alpha: {alpha.item():.3f}, '
         )         )
-        # target 모델 교체 +        writer.add_scalar('metric/score', mean_score, global_step=frames) 
-        model_old.load_state_dict(model.state_dict())+        writer.add_scalar('metric/ret', D_ret.mean(), global_step=frames) 
 +        writer.add_scalar('metric/v', value.mean(), global_step=frames) 
 +        writer.add_scalar('metric/ent', entropy.item(), global_step=frames) 
 +        writer.add_scalar('loss/v', loss_v.item(), global_step=frames) 
 +        writer.add_scalar('loss/pi', loss_pi.item(), global_step=frames) 
 +        writer.add_scalar('loss/eta', eta.item(), global_step=frames) 
 +        writer.add_scalar('loss/alpha', alpha.item(), global_step=frames)
  
-        # 테스트 게임 플레이 +        if args.max_frames is not None and frames > args.max_frames
-        test_state = test_env.reset() +            writer.add_hparams(dict(algo='vmpo'), dict(final_score=mean_score)
-        score = 0 +            break 
-        while True+</code>
-            with torch.no_grad(): +
-                _, _, prob = model_old( +
-                    torch.from_numpy(test_state).float().view(1, -1) +
-                ) +
-            action prob.multinomial(1).item() +
-            test_stater, done, _ = test_env.step(action) +
-            score ++
-            test_env.render(+
-            if done: +
-                print(f'test score: {score:.3f}'+
-                score = 0 +
-                break+
  
-</code>+{{tag>example "VMPO" MPO PPO lunar_lander}}
v-mpo/example.1592136096.txt.gz · 마지막으로 수정됨: (바깥 편집)