사용자 도구

사이트 도구


sad

차이

문서의 선택한 두 판 사이의 차이를 보여줍니다.

차이 보기로 링크

다음 판
이전 판
sad [2020/06/01 23:20] – 만듦 rex8312sad [2024/03/23 02:38] (현재) – 바깥 편집 127.0.0.1
줄 1: 줄 1:
 ====== SAD (Simplified Action Decoder) ====== ====== SAD (Simplified Action Decoder) ======
  
-<code python>+{{:sad:pasted:20200602-143049.png}}
  
 +  * [[https://arxiv.org/pdf/1912.02288.pdf|SIMPLIFIED ACTION DECODER FOR DEEP MULTI-AGENT REINFORCEMENT LEARNING]]
 +  * [[https://arxiv.org/pdf/1811.01458.pdf|Bayesian Action Decoder for Deep Multi-Agent Reinforcement Learning]]
 +  * [[https://arxiv.org/pdf/1709.04326.pdf|Learning with Opponent-Learning Awareness]]
 +
 +<code python>
 import numpy as np import numpy as np
 import matplotlib.pyplot as plt import matplotlib.pyplot as plt
줄 9: 줄 14:
 import torch.optim as optim import torch.optim as optim
 from IPython import embed from IPython import embed
- 
- 
-def repeat_tensor(tensor, repetion): 
-    with tf.variable_scope("rep"): 
-        exp_tensor = tf.expand_dims(tensor, -1) 
-        tensor_t = tf.tile(exp_tensor,[1] + repetion) 
-        tensor_r = tf.reshape(tensor_t, repetion * tf.shape(tensor) ) 
-    return tensor_r 
  
  
 # payoff values # payoff values
 payoff_values = [               payoff_values = [              
-      [[[10, 0, 0], [4, 8, 4], [10, 0, 0]], +      [[[10, 0, 0],  
-       [[0, 0, 10], [4, 8, 4], [0, 0, 10]]], +        [4, 8, 4],  
-      [[[0, 0, 10], [4, 8, 4], [0, 0, 0]], +        [10, 0, 0]], 
-       [[10, 0, 0], [4, 8, 4], [10, 0, 0]]]+       [[0, 0, 10],  
 +        [4, 8, 4],  
 +        [0, 0, 10]]], 
 +      [[[0, 0, 10],  
 +        [4, 8, 4],  
 +        [0, 0, 0]], 
 +       [[10, 0, 0],  
 +        [4, 8, 4],  
 +        [10, 0, 0]]]
   ]   ]
-payoff_values = np.array( payoff_values )+payoff_values = np.array(payoff_values)
  
 n_cards = 2 n_cards = 2
줄 39: 줄 44:
     final_epsilon = 0.05     final_epsilon = 0.05
  
-    n_runs = 20 +    n_runs = 50 
-    n_episodes = 100000 +    n_episodes = 50000 
-    n_readings = 100+    n_readings = 25
  
     np.random.seed(seed)     np.random.seed(seed)
줄 52: 줄 57:
     interval = n_episodes // n_readings     interval = n_episodes // n_readings
  
-    for bad_mode in [2, 4]: +    for n_r in range(n_runs): 
-        print('Running for ', mode_labels[bad_mode]) +        print('run ', n_r, ' out of ', n_runs) 
-        for n_r in range(n_runs): +        for bad_mode in [2, 4]: 
-            net0 = nn.Linear(n_cards, n_actions) +            print('Running for ', mode_labels[bad_mode]) 
-            input_size_1 = n_cards * n_actions ** 2 +            input_size = n_cards * n_actions ** 2 
-            net1 = nn.Linear(input_size_1, n_actions) +            net0 = nn.Sequential( 
-            optimizer = optim.SGD(+                nn.Linear(input_size, 32), 
 +                nn.ReLU(), 
 +                nn.Linear(32, n_actions), 
 +            ) 
 +            net1 = nn.Sequential( 
 +                nn.Linear(input_size, 32), 
 +                nn.ReLU(), 
 +                nn.Linear(32, n_actions), 
 +            
 +            optimizer = optim.Adam(
                 [                 [
                     {'params': net0.parameters()},                     {'params': net0.parameters()},
                     {'params': net1.parameters()},                     {'params': net1.parameters()},
                 ],                 ],
-                lr=0.01,+                lr=0.001,
             )             )
             greedy = 1 if bad_mode > 3 else 0             greedy = 1 if bad_mode > 3 else 0
-            print('run ', n_r, ' out of ', n_runs)+            
             for j in range(n_episodes+1):             for j in range(n_episodes+1):
                 cards_0 = np.random.choice(n_cards, size=(bs))                 cards_0 = np.random.choice(n_cards, size=(bs))
줄 72: 줄 86:
                 eps = 0                 eps = 0
                 if j % (interval) != 0:                 if j % (interval) != 0:
-                    eps = max(final_epsilon, 1 - 2 * j/n_episodes)+                    eps = max(final_epsilon, 1 - 2 * j / n_episodes)
                                  
                 with torch.no_grad():                 with torch.no_grad():
-                    input_0 = np.eye(n_cards)[cards_0+                    joint_in1 = cards_0 * n_actions**2 
-                    input_0 = torch.from_numpy(input_0).to(torch.float32)+                    input_0 = np.eye(input_size)[joint_in1
 +                    input_0 = torch.from_numpy(input_0).float()
                     q_vals = net0(input_0)                     q_vals = net0(input_0)
                     qv0, qv0_i = q_vals.max(1)                     qv0, qv0_i = q_vals.max(1)
줄 89: 줄 104:
                                 u0.numpy() * n_actions + \                                 u0.numpy() * n_actions + \
                                 u0_greedy.numpy() * greedy                                 u0_greedy.numpy() * greedy
-                    input_1 = np.eye(input_size_1)[joint_in1] +                    input_1 = np.eye(input_size)[joint_in1] 
-                    input_1 = torch.from_numpy(input_1).to(torch.float32)+                    input_1 = torch.from_numpy(input_1).float()
                     q_vals = net1(input_1)                     q_vals = net1(input_1)
                     qv1, qv1_i = q_vals.max(1)                     qv1, qv1_i = q_vals.max(1)
줄 100: 줄 115:
                     q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1)                     q1_greedy = (q_vals * torch.eye(n_actions)[u1_greedy]).sum(1)
  
-                rew = [ +                    rew = [ 
-                    payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]] +                        payoff_values[cards_0[i], cards_1[i], u0[i], u1[i]] 
-                    for i in range(bs) +                        for i in range(bs) 
-                +                    
-                rew = torch.from_numpy(np.array(rew)).to(torch.float32)+                    rew = torch.from_numpy(np.array(rew)).float()
  
-                q0 = net0(input_0).gather(1u0.view(-1, 1)+                q0 = net0(input_0)[torch.arange(0bs).long(), u0] 
-                q1 = net1(input_1).gather(1u1.view(-1, 1))+                q1 = net1(input_1)[torch.arange(0bs).long(), u1]
  
                 optimizer.zero_grad()                 optimizer.zero_grad()
줄 113: 줄 128:
                 loss.backward()                 loss.backward()
                 optimizer.step()                 optimizer.step()
 +
 +                net0, net1 = net1, net0
                                  
                 if eps == 0:                 if eps == 0:
-                    all_r[bad_mode, n_r, int( j / interval)] = rew.mean().item() +                    all_r[bad_mode, n_r, int(j / interval)] = rew.mean().item()
- +
-                if j % (n_episodes // 10) == 0:+
                     print(j, 'rew', rew.mean().item(), 'loss', loss.item() )                     print(j, 'rew', rew.mean().item(), 'loss', loss.item() )
  
-    mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', ''+        mode_labels = ['', '', 'IQL', 'IQL+aux', 'SAD', ''
-    colors = ['','','#1f77b4','','#d62728'+        colors = ['','','#1f77b4','','#d62728'
-    plt.figure(figsize=(6, 6)) +        plt.figure(figsize=(6, 6)) 
-    x_vals = np.arange(n_readings+1)* interval +        x_vals = np.arange(n_readings+1) * interval 
-    for bad_mode in [2,4]: +        for bad_mode in [2, 4]: 
-        vals = all_r[bad_mode] +            vals = all_r[bad_mode][:n_r+1
-        y_m = vals.mean(0) +            y_m = vals.mean(0) 
-        y_std = vals.std(0) / ( n_runs**0.5 ) +            y_std = vals.std(0) / (n_runs**0.5) 
-        plt.plot( x_vals, y_m, colors[bad_mode],label = mode_labels[ bad_mode ]) +            plt.plot(x_vals, y_m, colors[bad_mode], label=mode_labels[bad_mode]) 
-        plt.fill_between(x_vals, y_m + y_std ,y_m - y_std,alpha = 0.3) +            plt.fill_between(x_vals, y_m+y_std, y_m-y_std, alpha=0.3) 
-        plt.ylim([7.5, 10.3]) +            plt.ylim([7.5, 10.3]) 
-        plt.legend() +            plt.legend() 
-        None +            None 
-    plt.xlabel('Epoch'+        plt.xlabel('Epoch'
-    plt.ylabel('Reward'+        plt.ylabel('Reward'
-    plt.savefig('matrix_game.png')+        plt.savefig('matrix_game.png') 
 +        plt.clf() 
 </code> </code>
sad.1591053639.txt.gz · 마지막으로 수정됨: (바깥 편집)