사용자 도구

사이트 도구


sad

문서의 이전 판입니다!


SAD (Simplified Action Decoder)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from IPython import embed
 
 
def repeat_tensor(tensor, repetion):
    with tf.variable_scope("rep"):
        exp_tensor = tf.expand_dims(tensor, -1)
        tensor_t = tf.tile(exp_tensor,[1] + repetion)
        tensor_r = tf.reshape(tensor_t, repetion * tf.shape(tensor) )
    return tensor_r
 
 
# payoff values
payoff_values = [              
      [[[10, 0, 0], [4, 8, 4], [10, 0, 0]],
       [[0, 0, 10], [4, 8, 4], [0, 0, 10]]],
      [[[0, 0, 10], [4, 8, 4], [0, 0, 0]],
       [[10, 0, 0], [4, 8, 4], [10, 0, 0]]]
  ]
payoff_values = np.array( payoff_values )
 
n_cards = 2
n_actions = 3
bs = 32
 
 
if __name__ == '__main__':
 
    seed = 42
    vdn = 0
    final_epsilon = 0.05
 
    n_runs = 20
    n_episodes = 100000
    n_readings = 100
 
    np.random.seed(seed)
    torch.manual_seed(seed)
 
 
    mode_labels = ['','', 'IQL', '', 'SAD', '' ]
 
    all_r = np.zeros((6, n_runs, n_readings + 1))
    interval = n_episodes // n_readings
 
    for bad_mode in [2, 4]:
        print('Running for ', mode_labels[bad_mode])
        for n_r in range(n_runs):
            net0 = nn.Linear(n_cards, n_actions)
            input_size_1 = n_cards * n_actions ** 2
            net1 = nn.Linear(input_size_1, n_actions)
            optimizer = optim.SGD(
                [
                    {'params': net0.parameters()},
                    {'params': net1.parameters()},
                ],
                lr=0.01,
            )
            greedy = 1 if bad_mode > 3 else 0
            print('run ', n_r, ' out of ', n_runs)
            for j in range(n_episodes+1):
                cards_0 = np.random.choice(n_cards, size=(bs))
                cards_1 = np.random.choice(n_cards, size=(bs))
                eps = 0
                if j % (interval) != 0:
                    eps = max(final_epsilon, 1 - 2 * j/n_episodes)
 
                with torch.no_grad():
                    input_0 = np.eye(n_cards)[cards_0]
                    input_0 = torch.from_numpy(input_0).to(torch.float32)
                    q_vals = net0(input_0)
                    qv0, qv0_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv0_i]
                    u0 = probs.multinomial(1).view(-1)
                    u0_greedy = probs.max(1)[1]
                    q0 = (q_vals * torch.eye(n_actions)[u0]).sum(1)
                    # q0_greedy = (q_vals * torch.eye(n_actions)[u0_greedy]).sum(1)
 
                    joint_in1 = cards_1 * n_actions**2 + \
                                u0.numpy() * n_actions + \
                                u0_greedy.numpy() * greedy
                    input_1 = np.eye(input_size_1)[joint_in1]
                    input_1 = torch.from_numpy(input_1).to(torch.float32)
                    q_vals = net1(input_1)
                    qv1, qv1_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv1_i]
                    u1 = probs.multinomial(1).view(-1)
                    u1_greedy = probs.max(1)[1]
                    # q1 = (q_vals * torch.eye(n_actions)[u1]).sum(1)
                    q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1)
 
                rew = [
                    payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]]
                    for i in range(bs)
                ]
                rew = torch.from_numpy(np.array(rew)).to(torch.float32)
 
                q0 = net0(input_0).gather(1, u0.view(-1, 1))
                q1 = net1(input_1).gather(1, u1.view(-1, 1))
 
                optimizer.zero_grad()
                loss = (((rew + q1_greedy * vdn) - q0) ** 2).mean() + ((rew - q1) ** 2).mean()
                loss.backward()
                optimizer.step()
 
                if eps == 0:
                    all_r[bad_mode, n_r, int( j / interval)] = rew.mean().item()
 
                if j % (n_episodes // 10) == 0:
                    print(j, 'rew', rew.mean().item(), 'loss', loss.item() )
 
    mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', '' ]
    colors = ['','','#1f77b4','','#d62728']
    plt.figure(figsize=(6, 6))
    x_vals = np.arange(n_readings+1)* interval
    for bad_mode in [2,4]:
        vals = all_r[bad_mode]
        y_m = vals.mean(0)
        y_std = vals.std(0) / ( n_runs**0.5 )
        plt.plot( x_vals, y_m, colors[bad_mode],label = mode_labels[ bad_mode ])
        plt.fill_between(x_vals, y_m + y_std ,y_m - y_std,alpha = 0.3)
        plt.ylim([7.5, 10.3])
        plt.legend()
        None
    plt.xlabel('Epoch')
    plt.ylabel('Reward')
    plt.savefig('matrix_game.png')
sad.1591053639.txt.gz · 마지막으로 수정됨: 2024/03/23 02:37 (바깥 편집)