사용자 도구

사이트 도구


sad

SAD (Simplified Action Decoder)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from IPython import embed
 
 
# payoff values
payoff_values = [              
      [[[10, 0, 0], 
        [4, 8, 4], 
        [10, 0, 0]],
       [[0, 0, 10], 
        [4, 8, 4], 
        [0, 0, 10]]],
      [[[0, 0, 10], 
        [4, 8, 4], 
        [0, 0, 0]],
       [[10, 0, 0], 
        [4, 8, 4], 
        [10, 0, 0]]]
  ]
payoff_values = np.array(payoff_values)
 
n_cards = 2
n_actions = 3
bs = 32
 
 
if __name__ == '__main__':
 
    seed = 42
    vdn = 0
    final_epsilon = 0.05
 
    n_runs = 50
    n_episodes = 50000
    n_readings = 25
 
    np.random.seed(seed)
    torch.manual_seed(seed)
 
 
    mode_labels = ['','', 'IQL', '', 'SAD', '' ]
 
    all_r = np.zeros((6, n_runs, n_readings + 1))
    interval = n_episodes // n_readings
 
    for n_r in range(n_runs):
        print('run ', n_r, ' out of ', n_runs)
        for bad_mode in [2, 4]:
            print('Running for ', mode_labels[bad_mode])
            input_size = n_cards * n_actions ** 2
            net0 = nn.Sequential(
                nn.Linear(input_size, 32),
                nn.ReLU(),
                nn.Linear(32, n_actions),
            )
            net1 = nn.Sequential(
                nn.Linear(input_size, 32),
                nn.ReLU(),
                nn.Linear(32, n_actions),
            )
            optimizer = optim.Adam(
                [
                    {'params': net0.parameters()},
                    {'params': net1.parameters()},
                ],
                lr=0.001,
            )
            greedy = 1 if bad_mode > 3 else 0
 
            for j in range(n_episodes+1):
                cards_0 = np.random.choice(n_cards, size=(bs))
                cards_1 = np.random.choice(n_cards, size=(bs))
                eps = 0
                if j % (interval) != 0:
                    eps = max(final_epsilon, 1 - 2 * j / n_episodes)
 
                with torch.no_grad():
                    joint_in1 = cards_0 * n_actions**2
                    input_0 = np.eye(input_size)[joint_in1]
                    input_0 = torch.from_numpy(input_0).float()
                    q_vals = net0(input_0)
                    qv0, qv0_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv0_i]
                    u0 = probs.multinomial(1).view(-1)
                    u0_greedy = probs.max(1)[1]
                    q0 = (q_vals * torch.eye(n_actions)[u0]).sum(1)
                    # q0_greedy = (q_vals * torch.eye(n_actions)[u0_greedy]).sum(1)
 
                    joint_in1 = cards_1 * n_actions**2 + \
                                u0.numpy() * n_actions + \
                                u0_greedy.numpy() * greedy
                    input_1 = np.eye(input_size)[joint_in1]
                    input_1 = torch.from_numpy(input_1).float()
                    q_vals = net1(input_1)
                    qv1, qv1_i = q_vals.max(1)
                    probs = q_vals * 0 + eps / n_actions
                    probs += (1 - eps) * torch.eye(n_actions)[qv1_i]
                    u1 = probs.multinomial(1).view(-1)
                    u1_greedy = probs.max(1)[1]
                    # q1 = (q_vals * torch.eye(n_actions)[u1]).sum(1)
                    q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1)
 
                    rew = [
                        payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]]
                        for i in range(bs)
                    ]
                    rew = torch.from_numpy(np.array(rew)).float()
 
                q0 = net0(input_0)[torch.arange(0, bs).long(), u0]
                q1 = net1(input_1)[torch.arange(0, bs).long(), u1]
 
                optimizer.zero_grad()
                loss = (((rew + q1_greedy * vdn) - q0) ** 2).mean() + ((rew - q1) ** 2).mean()
                loss.backward()
                optimizer.step()
 
                net0, net1 = net1, net0
 
                if eps == 0:
                    all_r[bad_mode, n_r, int(j / interval)] = rew.mean().item()
                    print(j, 'rew', rew.mean().item(), 'loss', loss.item() )
 
        mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', '' ]
        colors = ['','','#1f77b4','','#d62728']
        plt.figure(figsize=(6, 6))
        x_vals = np.arange(n_readings+1) * interval
        for bad_mode in [2, 4]:
            vals = all_r[bad_mode][:n_r+1]
            y_m = vals.mean(0)
            y_std = vals.std(0) / (n_runs**0.5)
            plt.plot(x_vals, y_m, colors[bad_mode], label=mode_labels[bad_mode])
            plt.fill_between(x_vals, y_m+y_std, y_m-y_std, alpha=0.3)
            plt.ylim([7.5, 10.3])
            plt.legend()
            None
        plt.xlabel('Epoch')
        plt.ylabel('Reward')
        plt.savefig('matrix_game.png')
        plt.clf()
sad.txt · 마지막으로 수정됨: 2024/03/23 02:38 저자 127.0.0.1