사용자 도구

사이트 도구


v-mpo:example

Example: V-MPO

vmpo_example.py
#!/usr/bin/env python
 
"""
conda create -n ppo python=3.7 numpy ipython matplotlib swig termcolor tqdm scipy tensorboard
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install gym[box2d]
pip install plotille
"""
 
import argparse
import threading
from collections import deque
 
import gym
import matplotlib.pyplot as plt
import numpy as np
import plotille
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from IPython import embed
from termcolor import cprint
from torch.utils.tensorboard import SummaryWriter
 
 
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--logdir', type=str)
    parser.add_argument('--max_frames', type=int)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--n_envs', type=int, default=8)
    parser.add_argument('--horizon', type=int, default=2048)
    parser.add_argument('--mini_batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.0005)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--lam', type=float, default=0.99)
    parser.add_argument('--max_grad', type=float, default=10.0)
    parser.add_argument('--eps_eta', type=float, default=1.0)
    parser.add_argument('--eps_alpha', type=float, default=0.01)
    parser.add_argument('--value_coef', type=float, default=0.5)
    return parser.parse_args()
 
 
class Model(nn.Module):
    def __init__(self, n_features, n_actions):
        super().__init__()
        self.fc1 = nn.Linear(n_features, 32)
        self.norm1 = nn.LayerNorm(32)
        self.vf = nn.Linear(32, 1)
        self.policy = nn.Linear(32, n_actions)
 
        self.apply(self._init_weights)
 
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
 
    def forward(self, x):
        x = torch.relu(self.norm1(self.fc1(x)))
        logit = self.policy(x)
        return self.vf(x), F.log_softmax(logit, -1), F.softmax(logit, -1)
 
 
if __name__ == '__main__':
 
    args = parse_args()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    writer = SummaryWriter(args.logdir)
 
    # 환경 생성
    env = gym.vector.make('LunarLander-v2', num_envs=args.n_envs)
    env.seed(args.seed)
    n_features = env.observation_space.shape[1]
    n_actions = env.action_space[0].n
 
    # 모델 & 옵티마이저 생성
    model = Model(n_features, n_actions)
    model_old = Model(n_features, n_actions)
    model_old.load_state_dict(model.state_dict())
    eta = torch.tensor(1.0, requires_grad=True)
    alpha = torch.tensor(5.0, requires_grad=True)
    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=args.weight_decay)
 
    # 버퍼 생성
    D_obs = np.zeros((args.horizon, args.n_envs, n_features))
    D_action = np.zeros((args.horizon, args.n_envs))
    D_reward = np.zeros((args.horizon, args.n_envs))
    D_done = np.zeros((args.horizon, args.n_envs))
    D_value = np.zeros((args.horizon, args.n_envs))
    D_logp_a = np.zeros((args.horizon, args.n_envs))
 
    # 학습 시작
    frames = 0
    score = np.zeros(args.n_envs)
    n_scores_in_epoch = 0
    scores = deque(maxlen=10000)
    eval_scores = list()
 
    # 평가 게임 시작
    def eval():
        env = gym.make('LunarLander-v2')
        env.seed(args.seed)
        while True:
            obs = env.reset()
            score = 0
            while True:
                with torch.no_grad():
                    _, _, prob = model_old(
                        torch.from_numpy(obs).float().view(1, -1)
                    )
                action = prob.multinomial(1).item()
                obs, r, done, _ = env.step(action)
                score += r
                env.render()
                if done:
                    eval_scores.append(score)
                    break
 
    if args.eval:
        threading.Thread(target=eval, daemon=True).start()
 
    obs_prime = env.reset()
 
    while True:
 
        obs = obs_prime
        for D_i in tqdm.trange(args.horizon, desc='Rollout'):
            # 게임 플레이 & 데이터 수집
            with torch.no_grad():
                value, logp, prob = model_old(torch.from_numpy(obs).float())
                action = prob.multinomial(num_samples=1).numpy().reshape(-1)
            obs_prime, reward, done, info = env.step(action)
 
            # 점수 기록
            score += reward
            scores.extend([(frames, s) for s in score[done]])
            score[done] = 0
            n_scores_in_epoch += done.sum()
 
            # 데이터 저장
            D_obs[D_i] = obs
            D_action[D_i] = action
            D_reward[D_i] = reward / 100.
            D_done[D_i] = done
            D_value[D_i] = value.view(-1).numpy()
            D_logp_a[D_i] = logp.numpy()[range(args.n_envs), action]
 
            obs = obs_prime
            frames += args.n_envs
 
        # 데이터 수집 완료
        D_i = 0
        # gamma 
        gamma = args.gamma * (1 - D_done)
        # return 계산
        D_ret = np.zeros((args.horizon + 1, args.n_envs))
        with torch.no_grad():
            value, _, _ = model_old(torch.from_numpy(D_obs[-1]).float())
        D_ret[-1] = value.view(-1).numpy()
        for t in reversed(range(args.horizon)):
            D_ret[t] = D_reward[t] + gamma[t] * D_ret[t+1]
        D_ret = D_ret[:-1]
        # adv 계산
        value_ = np.vstack([D_value, value.numpy().transpose(1, 0)])
        delta = D_reward + gamma * value_[1:] - value_[:-1]
        D_adv = np.zeros((args.horizon, args.n_envs))
        gae = 0
        for t in reversed(range(args.horizon)):
            gae = gae * gamma[t] * args.lam + delta[t]
            D_adv[t] = gae
 
        # batch 차원 제거
        FD_obs = D_obs.reshape(-1, n_features)
        FD_action = D_action.reshape(-1)
        FD_logp_a = D_logp_a.reshape(-1)
        FD_ret = D_ret.reshape(-1)
        FD_adv = D_adv.reshape(-1)
        # top_k (상위 50% advantage 위치)
        FD_top_k = FD_adv > np.median(FD_adv)
 
        # 미니배치 index 준비
        idx = np.arange(args.horizon * args.n_envs)
        np.random.shuffle(idx)
        n_mini_batchs = args.horizon * args.n_envs // args.mini_batch_size
        for mb_i in tqdm.trange(n_mini_batchs, desc='Fit'):
            # 미니배치 준비
            sel = idx[mb_i * args.mini_batch_size: (mb_i+1) * args.mini_batch_size]
            obs = torch.from_numpy(FD_obs[sel]).float()
            action = torch.from_numpy(FD_action[sel]).long()
            ret = torch.from_numpy(FD_ret[sel]).float()
            adv = torch.from_numpy(FD_adv[sel]).float()
            logp_a_old = torch.from_numpy(FD_logp_a[sel]).float()
            top_k = torch.from_numpy(FD_top_k[sel]).bool()
 
            # 그래프 생성
            value, logp, prob = model(obs)
            logp_a = logp.gather(1, action.view(-1, 1)).view(-1)
            entropy = -(prob * logp).sum(-1).mean()
 
            # loss_v
            # loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean()
            loss_v = F.smooth_l1_loss(value, ret.view(value.shape))
            # loss_pi
            with torch.no_grad():
                aug_adv_max = (adv[top_k] / eta).max()
                aug_adv = (adv[top_k] / eta - aug_adv_max).exp() #  / (adv[top_k] / eta - aug_adv_max).exp().sum()
                norm_aug_adv = aug_adv / aug_adv.sum()
            loss_pi = -(norm_aug_adv * logp_a[top_k]).sum()
            # loss_eta (dual func.)
            loss_eta = eta * args.eps_eta + aug_adv_max + eta * (adv[top_k] / eta - aug_adv_max).exp().mean().log()
            # loss_alpha
            prob_a_old, prob_a = logp_a_old.exp(), logp_a.exp()
            # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')
            # kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) 
            kld = prob_a_old * (logp_a_old - logp_a) - prob_a_old + prob_a 
            loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()
            # total_loss
            loss = loss_v * args.value_coef + loss_pi + loss_eta + loss_alpha
 
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_([*model.parameters(), eta, alpha], args.max_grad)
            optimizer.step()
 
            with torch.no_grad():
                # eta와 alpha는 반드시 0 이상
                eta.data.copy_(eta.clamp(min=1e-6, max=1e6))
                alpha.data.copy_(alpha.clamp(min=1e-6, max=1e6))
 
        # target 모델 교체
        model_old.load_state_dict(model.state_dict())
 
        # 학습결과 출력
        scores_x, scores_y = zip(*scores)
        mean_score = np.mean(scores_y[-n_scores_in_epoch:])
        print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))
        print(plotille.histogram(
            scores_y[-n_scores_in_epoch:], 
            bins=n_scores_in_epoch // 5, height=25, color_mode='byte'
        ))
        n_scores_in_epoch = 0
        print(
            f'{frames:,} => {mean_score:.3f}, '
            f'ret: {D_ret.mean():.3f}, '
            f'v: {value.mean():.3f}, '
            f'ent.: {entropy.item():.3f}, '
            f'kld: {kld.mean():.3f}, '
            f'L_v: {loss_v.item():.3f}, '
            f'L_pi: {loss_pi.item():.3f}, '
            f'L_eta: {loss_eta.item():.3f}, '
            f'L_alpha: {loss_alpha.item():.3f}, '
            f'eta: {eta.item():.3f}, '
            f'alpha: {alpha.item():.3f}, '
        )
        writer.add_scalar('metric/score', mean_score, global_step=frames)
        writer.add_scalar('metric/ret', D_ret.mean(), global_step=frames)
        writer.add_scalar('metric/v', value.mean(), global_step=frames)
        writer.add_scalar('metric/ent', entropy.item(), global_step=frames)
        writer.add_scalar('loss/v', loss_v.item(), global_step=frames)
        writer.add_scalar('loss/pi', loss_pi.item(), global_step=frames)
        writer.add_scalar('loss/eta', eta.item(), global_step=frames)
        writer.add_scalar('loss/alpha', alpha.item(), global_step=frames)
 
        if args.max_frames is not None and frames > args.max_frames:
            writer.add_hparams(dict(algo='vmpo'), dict(final_score=mean_score))
            break
v-mpo/example.txt · 마지막으로 수정됨: 2024/03/23 02:42 저자 127.0.0.1