v-mpo:example
차이
문서의 선택한 두 판 사이의 차이를 보여줍니다.
다음 판 | 이전 판 | ||
v-mpo:example [2020/06/14 12:01] – 만듦 rex8312 | v-mpo:example [2024/03/23 02:42] (현재) – 바깥 편집 127.0.0.1 | ||
---|---|---|---|
줄 1: | 줄 1: | ||
- | ====== V-MPO: example | + | ====== |
<code python vmpo_example.py> | <code python vmpo_example.py> | ||
줄 5: | 줄 5: | ||
""" | """ | ||
+ | conda create -n ppo python=3.7 numpy ipython matplotlib swig termcolor tqdm scipy tensorboard | ||
+ | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch | ||
pip install gym[box2d] | pip install gym[box2d] | ||
pip install plotille | pip install plotille | ||
줄 10: | 줄 12: | ||
import argparse | import argparse | ||
+ | import threading | ||
from collections import deque | from collections import deque | ||
줄 15: | 줄 18: | ||
import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||
import numpy as np | import numpy as np | ||
+ | import plotille | ||
import torch | import torch | ||
import torch.nn as nn | import torch.nn as nn | ||
줄 21: | 줄 25: | ||
import tqdm | import tqdm | ||
from IPython import embed | from IPython import embed | ||
- | + | from termcolor import cprint | |
- | import | + | from torch.utils.tensorboard |
def parse_args(): | def parse_args(): | ||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||
- | parser.add_argument(' | + | parser.add_argument(' |
+ | parser.add_argument(' | ||
+ | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
줄 33: | 줄 39: | ||
parser.add_argument(' | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
+ | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
parser.add_argument(' | parser.add_argument(' | ||
- | parser.add_argument(' | + | parser.add_argument(' |
parser.add_argument(' | parser.add_argument(' | ||
+ | parser.add_argument(' | ||
return parser.parse_args() | return parser.parse_args() | ||
줄 48: | 줄 56: | ||
self.vf = nn.Linear(32, | self.vf = nn.Linear(32, | ||
self.policy = nn.Linear(32, | self.policy = nn.Linear(32, | ||
+ | |||
+ | self.apply(self._init_weights) | ||
+ | |||
+ | def _init_weights(self, | ||
+ | if isinstance(module, | ||
+ | module.weight.data.normal_(mean=0.0, | ||
+ | if isinstance(module, | ||
+ | module.bias.data.zero_() | ||
+ | elif isinstance(module, | ||
+ | module.bias.data.zero_() | ||
+ | module.weight.data.fill_(1.0) | ||
def forward(self, | def forward(self, | ||
줄 60: | 줄 79: | ||
np.random.seed(args.seed) | np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | torch.manual_seed(args.seed) | ||
+ | writer = SummaryWriter(args.logdir) | ||
# 환경 생성 | # 환경 생성 | ||
줄 66: | 줄 86: | ||
n_features = env.observation_space.shape[1] | n_features = env.observation_space.shape[1] | ||
n_actions = env.action_space[0].n | n_actions = env.action_space[0].n | ||
- | |||
- | test_env = gym.make(' | ||
- | test_env.seed(args.seed) | ||
# 모델 & 옵티마이저 생성 | # 모델 & 옵티마이저 생성 | ||
줄 76: | 줄 93: | ||
eta = torch.tensor(1.0, | eta = torch.tensor(1.0, | ||
alpha = torch.tensor(5.0, | alpha = torch.tensor(5.0, | ||
- | optimizer = optim.Adam([*model.parameters(), | + | optimizer = optim.Adam([*model.parameters(), |
# 버퍼 생성 | # 버퍼 생성 | ||
줄 84: | 줄 101: | ||
D_done = np.zeros((args.horizon, | D_done = np.zeros((args.horizon, | ||
D_value = np.zeros((args.horizon, | D_value = np.zeros((args.horizon, | ||
- | | + | |
# 학습 시작 | # 학습 시작 | ||
줄 91: | 줄 108: | ||
n_scores_in_epoch = 0 | n_scores_in_epoch = 0 | ||
scores = deque(maxlen=10000) | scores = deque(maxlen=10000) | ||
+ | eval_scores = list() | ||
+ | |||
+ | # 평가 게임 시작 | ||
+ | def eval(): | ||
+ | env = gym.make(' | ||
+ | env.seed(args.seed) | ||
+ | while True: | ||
+ | obs = env.reset() | ||
+ | score = 0 | ||
+ | while True: | ||
+ | with torch.no_grad(): | ||
+ | _, _, prob = model_old( | ||
+ | torch.from_numpy(obs).float().view(1, | ||
+ | ) | ||
+ | action = prob.multinomial(1).item() | ||
+ | obs, r, done, _ = env.step(action) | ||
+ | score += r | ||
+ | env.render() | ||
+ | if done: | ||
+ | eval_scores.append(score) | ||
+ | break | ||
+ | |||
+ | if args.eval: | ||
+ | threading.Thread(target=eval, | ||
obs_prime = env.reset() | obs_prime = env.reset() | ||
줄 116: | 줄 157: | ||
D_done[D_i] = done | D_done[D_i] = done | ||
D_value[D_i] = value.view(-1).numpy() | D_value[D_i] = value.view(-1).numpy() | ||
- | | + | |
obs = obs_prime | obs = obs_prime | ||
줄 145: | 줄 186: | ||
FD_obs = D_obs.reshape(-1, | FD_obs = D_obs.reshape(-1, | ||
FD_action = D_action.reshape(-1) | FD_action = D_action.reshape(-1) | ||
- | | + | |
FD_ret = D_ret.reshape(-1) | FD_ret = D_ret.reshape(-1) | ||
FD_adv = D_adv.reshape(-1) | FD_adv = D_adv.reshape(-1) | ||
줄 162: | 줄 203: | ||
ret = torch.from_numpy(FD_ret[sel]).float() | ret = torch.from_numpy(FD_ret[sel]).float() | ||
adv = torch.from_numpy(FD_adv[sel]).float() | adv = torch.from_numpy(FD_adv[sel]).float() | ||
- | | + | |
top_k = torch.from_numpy(FD_top_k[sel]).bool() | top_k = torch.from_numpy(FD_top_k[sel]).bool() | ||
줄 168: | 줄 209: | ||
value, logp, prob = model(obs) | value, logp, prob = model(obs) | ||
logp_a = logp.gather(1, | logp_a = logp.gather(1, | ||
- | entropy = -(prob * logp).sum(1).mean().item() | + | entropy = -(prob * logp).sum(-1).mean() |
# loss_v | # loss_v | ||
- | loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean() | + | |
+ | loss_v = F.smooth_l1_loss(value, | ||
# loss_pi | # loss_pi | ||
with torch.no_grad(): | with torch.no_grad(): | ||
- | aug_adv = (adv[top_k] / eta).exp() / (adv[top_k] / eta).exp().sum() | + | |
- | loss_pi = -(aug_adv | + | |
- | # loss_eta | + | norm_aug_adv = aug_adv / aug_adv.sum() |
- | loss_eta = eta * args.eps_eta + eta * (adv[top_k] / eta).exp().mean().log() | + | loss_pi = -(norm_aug_adv |
+ | # loss_eta | ||
+ | loss_eta = eta * args.eps_eta | ||
# loss_alpha | # loss_alpha | ||
+ | prob_a_old, prob_a = logp_a_old.exp(), | ||
# kld = F.kl_div(logp_old.detach(), | # kld = F.kl_div(logp_old.detach(), | ||
- | kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) | + | |
+ | kld = prob_a_old * (logp_a_old - logp_a) - prob_a_old + prob_a | ||
loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean() | loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean() | ||
# total_loss | # total_loss | ||
- | loss = loss_v + loss_pi + loss_eta + loss_alpha | + | loss = loss_v |
optimizer.zero_grad() | optimizer.zero_grad() | ||
줄 191: | 줄 237: | ||
with torch.no_grad(): | with torch.no_grad(): | ||
- | eta.fill_(eta.clamp(min=1e-6, | + | |
- | alpha.fill_(alpha.clamp(min=1e-6, | + | eta.data.copy_(eta.clamp(min=1e-6, |
+ | alpha.data.copy_(alpha.clamp(min=1e-6, | ||
+ | |||
+ | # target 모델 교체 | ||
+ | model_old.load_state_dict(model.state_dict()) | ||
# 학습결과 출력 | # 학습결과 출력 | ||
줄 198: | 줄 248: | ||
mean_score = np.mean(scores_y[-n_scores_in_epoch: | mean_score = np.mean(scores_y[-n_scores_in_epoch: | ||
print(plotille.scatter(scores_x, | print(plotille.scatter(scores_x, | ||
- | print(plotille.histogram(scores_y[-n_scores_in_epoch: | + | print(plotille.histogram( |
+ | | ||
+ | | ||
+ | | ||
n_scores_in_epoch = 0 | n_scores_in_epoch = 0 | ||
print( | print( | ||
줄 204: | 줄 257: | ||
f'ret: {D_ret.mean(): | f'ret: {D_ret.mean(): | ||
f'v: {value.mean(): | f'v: {value.mean(): | ||
- | f' | + | f' |
f'kld: {kld.mean(): | f'kld: {kld.mean(): | ||
f'L_v: {loss_v.item(): | f'L_v: {loss_v.item(): | ||
줄 213: | 줄 266: | ||
f' | f' | ||
) | ) | ||
- | | + | |
- | | + | writer.add_scalar(' |
+ | writer.add_scalar(' | ||
+ | writer.add_scalar(' | ||
+ | writer.add_scalar(' | ||
+ | writer.add_scalar(' | ||
+ | writer.add_scalar(' | ||
+ | | ||
- | | + | |
- | test_state = test_env.reset() | + | |
- | score = 0 | + | |
- | while True: | + | </ |
- | | + | |
- | _, _, prob = model_old( | + | |
- | torch.from_numpy(test_state).float().view(1, | + | |
- | ) | + | |
- | action | + | |
- | test_state, r, done, _ = test_env.step(action) | + | |
- | score += r | + | |
- | test_env.render() | + | |
- | | + | |
- | | + | |
- | score = 0 | + | |
- | break | + | |
- | </code> | + | {{tag>example " |
v-mpo/example.1592136096.txt.gz · 마지막으로 수정됨: (바깥 편집)