사용자 도구

사이트 도구


sad

문서의 이전 판입니다!


SAD (Simplified Action Decoder)

 
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from IPython import embed
 
 
# payoff values
payoff_values = [              
      [[[10, 0, 0], 
        [4, 8, 4], 
        [10, 0, 0]],
       [[0, 0, 10], 
        [4, 8, 4], 
        [0, 0, 10]]],
      [[[0, 0, 10], 
        [4, 8, 4], 
        [0, 0, 0]],
       [[10, 0, 0], 
        [4, 8, 4], 
        [10, 0, 0]]]
  ]
payoff_values = np.array(payoff_values)
 
n_cards = 2
n_actions = 3
bs = 32
 
 
if __name__ == '__main__':
 
    seed = 42
    vdn = 0
    final_epsilon = 0.05
 
    n_runs = 50
    n_episodes = 50000
    n_readings = 25
 
    np.random.seed(seed)
    torch.manual_seed(seed)
 
 
    mode_labels = ['','', 'IQL', '', 'SAD', '' ]
 
    all_r = np.zeros((6, n_runs, n_readings + 1))
    interval = n_episodes // n_readings
 
    for n_r in range(n_runs):
        print('run ', n_r, ' out of ', n_runs)
        for bad_mode in [2, 4]:
            print('Running for ', mode_labels[bad_mode])
            input_size = n_cards * n_actions ** 2
            net0 = nn.Sequential(
                nn.Linear(input_size, 32),
                nn.ReLU(),
                nn.Linear(32, n_actions),
            )
            net1 = nn.Sequential(
                nn.Linear(input_size, 32),
                nn.ReLU(),
                nn.Linear(32, n_actions),
            )
            optimizer = optim.Adam(
                [
                    {'params': net0.parameters()},
                    {'params': net1.parameters()},
                ],
                lr=0.001,
            )
            greedy = 1 if bad_mode > 3 else 0
 
            for j in range(n_episodes+1):
                cards_0 = np.random.choice(n_cards, size=(bs))
                cards_1 = np.random.choice(n_cards, size=(bs))
                eps = 0
                if j % (interval) != 0:
                    eps = max(final_epsilon, 1 - 2 * j / n_episodes)
 
                with torch.no_grad():
                    joint_in1 = cards_0 * n_actions**2
                    input_0 = np.eye(input_size)[joint_in1]
                    input_0 = torch.from_numpy(input_0).float()
                    q_vals = net0(input_0)
                    qv0, qv0_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv0_i]
                    u0 = probs.multinomial(1).view(-1)
                    u0_greedy = probs.max(1)[1]
                    q0 = (q_vals * torch.eye(n_actions)[u0]).sum(1)
                    # q0_greedy = (q_vals * torch.eye(n_actions)[u0_greedy]).sum(1)
 
                    joint_in1 = cards_1 * n_actions**2 + \
                                u0.numpy() * n_actions + \
                                u0_greedy.numpy() * greedy
                    input_1 = np.eye(input_size)[joint_in1]
                    input_1 = torch.from_numpy(input_1).float()
                    q_vals = net1(input_1)
                    qv1, qv1_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv1_i]
                    u1 = probs.multinomial(1).view(-1)
                    u1_greedy = probs.max(1)[1]
                    # q1 = (q_vals * torch.eye(n_actions)[u1]).sum(1)
                    q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1)
 
                    rew = [
                        payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]]
                        for i in range(bs)
                    ]
                    rew = torch.from_numpy(np.array(rew)).float()
 
                q0 = net0(input_0)[torch.arange(0, bs).long(), u0]
                q1 = net1(input_1)[torch.arange(0, bs).long(), u1]
 
                optimizer.zero_grad()
                loss = (((rew + q1_greedy * vdn) - q0) ** 2).mean() + ((rew - q1) ** 2).mean()
                loss.backward()
                optimizer.step()
 
                net0, net1 = net1, net0
 
                if eps == 0:
                    all_r[bad_mode, n_r, int(j / interval)] = rew.mean().item()
                    print(j, 'rew', rew.mean().item(), 'loss', loss.item() )
 
        mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', '' ]
        colors = ['','','#1f77b4','','#d62728']
        plt.figure(figsize=(6, 6))
        x_vals = np.arange(n_readings+1) * interval
        for bad_mode in [2, 4]:
            vals = all_r[bad_mode][:n_r+1]
            y_m = vals.mean(0)
            y_std = vals.std(0) / (n_runs**0.5)
            plt.plot(x_vals, y_m, colors[bad_mode], label=mode_labels[bad_mode])
            plt.fill_between(x_vals, y_m+y_std, y_m-y_std, alpha=0.3)
            plt.ylim([7.5, 10.3])
            plt.legend()
            None
        plt.xlabel('Epoch')
        plt.ylabel('Reward')
        plt.savefig('matrix_game.png')
        plt.clf()
sad.1591105462.txt.gz · 마지막으로 수정됨: 2024/03/23 02:37 (바깥 편집)