사용자 도구

사이트 도구


v-mpo:example

문서의 이전 판입니다!


V-MPO: example

vmpo_example.py
#!/usr/bin/env python
 
"""
pip install gym[box2d]
pip install plotille
"""
 
import argparse
from collections import deque
 
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from IPython import embed
 
import plotille
 
 
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--comment', type=str, default='baseline')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--n_envs', type=int, default=8)
    parser.add_argument('--horizon', type=int, default=2048)
    parser.add_argument('--mini_batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.0005)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--lam', type=float, default=0.99)
    parser.add_argument('--max_grad', type=float, default=10.0)
    parser.add_argument('--eps_eta', type=float, default=0.01)
    parser.add_argument('--eps_alpha', type=float, default=0.01)
    return parser.parse_args()
 
 
class Model(nn.Module):
    def __init__(self, n_features, n_actions):
        super().__init__()
        self.fc1 = nn.Linear(n_features, 32)
        self.norm1 = nn.LayerNorm(32)
        self.vf = nn.Linear(32, 1)
        self.policy = nn.Linear(32, n_actions)
 
    def forward(self, x):
        x = torch.relu(self.norm1(self.fc1(x)))
        logit = self.policy(x)
        return self.vf(x), F.log_softmax(logit, -1), F.softmax(logit, -1)
 
 
if __name__ == '__main__':
 
    args = parse_args()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
 
    # 환경 생성
    env = gym.vector.make('LunarLander-v2', num_envs=args.n_envs)
    env.seed(args.seed)
    n_features = env.observation_space.shape[1]
    n_actions = env.action_space[0].n
 
    test_env = gym.make('LunarLander-v2')
    test_env.seed(args.seed)
 
    # 모델 & 옵티마이저 생성
    model = Model(n_features, n_actions)
    model_old = Model(n_features, n_actions)
    model_old.load_state_dict(model.state_dict())
    eta = torch.tensor(1.0, requires_grad=True)
    alpha = torch.tensor(5.0, requires_grad=True)
    optimizer = optim.Adam([*model.parameters(), eta, alpha], lr=args.lr, weight_decay=0.01)
 
    # 버퍼 생성
    D_obs = np.zeros((args.horizon, args.n_envs, n_features))
    D_action = np.zeros((args.horizon, args.n_envs))
    D_reward = np.zeros((args.horizon, args.n_envs))
    D_done = np.zeros((args.horizon, args.n_envs))
    D_value = np.zeros((args.horizon, args.n_envs))
    D_logp = np.zeros((args.horizon, args.n_envs, n_actions))
 
    # 학습 시작
    frames = 0
    score = np.zeros(args.n_envs)
    n_scores_in_epoch = 0
    scores = deque(maxlen=10000)
 
    obs_prime = env.reset()
 
    while True:
 
        obs = obs_prime
        for D_i in tqdm.trange(args.horizon, desc='Rollout'):
            # 게임 플레이 & 데이터 수집
            with torch.no_grad():
                value, logp, prob = model_old(torch.from_numpy(obs).float())
                action = prob.multinomial(num_samples=1).numpy().reshape(-1)
            obs_prime, reward, done, info = env.step(action)
 
            # 점수 기록
            score += reward
            scores.extend([(frames, s) for s in score[done]])
            score[done] = 0
            n_scores_in_epoch += done.sum()
 
            # 데이터 저장
            D_obs[D_i] = obs
            D_action[D_i] = action
            D_reward[D_i] = reward / 100.
            D_done[D_i] = done
            D_value[D_i] = value.view(-1).numpy()
            D_logp[D_i] = logp.numpy()
 
            obs = obs_prime
            frames += args.n_envs
 
        # 데이터 수집 완료
        D_i = 0
        # gamma 
        gamma = args.gamma * (1 - D_done)
        # return 계산
        D_ret = np.zeros((args.horizon + 1, args.n_envs))
        with torch.no_grad():
            value, _, _ = model_old(torch.from_numpy(D_obs[-1]).float())
        D_ret[-1] = value.view(-1).numpy()
        for t in reversed(range(args.horizon)):
            D_ret[t] = D_reward[t] + gamma[t] * D_ret[t+1]
        D_ret = D_ret[:-1]
        # adv 계산
        value_ = np.vstack([D_value, value.numpy().transpose(1, 0)])
        delta = D_reward + gamma * value_[1:] - value_[:-1]
        D_adv = np.zeros((args.horizon, args.n_envs))
        gae = 0
        for t in reversed(range(args.horizon)):
            gae = gae * gamma[t] * args.lam + delta[t]
            D_adv[t] = gae
 
        # batch 차원 제거
        FD_obs = D_obs.reshape(-1, n_features)
        FD_action = D_action.reshape(-1)
        FD_logp = D_logp.reshape(-1, n_actions)
        FD_ret = D_ret.reshape(-1)
        FD_adv = D_adv.reshape(-1)
        # top_k (상위 50% advantage 위치)
        FD_top_k = FD_adv > np.median(FD_adv)
 
        # 미니배치 index 준비
        idx = np.arange(args.horizon * args.n_envs)
        np.random.shuffle(idx)
        n_mini_batchs = args.horizon * args.n_envs // args.mini_batch_size
        for mb_i in tqdm.trange(n_mini_batchs, desc='Fit'):
            # 미니배치 준비
            sel = idx[mb_i * args.mini_batch_size: (mb_i+1) * args.mini_batch_size]
            obs = torch.from_numpy(FD_obs[sel]).float()
            action = torch.from_numpy(FD_action[sel]).long()
            ret = torch.from_numpy(FD_ret[sel]).float()
            adv = torch.from_numpy(FD_adv[sel]).float()
            logp_old = torch.from_numpy(FD_logp[sel]).float()
            top_k = torch.from_numpy(FD_top_k[sel]).bool()
 
            # 그래프 생성
            value, logp, prob = model(obs)
            logp_a = logp.gather(1, action.view(-1, 1)).view(-1)
            entropy = -(prob * logp).sum(1).mean().item()
 
            # loss_v
            loss_v = 0.5 * (ret - value.view(ret.shape)).pow(2).mean()
            # loss_pi
            with torch.no_grad():
                aug_adv = (adv[top_k] / eta).exp() / (adv[top_k] / eta).exp().sum()
            loss_pi = -(aug_adv * logp_a[top_k]).sum()
            # loss_eta
            loss_eta = eta * args.eps_eta + eta * (adv[top_k] / eta).exp().mean().log()
            # loss_alpha
            # kld = F.kl_div(logp_old.detach(), logp, reduction='batchmean')
            kld = (logp_old.exp() * (logp_old - logp) - logp_old.exp() + logp.exp()).sum(1) 
            loss_alpha = (alpha * (args.eps_alpha - kld.detach()) + alpha.detach() * kld).mean()
            # total_loss
            loss = loss_v + loss_pi + loss_eta + loss_alpha
 
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_([*model.parameters(), eta, alpha], args.max_grad)
            optimizer.step()
 
            with torch.no_grad():
                eta.fill_(eta.clamp(min=1e-6, max=1e6))
                alpha.fill_(alpha.clamp(min=1e-6, max=1e6))
 
        # 학습결과 출력
        scores_x, scores_y = zip(*scores)
        mean_score = np.mean(scores_y[-n_scores_in_epoch:])
        print(plotille.scatter(scores_x, scores_y, height=25, color_mode='byte'))
        print(plotille.histogram(scores_y[-n_scores_in_epoch:], bins=n_scores_in_epoch//5, height=25, color_mode='byte'))
        n_scores_in_epoch = 0
        print(
            f'{frames:,} => {mean_score:.3f}, '
            f'ret: {D_ret.mean():.3f}, '
            f'v: {value.mean():.3f}, '
            f'ent.: {entropy:.3f}, '
            f'kld: {kld.mean():.3f}, '
            f'L_v: {loss_v.item():.3f}, '
            f'L_pi: {loss_pi.item():.3f}, '
            f'L_eta: {loss_eta.item():.3f}, '
            f'L_alpha: {loss_alpha.item():.3f}, '
            f'eta: {eta.item():.3f}, '
            f'alpha: {alpha.item():.3f}, '
        )
        # target 모델 교체
        model_old.load_state_dict(model.state_dict())
 
        # 테스트 게임 플레이
        test_state = test_env.reset()
        score = 0
        while True:
            with torch.no_grad():
                _, _, prob = model_old(
                    torch.from_numpy(test_state).float().view(1, -1)
                )
            action = prob.multinomial(1).item()
            test_state, r, done, _ = test_env.step(action)
            score += r
            test_env.render()
            if done:
                print(f'test score: {score:.3f}')
                score = 0
                break
v-mpo/example.1592136096.txt.gz · 마지막으로 수정됨: (바깥 편집)