사용자 도구

사이트 도구


a2c

A2C

import math
import os
import sys
 
import torch
import torch.nn.functional as F
import torch.optim as optim
from envs import create_atari_env, create_car_racing_env
from model import ActorCritic
from torch.autograd import Variable
from torchvision import datasets, transforms
 
def train(args, model, env, optimizer=None):
    torch.manual_seed(args.seed)
 
    # env = create_atari_env(args.env_name)
    # env = create_car_racing_env()
    print ("env: ", env.observation_space.shape, env.action_space)
    env.seed(args.seed)
 
    model = ActorCritic(env.observation_space.shape[0], env.action_space)
 
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
 
    model.train()
 
    state = env.reset()
    print ("state: ", state.shape)
    state = torch.from_numpy(state)
    done = True
 
    episode_length = 0
    u = 0
    while u < args.num_updates:
        #print ("update: ", u)
        episode_length += 1
        # Sync with the shared model
        # model.load_state_dict(shared_model.state_dict())
        if done:
            cx = Variable(torch.zeros(1, model.lstm_size))
            hx = Variable(torch.zeros(1, model.lstm_size))
        else:
            cx = Variable(cx.data)
            hx = Variable(hx.data)
 
        values = []
        log_probs = []
        rewards = []
        entropies = []
 
        for step in range(args.num_steps):
            value, logit, (hx, cx) = model(
                (Variable(state.unsqueeze(0)), (hx, cx)))
            prob = F.softmax(logit)
            log_prob = F.log_softmax(logit)
            entropy = -(log_prob * prob).sum(1)
            entropies.append(entropy)
 
            action = prob.multinomial().data
            log_prob = log_prob.gather(1, Variable(action))
 
            state, reward, done, _ = env.step(action.numpy())
            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 1), -1)
 
            if done:
                episode_length = 0
                state = env.reset()
 
            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)
 
            if done:
                break
 
        R = torch.zeros(1, 1)
        if not done:
            value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx)))
            R = value.data
 
        values.append(Variable(R))
        policy_loss = 0
        value_loss = 0
        R = Variable(R)
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)
 
            # Generalized Advantage Estimataion
            delta_t = rewards[i] + args.gamma * \
                values[i + 1].data - values[i].data
            gae = gae * args.gamma * args.tau + delta_t
 
            policy_loss = policy_loss - \
                log_probs[i] * Variable(gae) - 0.01 * entropies[i]
 
        optimizer.zero_grad()
 
        (policy_loss + 0.5 * value_loss).backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 40)
 
        optimizer.step()
        u += 1
a2c.txt · 마지막으로 수정됨: 2024/03/23 02:38 저자 127.0.0.1