sad
SAD (Simplified Action Decoder)
import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim from IPython import embed # payoff values payoff_values = [ [[[10, 0, 0], [4, 8, 4], [10, 0, 0]], [[0, 0, 10], [4, 8, 4], [0, 0, 10]]], [[[0, 0, 10], [4, 8, 4], [0, 0, 0]], [[10, 0, 0], [4, 8, 4], [10, 0, 0]]] ] payoff_values = np.array(payoff_values) n_cards = 2 n_actions = 3 bs = 32 if __name__ == '__main__': seed = 42 vdn = 0 final_epsilon = 0.05 n_runs = 50 n_episodes = 50000 n_readings = 25 np.random.seed(seed) torch.manual_seed(seed) mode_labels = ['','', 'IQL', '', 'SAD', '' ] all_r = np.zeros((6, n_runs, n_readings + 1)) interval = n_episodes // n_readings for n_r in range(n_runs): print('run ', n_r, ' out of ', n_runs) for bad_mode in [2, 4]: print('Running for ', mode_labels[bad_mode]) input_size = n_cards * n_actions ** 2 net0 = nn.Sequential( nn.Linear(input_size, 32), nn.ReLU(), nn.Linear(32, n_actions), ) net1 = nn.Sequential( nn.Linear(input_size, 32), nn.ReLU(), nn.Linear(32, n_actions), ) optimizer = optim.Adam( [ {'params': net0.parameters()}, {'params': net1.parameters()}, ], lr=0.001, ) greedy = 1 if bad_mode > 3 else 0 for j in range(n_episodes+1): cards_0 = np.random.choice(n_cards, size=(bs)) cards_1 = np.random.choice(n_cards, size=(bs)) eps = 0 if j % (interval) != 0: eps = max(final_epsilon, 1 - 2 * j / n_episodes) with torch.no_grad(): joint_in1 = cards_0 * n_actions**2 input_0 = np.eye(input_size)[joint_in1] input_0 = torch.from_numpy(input_0).float() q_vals = net0(input_0) qv0, qv0_i = q_vals.max(1) probs = q_vals * 0 + eps / n_actions probs += (1 - eps) * torch.eye(n_actions)[qv0_i] u0 = probs.multinomial(1).view(-1) u0_greedy = probs.max(1)[1] q0 = (q_vals * torch.eye(n_actions)[u0]).sum(1) # q0_greedy = (q_vals * torch.eye(n_actions)[u0_greedy]).sum(1) joint_in1 = cards_1 * n_actions**2 + \ u0.numpy() * n_actions + \ u0_greedy.numpy() * greedy input_1 = np.eye(input_size)[joint_in1] input_1 = torch.from_numpy(input_1).float() q_vals = net1(input_1) qv1, qv1_i = q_vals.max(1) probs = q_vals * 0 + eps / n_actions probs += (1 - eps) * torch.eye(n_actions)[qv1_i] u1 = probs.multinomial(1).view(-1) u1_greedy = probs.max(1)[1] # q1 = (q_vals * torch.eye(n_actions)[u1]).sum(1) q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1) rew = [ payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]] for i in range(bs) ] rew = torch.from_numpy(np.array(rew)).float() q0 = net0(input_0)[torch.arange(0, bs).long(), u0] q1 = net1(input_1)[torch.arange(0, bs).long(), u1] optimizer.zero_grad() loss = (((rew + q1_greedy * vdn) - q0) ** 2).mean() + ((rew - q1) ** 2).mean() loss.backward() optimizer.step() net0, net1 = net1, net0 if eps == 0: all_r[bad_mode, n_r, int(j / interval)] = rew.mean().item() print(j, 'rew', rew.mean().item(), 'loss', loss.item() ) mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', '' ] colors = ['','','#1f77b4','','#d62728'] plt.figure(figsize=(6, 6)) x_vals = np.arange(n_readings+1) * interval for bad_mode in [2, 4]: vals = all_r[bad_mode][:n_r+1] y_m = vals.mean(0) y_std = vals.std(0) / (n_runs**0.5) plt.plot(x_vals, y_m, colors[bad_mode], label=mode_labels[bad_mode]) plt.fill_between(x_vals, y_m+y_std, y_m-y_std, alpha=0.3) plt.ylim([7.5, 10.3]) plt.legend() None plt.xlabel('Epoch') plt.ylabel('Reward') plt.savefig('matrix_game.png') plt.clf()
sad.txt · 마지막으로 수정됨: 2024/03/23 02:38 저자 127.0.0.1