사용자 도구

사이트 도구


v-mpo:example

차이

문서의 선택한 두 판 사이의 차이를 보여줍니다.

차이 보기로 링크

양쪽 이전 판이전 판
다음 판
이전 판
v-mpo:example [2020/10/13 17:10] rex8312v-mpo:example [2024/03/23 02:42] (현재) – 바깥 편집 127.0.0.1
줄 1: 줄 1:
-====== V-MPO: example ======+====== Example: V-MPO ======
  
 <code python vmpo_example.py> <code python vmpo_example.py>
줄 5: 줄 5:
  
 """ """
-conda install swig+conda create -n ppo python=3.7 numpy ipython matplotlib swig termcolor tqdm scipy tensorboard 
 +conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
 pip install gym[box2d] pip install gym[box2d]
 pip install plotille pip install plotille
줄 17: 줄 18:
 import matplotlib.pyplot as plt import matplotlib.pyplot as plt
 import numpy as np import numpy as np
 +import plotille
 import torch import torch
 import torch.nn as nn import torch.nn as nn
줄 23: 줄 25:
 import tqdm import tqdm
 from IPython import embed from IPython import embed
- +from termcolor import cprint 
-import plotille+from torch.utils.tensorboard import SummaryWriter
  
  
 def parse_args(): def parse_args():
     parser = argparse.ArgumentParser()     parser = argparse.ArgumentParser()
-    parser.add_argument('--comment', type=str, default='baseline')+    parser.add_argument('--logdir', type=str
 +    parser.add_argument('--max_frames', type=int) 
 +    parser.add_argument('--eval'action='store_true')
     parser.add_argument('--seed', type=int, default=0)     parser.add_argument('--seed', type=int, default=0)
     parser.add_argument('--n_envs', type=int, default=8)     parser.add_argument('--n_envs', type=int, default=8)
줄 35: 줄 39:
     parser.add_argument('--mini_batch_size', type=int, default=128)     parser.add_argument('--mini_batch_size', type=int, default=128)
     parser.add_argument('--lr', type=float, default=0.0005)     parser.add_argument('--lr', type=float, default=0.0005)
 +    parser.add_argument('--weight_decay', type=float, default=0.0)
     parser.add_argument('--gamma', type=float, default=0.99)     parser.add_argument('--gamma', type=float, default=0.99)
     parser.add_argument('--lam', type=float, default=0.99)     parser.add_argument('--lam', type=float, default=0.99)
줄 40: 줄 45:
     parser.add_argument('--eps_eta', type=float, default=1.0)     parser.add_argument('--eps_eta', type=float, default=1.0)
     parser.add_argument('--eps_alpha', type=float, default=0.01)     parser.add_argument('--eps_alpha', type=float, default=0.01)
 +    parser.add_argument('--value_coef', type=float, default=0.5)
     return parser.parse_args()     return parser.parse_args()
  
줄 50: 줄 56:
         self.vf = nn.Linear(32, 1)         self.vf = nn.Linear(32, 1)
         self.policy = nn.Linear(32, n_actions)         self.policy = nn.Linear(32, n_actions)
 +
 +        self.apply(self._init_weights)
 +
 +    def _init_weights(self, module):
 +        if isinstance(module, (nn.Linear, nn.Embedding)):
 +            module.weight.data.normal_(mean=0.0, std=0.02)
 +            if isinstance(module, nn.Linear) and module.bias is not None:
 +                module.bias.data.zero_()
 +        elif isinstance(module, nn.LayerNorm):
 +            module.bias.data.zero_()
 +            module.weight.data.fill_(1.0)
  
     def forward(self, x):     def forward(self, x):
줄 62: 줄 79:
     np.random.seed(args.seed)     np.random.seed(args.seed)
     torch.manual_seed(args.seed)     torch.manual_seed(args.seed)
 +    writer = SummaryWriter(args.logdir)
  
     # 환경 생성     # 환경 생성
줄 75: 줄 93:
     eta = torch.tensor(1.0, requires_grad=True)     eta = torch.tensor(1.0, requires_grad=True)
     alpha = torch.tensor(5.0, requires_grad=True)     alpha = torch.tensor(5.0, requires_grad=True)
-    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=0.01) +    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=args.weight_decay)
- +
-    # 테스트 게임 시작 +
-    def run_test_game(): +
-        test_env = gym.make('LunarLander-v2'+
-        test_env.seed(args.seed) +
- +
-        while True: +
-            test_state = test_env.reset() +
-            score = 0 +
-            while True: +
-                with torch.no_grad(): +
-                    _, _, prob = model_old( +
-                        torch.from_numpy(test_state).float().view(1, -1) +
-                    ) +
-                action = prob.multinomial(1).item() +
-                test_state, r, done, _ = test_env.step(action) +
-                score += r +
-                test_env.render() +
-                if done: +
-                    score = 0 +
-                    break +
-     +
-    threading.Thread(target=run_test_game, daemon=True).start()+
  
     # 버퍼 생성     # 버퍼 생성
줄 106: 줄 101:
     D_done = np.zeros((args.horizon, args.n_envs))     D_done = np.zeros((args.horizon, args.n_envs))
     D_value = np.zeros((args.horizon, args.n_envs))     D_value = np.zeros((args.horizon, args.n_envs))
-    D_logp = np.zeros((args.horizon, args.n_envs, n_actions))+    D_logp_a = np.zeros((args.horizon, args.n_envs))
  
     # 학습 시작     # 학습 시작
줄 113: 줄 108:
     n_scores_in_epoch = 0     n_scores_in_epoch = 0
     scores = deque(maxlen=10000)     scores = deque(maxlen=10000)
 +    eval_scores = list()
 +
 +    # 평가 게임 시작
 +    def eval():
 +        env = gym.make('LunarLander-v2')
 +        env.seed(args.seed)
 +        while True:
 +            obs = env.reset()
 +            score = 0
 +            while True:
 +                with torch.no_grad():
 +                    _, _, prob = model_old(
 +                        torch.from_numpy(obs).float().view(1, -1)
 +                    )
 +                action = prob.multinomial(1).item()
 +                obs, r, done, _ = env.step(action)
 +                score += r
 +                env.render()
 +                if done:
 +                    eval_scores.append(score)
 +                    break
 +
 +    if args.eval:
 +        threading.Thread(target=eval, daemon=True).start()
  
     obs_prime = env.reset()     obs_prime = env.reset()
줄 138: 줄 157:
             D_done[D_i] = done             D_done[D_i] = done
             D_value[D_i] = value.view(-1).numpy()             D_value[D_i] = value.view(-1).numpy()
-            D_logp[D_i] = logp.numpy()+            D_logp_a[D_i] = logp.numpy()[range(args.n_envs), action]
  
             obs = obs_prime             obs = obs_prime
줄 167: 줄 186:
         FD_obs = D_obs.reshape(-1, n_features)         FD_obs = D_obs.reshape(-1, n_features)
         FD_action = D_action.reshape(-1)         FD_action = D_action.reshape(-1)
-        FD_logp D_logp.reshape(-1, n_actions)+        FD_logp_a D_logp_a.reshape(-1)
         FD_ret = D_ret.reshape(-1)         FD_ret = D_ret.reshape(-1)
         FD_adv = D_adv.reshape(-1)         FD_adv = D_adv.reshape(-1)
줄 184: 줄 203:
             ret = torch.from_numpy(FD_ret[sel]).float()             ret = torch.from_numpy(FD_ret[sel]).float()
             adv = torch.from_numpy(FD_adv[sel]).float()             adv = torch.from_numpy(FD_adv[sel]).float()
-            logp_old = torch.from_numpy(FD_logp[sel]).float()+            logp_a_old = torch.from_numpy(FD_logp_a[sel]).float()
             top_k = torch.from_numpy(FD_top_k[sel]).bool()             top_k = torch.from_numpy(FD_top_k[sel]).bool()
  
줄 190: 줄 209:
             value, logp, prob = model(obs)             value, logp, prob = model(obs)
             logp_a = logp.gather(1, action.view(-1, 1)).view(-1)             logp_a = logp.gather(1, action.view(-1, 1)).view(-1)
-            entropy = -(prob * logp).sum(1).mean().item()+            entropy = -(prob * logp).sum(-1).mean()
  
             # loss_v             # loss_v
-            loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean()+            loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean(
 +            loss_v = F.smooth_l1_loss(value, ret.view(value.shape))
             # loss_pi             # loss_pi
             with torch.no_grad():             with torch.no_grad():
                 aug_adv_max = (adv[top_k] / eta).max()                 aug_adv_max = (adv[top_k] / eta).max()
-                aug_adv = (adv[top_k] / eta - aug_adv_max).exp() / (adv[top_k] / eta - aug_adv_max).exp().sum() +                aug_adv = (adv[top_k] / eta - aug_adv_max).exp() #  / (adv[top_k] / eta - aug_adv_max).exp().sum() 
-            loss_pi = -(aug_adv * logp_a[top_k]).sum()+                norm_aug_adv = aug_adv / aug_adv.sum() 
 +            loss_pi = -(norm_aug_adv * logp_a[top_k]).sum()
             # loss_eta (dual func.)             # loss_eta (dual func.)
             loss_eta = eta * args.eps_eta + aug_adv_max + eta * (adv[top_k] / eta - aug_adv_max).exp().mean().log()             loss_eta = eta * args.eps_eta + aug_adv_max + eta * (adv[top_k] / eta - aug_adv_max).exp().mean().log()
             # loss_alpha             # loss_alpha
 +            prob_a_old, prob_a = logp_a_old.exp(), logp_a.exp()
             # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')             # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')
-            kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) +            kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1)  
 +            kld = prob_a_old * (logp_a_old - logp_a) - prob_a_old + prob_a 
             loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()             loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()
             # total_loss             # total_loss
-            loss = loss_v + loss_pi + loss_eta + loss_alpha+            loss = loss_v * args.value_coef + loss_pi + loss_eta + loss_alpha
  
             optimizer.zero_grad()             optimizer.zero_grad()
줄 214: 줄 237:
  
             with torch.no_grad():             with torch.no_grad():
-                eta.fill_(eta.clamp(min=1e-6, max=1e6)) +                eta와 alpha는 반드시 0 이상 
-                alpha.fill_(alpha.clamp(min=1e-6, max=1e6))+                eta.data.copy_(eta.clamp(min=1e-6, max=1e6)) 
 +                alpha.data.copy_(alpha.clamp(min=1e-6, max=1e6)) 
 + 
 +        # target 모델 교체 
 +        model_old.load_state_dict(model.state_dict())
  
         # 학습결과 출력         # 학습결과 출력
줄 221: 줄 248:
         mean_score = np.mean(scores_y[-n_scores_in_epoch:])         mean_score = np.mean(scores_y[-n_scores_in_epoch:])
         print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))         print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))
-        print(plotille.histogram(scores_y[-n_scores_in_epoch:], bins=n_scores_in_epoch//5, height=25, color_mode='byte'))+        print(plotille.histogram( 
 +            scores_y[-n_scores_in_epoch:],  
 +            bins=n_scores_in_epoch // 5, height=25, color_mode='byte' 
 +        ))
         n_scores_in_epoch = 0         n_scores_in_epoch = 0
         print(         print(
줄 227: 줄 257:
             f'ret: {D_ret.mean():.3f}, '             f'ret: {D_ret.mean():.3f}, '
             f'v: {value.mean():.3f}, '             f'v: {value.mean():.3f}, '
-            f'ent.: {entropy:.3f}, '+            f'ent.: {entropy.item():.3f}, '
             f'kld: {kld.mean():.3f}, '             f'kld: {kld.mean():.3f}, '
             f'L_v: {loss_v.item():.3f}, '             f'L_v: {loss_v.item():.3f}, '
줄 236: 줄 266:
             f'alpha: {alpha.item():.3f}, '             f'alpha: {alpha.item():.3f}, '
         )         )
-        # target 모델 교체 +        writer.add_scalar('metric/score', mean_score, global_step=frames) 
-        model_old.load_state_dict(model.state_dict())+        writer.add_scalar('metric/ret', D_ret.mean(), global_step=frames) 
 +        writer.add_scalar('metric/v', value.mean(), global_step=frames) 
 +        writer.add_scalar('metric/ent', entropy.item(), global_step=frames) 
 +        writer.add_scalar('loss/v', loss_v.item(), global_step=frames) 
 +        writer.add_scalar('loss/pi', loss_pi.item(), global_step=frames) 
 +        writer.add_scalar('loss/eta', eta.item(), global_step=frames) 
 +        writer.add_scalar('loss/alpha', alpha.item(), global_step=frames) 
 + 
 +        if args.max_frames is not None and frames > args.max_frames: 
 +            writer.add_hparams(dict(algo='vmpo'), dict(final_score=mean_score)) 
 +            break
 </code> </code>
  
-{{tag>example "VMPO" MPO}}+{{tag>example "VMPO" MPO PPO lunar_lander}}
v-mpo/example.1602609028.txt.gz · 마지막으로 수정됨: (바깥 편집)