sad
차이
문서의 선택한 두 판 사이의 차이를 보여줍니다.
다음 판 | 이전 판 | ||
sad [2020/06/01 23:20] – 만듦 rex8312 | sad [2024/03/23 02:38] (현재) – 바깥 편집 127.0.0.1 | ||
---|---|---|---|
줄 1: | 줄 1: | ||
====== SAD (Simplified Action Decoder) ====== | ====== SAD (Simplified Action Decoder) ====== | ||
- | <code python> | + | {{: |
+ | * [[https:// | ||
+ | * [[https:// | ||
+ | * [[https:// | ||
+ | |||
+ | <code python> | ||
import numpy as np | import numpy as np | ||
import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||
줄 9: | 줄 14: | ||
import torch.optim as optim | import torch.optim as optim | ||
from IPython import embed | from IPython import embed | ||
- | |||
- | |||
- | def repeat_tensor(tensor, | ||
- | with tf.variable_scope(" | ||
- | exp_tensor = tf.expand_dims(tensor, | ||
- | tensor_t = tf.tile(exp_tensor, | ||
- | tensor_r = tf.reshape(tensor_t, | ||
- | return tensor_r | ||
# payoff values | # payoff values | ||
payoff_values = [ | payoff_values = [ | ||
- | [[[10, 0, 0], [4, 8, 4], [10, 0, 0]], | + | [[[10, 0, 0], |
- | [[0, 0, 10], [4, 8, 4], [0, 0, 10]]], | + | |
- | [[[0, 0, 10], [4, 8, 4], [0, 0, 0]], | + | |
- | [[10, 0, 0], [4, 8, 4], [10, 0, 0]]] | + | [[0, 0, 10], |
+ | | ||
+ | | ||
+ | [[[0, 0, 10], | ||
+ | | ||
+ | | ||
+ | [[10, 0, 0], | ||
+ | | ||
+ | | ||
] | ] | ||
- | payoff_values = np.array( payoff_values ) | + | payoff_values = np.array(payoff_values) |
n_cards = 2 | n_cards = 2 | ||
줄 39: | 줄 44: | ||
final_epsilon = 0.05 | final_epsilon = 0.05 | ||
- | n_runs = 20 | + | n_runs = 50 |
- | n_episodes = 100000 | + | n_episodes = 50000 |
- | n_readings = 100 | + | n_readings = 25 |
np.random.seed(seed) | np.random.seed(seed) | ||
줄 52: | 줄 57: | ||
interval = n_episodes // n_readings | interval = n_episodes // n_readings | ||
- | for bad_mode in [2, 4]: | + | |
- | print(' | + | print(' |
- | for n_r in range(n_runs): | + | |
- | net0 = nn.Linear(n_cards, n_actions) | + | print(' |
- | | + | |
- | net1 = nn.Linear(input_size_1, n_actions) | + | net0 = nn.Sequential( |
- | optimizer = optim.SGD( | + | |
+ | nn.ReLU(), | ||
+ | nn.Linear(32, n_actions), | ||
+ | | ||
+ | net1 = nn.Sequential( | ||
+ | | ||
+ | nn.ReLU(), | ||
+ | nn.Linear(32, n_actions), | ||
+ | | ||
+ | optimizer = optim.Adam( | ||
[ | [ | ||
{' | {' | ||
{' | {' | ||
], | ], | ||
- | lr=0.01, | + | lr=0.001, |
) | ) | ||
greedy = 1 if bad_mode > 3 else 0 | greedy = 1 if bad_mode > 3 else 0 | ||
- | | + | |
for j in range(n_episodes+1): | for j in range(n_episodes+1): | ||
cards_0 = np.random.choice(n_cards, | cards_0 = np.random.choice(n_cards, | ||
줄 72: | 줄 86: | ||
eps = 0 | eps = 0 | ||
if j % (interval) != 0: | if j % (interval) != 0: | ||
- | eps = max(final_epsilon, | + | eps = max(final_epsilon, |
| | ||
with torch.no_grad(): | with torch.no_grad(): | ||
- | input_0 = np.eye(n_cards)[cards_0] | + | |
- | input_0 = torch.from_numpy(input_0).to(torch.float32) | + | |
+ | input_0 = torch.from_numpy(input_0).float() | ||
q_vals = net0(input_0) | q_vals = net0(input_0) | ||
qv0, qv0_i = q_vals.max(1) | qv0, qv0_i = q_vals.max(1) | ||
줄 89: | 줄 104: | ||
u0.numpy() * n_actions + \ | u0.numpy() * n_actions + \ | ||
u0_greedy.numpy() * greedy | u0_greedy.numpy() * greedy | ||
- | input_1 = np.eye(input_size_1)[joint_in1] | + | input_1 = np.eye(input_size)[joint_in1] |
- | input_1 = torch.from_numpy(input_1).to(torch.float32) | + | input_1 = torch.from_numpy(input_1).float() |
q_vals = net1(input_1) | q_vals = net1(input_1) | ||
qv1, qv1_i = q_vals.max(1) | qv1, qv1_i = q_vals.max(1) | ||
줄 100: | 줄 115: | ||
q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1) | q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1) | ||
- | | + | |
- | payoff_values[cards_0[i], | + | payoff_values[cards_0[i], |
- | for i in range(bs) | + | for i in range(bs) |
- | ] | + | ] |
- | rew = torch.from_numpy(np.array(rew)).to(torch.float32) | + | rew = torch.from_numpy(np.array(rew)).float() |
- | q0 = net0(input_0).gather(1, u0.view(-1, 1)) | + | q0 = net0(input_0)[torch.arange(0, bs).long(), u0] |
- | q1 = net1(input_1).gather(1, u1.view(-1, 1)) | + | q1 = net1(input_1)[torch.arange(0, bs).long(), u1] |
optimizer.zero_grad() | optimizer.zero_grad() | ||
줄 113: | 줄 128: | ||
loss.backward() | loss.backward() | ||
optimizer.step() | optimizer.step() | ||
+ | |||
+ | net0, net1 = net1, net0 | ||
| | ||
if eps == 0: | if eps == 0: | ||
- | all_r[bad_mode, | + | all_r[bad_mode, |
- | + | ||
- | if j % (n_episodes // 10) == 0: | + | |
print(j, ' | print(j, ' | ||
- | | + | |
- | colors = ['','','# | + | colors = ['','','# |
- | plt.figure(figsize=(6, | + | plt.figure(figsize=(6, |
- | x_vals = np.arange(n_readings+1)* interval | + | x_vals = np.arange(n_readings+1) * interval |
- | for bad_mode in [2,4]: | + | for bad_mode in [2, 4]: |
- | vals = all_r[bad_mode] | + | vals = all_r[bad_mode][:n_r+1] |
- | y_m = vals.mean(0) | + | y_m = vals.mean(0) |
- | y_std = vals.std(0) / ( n_runs**0.5 ) | + | y_std = vals.std(0) / (n_runs**0.5) |
- | plt.plot( x_vals, y_m, colors[bad_mode], | + | plt.plot(x_vals, |
- | plt.fill_between(x_vals, | + | plt.fill_between(x_vals, |
- | plt.ylim([7.5, | + | plt.ylim([7.5, |
- | plt.legend() | + | plt.legend() |
- | None | + | None |
- | plt.xlabel(' | + | plt.xlabel(' |
- | plt.ylabel(' | + | plt.ylabel(' |
- | plt.savefig(' | + | plt.savefig(' |
+ | plt.clf() | ||
</ | </ |
sad.1591053639.txt.gz · 마지막으로 수정됨: (바깥 편집)