v-mpo:example
차이
문서의 선택한 두 판 사이의 차이를 보여줍니다.
| 양쪽 이전 판이전 판다음 판 | 이전 판 | ||
| v-mpo:example [2020/06/15 12:40] – rex8312 | v-mpo:example [2024/03/23 02:42] (현재) – 바깥 편집 127.0.0.1 | ||
|---|---|---|---|
| 줄 1: | 줄 1: | ||
| - | ====== V-MPO: example | + | ====== |
| <code python vmpo_example.py> | <code python vmpo_example.py> | ||
| 줄 5: | 줄 5: | ||
| """ | """ | ||
| + | conda create -n ppo python=3.7 numpy ipython matplotlib swig termcolor tqdm scipy tensorboard | ||
| + | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch | ||
| pip install gym[box2d] | pip install gym[box2d] | ||
| pip install plotille | pip install plotille | ||
| 줄 10: | 줄 12: | ||
| import argparse | import argparse | ||
| + | import threading | ||
| from collections import deque | from collections import deque | ||
| 줄 15: | 줄 18: | ||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||
| import numpy as np | import numpy as np | ||
| + | import plotille | ||
| import torch | import torch | ||
| import torch.nn as nn | import torch.nn as nn | ||
| 줄 21: | 줄 25: | ||
| import tqdm | import tqdm | ||
| from IPython import embed | from IPython import embed | ||
| - | + | from termcolor import cprint | |
| - | import | + | from torch.utils.tensorboard |
| def parse_args(): | def parse_args(): | ||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||
| - | parser.add_argument(' | + | parser.add_argument(' |
| + | parser.add_argument(' | ||
| + | parser.add_argument(' | ||
| parser.add_argument(' | parser.add_argument(' | ||
| parser.add_argument(' | parser.add_argument(' | ||
| - | parser.add_argument(' | + | parser.add_argument(' |
| - | parser.add_argument(' | + | parser.add_argument(' |
| - | parser.add_argument(' | + | parser.add_argument(' |
| - | parser.add_argument(' | + | parser.add_argument(' |
| + | parser.add_argument(' | ||
| parser.add_argument(' | parser.add_argument(' | ||
| parser.add_argument(' | parser.add_argument(' | ||
| - | parser.add_argument(' | + | parser.add_argument(' |
| parser.add_argument(' | parser.add_argument(' | ||
| + | parser.add_argument(' | ||
| return parser.parse_args() | return parser.parse_args() | ||
| 줄 44: | 줄 52: | ||
| def __init__(self, | def __init__(self, | ||
| super().__init__() | super().__init__() | ||
| - | self.fc1 = nn.Linear(n_features, | + | self.fc1 = nn.Linear(n_features, |
| - | self.norm1 = nn.LayerNorm(64) | + | self.norm1 = nn.LayerNorm(32) |
| - | self.vf = nn.Linear(64, 1) | + | self.vf = nn.Linear(32, 1) |
| - | self.policy = nn.Linear(64, n_actions) | + | self.policy = nn.Linear(32, n_actions) |
| + | |||
| + | self.apply(self._init_weights) | ||
| + | |||
| + | def _init_weights(self, | ||
| + | if isinstance(module, | ||
| + | module.weight.data.normal_(mean=0.0, | ||
| + | if isinstance(module, | ||
| + | module.bias.data.zero_() | ||
| + | elif isinstance(module, | ||
| + | module.bias.data.zero_() | ||
| + | module.weight.data.fill_(1.0) | ||
| def forward(self, | def forward(self, | ||
| 줄 60: | 줄 79: | ||
| np.random.seed(args.seed) | np.random.seed(args.seed) | ||
| torch.manual_seed(args.seed) | torch.manual_seed(args.seed) | ||
| + | writer = SummaryWriter(args.logdir) | ||
| # 환경 생성 | # 환경 생성 | ||
| 줄 66: | 줄 86: | ||
| n_features = env.observation_space.shape[1] | n_features = env.observation_space.shape[1] | ||
| n_actions = env.action_space[0].n | n_actions = env.action_space[0].n | ||
| - | |||
| - | test_env = gym.make(' | ||
| - | test_env.seed(args.seed) | ||
| # 모델 & 옵티마이저 생성 | # 모델 & 옵티마이저 생성 | ||
| 줄 76: | 줄 93: | ||
| eta = torch.tensor(1.0, | eta = torch.tensor(1.0, | ||
| alpha = torch.tensor(5.0, | alpha = torch.tensor(5.0, | ||
| - | optimizer = optim.Adam([*model.parameters(), | + | optimizer = optim.Adam([*model.parameters(), |
| # 버퍼 생성 | # 버퍼 생성 | ||
| 줄 84: | 줄 101: | ||
| D_done = np.zeros((args.horizon, | D_done = np.zeros((args.horizon, | ||
| D_value = np.zeros((args.horizon, | D_value = np.zeros((args.horizon, | ||
| - | | + | |
| # 학습 시작 | # 학습 시작 | ||
| 줄 91: | 줄 108: | ||
| n_scores_in_epoch = 0 | n_scores_in_epoch = 0 | ||
| scores = deque(maxlen=10000) | scores = deque(maxlen=10000) | ||
| + | eval_scores = list() | ||
| + | |||
| + | # 평가 게임 시작 | ||
| + | def eval(): | ||
| + | env = gym.make(' | ||
| + | env.seed(args.seed) | ||
| + | while True: | ||
| + | obs = env.reset() | ||
| + | score = 0 | ||
| + | while True: | ||
| + | with torch.no_grad(): | ||
| + | _, _, prob = model_old( | ||
| + | torch.from_numpy(obs).float().view(1, | ||
| + | ) | ||
| + | action = prob.multinomial(1).item() | ||
| + | obs, r, done, _ = env.step(action) | ||
| + | score += r | ||
| + | env.render() | ||
| + | if done: | ||
| + | eval_scores.append(score) | ||
| + | break | ||
| + | |||
| + | if args.eval: | ||
| + | threading.Thread(target=eval, | ||
| obs_prime = env.reset() | obs_prime = env.reset() | ||
| 줄 116: | 줄 157: | ||
| D_done[D_i] = done | D_done[D_i] = done | ||
| D_value[D_i] = value.view(-1).numpy() | D_value[D_i] = value.view(-1).numpy() | ||
| - | | + | |
| obs = obs_prime | obs = obs_prime | ||
| 줄 145: | 줄 186: | ||
| FD_obs = D_obs.reshape(-1, | FD_obs = D_obs.reshape(-1, | ||
| FD_action = D_action.reshape(-1) | FD_action = D_action.reshape(-1) | ||
| - | | + | |
| FD_ret = D_ret.reshape(-1) | FD_ret = D_ret.reshape(-1) | ||
| FD_adv = D_adv.reshape(-1) | FD_adv = D_adv.reshape(-1) | ||
| 줄 162: | 줄 203: | ||
| ret = torch.from_numpy(FD_ret[sel]).float() | ret = torch.from_numpy(FD_ret[sel]).float() | ||
| adv = torch.from_numpy(FD_adv[sel]).float() | adv = torch.from_numpy(FD_adv[sel]).float() | ||
| - | | + | |
| top_k = torch.from_numpy(FD_top_k[sel]).bool() | top_k = torch.from_numpy(FD_top_k[sel]).bool() | ||
| 줄 168: | 줄 209: | ||
| value, logp, prob = model(obs) | value, logp, prob = model(obs) | ||
| logp_a = logp.gather(1, | logp_a = logp.gather(1, | ||
| - | entropy = -(prob * logp).sum(1).mean().item() | + | entropy = -(prob * logp).sum(-1).mean() |
| # loss_v | # loss_v | ||
| - | loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean() | + | |
| + | loss_v = F.smooth_l1_loss(value, | ||
| # loss_pi | # loss_pi | ||
| with torch.no_grad(): | with torch.no_grad(): | ||
| - | aug_adv = (adv[top_k] / eta).exp() / (adv[top_k] / eta).exp().sum() | + | |
| - | loss_pi = -(aug_adv | + | |
| - | # loss_eta | + | norm_aug_adv = aug_adv / aug_adv.sum() |
| - | loss_eta = eta * args.eps_eta + eta * (adv[top_k] / eta).exp().mean().log() | + | loss_pi = -(norm_aug_adv |
| + | # loss_eta | ||
| + | loss_eta = eta * args.eps_eta | ||
| # loss_alpha | # loss_alpha | ||
| + | prob_a_old, prob_a = logp_a_old.exp(), | ||
| # kld = F.kl_div(logp_old.detach(), | # kld = F.kl_div(logp_old.detach(), | ||
| - | kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) | + | |
| + | kld = prob_a_old * (logp_a_old - logp_a) - prob_a_old + prob_a | ||
| loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean() | loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean() | ||
| # total_loss | # total_loss | ||
| - | loss = loss_v + loss_pi + loss_eta + loss_alpha | + | loss = loss_v |
| optimizer.zero_grad() | optimizer.zero_grad() | ||
| 줄 191: | 줄 237: | ||
| with torch.no_grad(): | with torch.no_grad(): | ||
| - | eta.fill_(eta.clamp(min=1e-6, | + | |
| - | alpha.fill_(alpha.clamp(min=1e-6, | + | eta.data.copy_(eta.clamp(min=1e-6, |
| + | alpha.data.copy_(alpha.clamp(min=1e-6, | ||
| + | |||
| + | # target 모델 교체 | ||
| + | model_old.load_state_dict(model.state_dict()) | ||
| # 학습결과 출력 | # 학습결과 출력 | ||
| 줄 198: | 줄 248: | ||
| mean_score = np.mean(scores_y[-n_scores_in_epoch: | mean_score = np.mean(scores_y[-n_scores_in_epoch: | ||
| print(plotille.scatter(scores_x, | print(plotille.scatter(scores_x, | ||
| - | print(plotille.histogram(scores_y[-n_scores_in_epoch: | + | print(plotille.histogram( |
| + | | ||
| + | | ||
| + | | ||
| n_scores_in_epoch = 0 | n_scores_in_epoch = 0 | ||
| print( | print( | ||
| 줄 204: | 줄 257: | ||
| f'ret: {D_ret.mean(): | f'ret: {D_ret.mean(): | ||
| f'v: {value.mean(): | f'v: {value.mean(): | ||
| - | f' | + | f' |
| f'kld: {kld.mean(): | f'kld: {kld.mean(): | ||
| f'L_v: {loss_v.item(): | f'L_v: {loss_v.item(): | ||
| 줄 213: | 줄 266: | ||
| f' | f' | ||
| ) | ) | ||
| - | | + | |
| - | | + | writer.add_scalar(' |
| + | writer.add_scalar(' | ||
| + | writer.add_scalar(' | ||
| + | writer.add_scalar(' | ||
| + | writer.add_scalar(' | ||
| + | writer.add_scalar(' | ||
| + | | ||
| - | | + | |
| - | test_state = test_env.reset() | + | |
| - | score = 0 | + | |
| - | while True: | + | </ |
| - | | + | |
| - | _, _, prob = model_old( | + | |
| - | torch.from_numpy(test_state).float().view(1, | + | |
| - | ) | + | |
| - | action | + | |
| - | test_state, r, done, _ = test_env.step(action) | + | |
| - | score += r | + | |
| - | test_env.render() | + | |
| - | | + | |
| - | | + | |
| - | score = 0 | + | |
| - | break | + | |
| - | </code> | + | {{tag>example " |
v-mpo/example.1592224833.txt.gz · 마지막으로 수정됨: (바깥 편집)