사용자 도구

사이트 도구


continuous_control

차이

문서의 선택한 두 판 사이의 차이를 보여줍니다.

차이 보기로 링크

양쪽 이전 판이전 판
다음 판
이전 판
continuous_control [2020/08/20 18:25] rex8312continuous_control [2024/03/23 02:38] (현재) – 바깥 편집 127.0.0.1
줄 88: 줄 88:
  
 <code python> <code python>
 +class Model(nn.Module):
 +    def __init__(self, args, n_features, n_actions, var):
 +        super().__init__()
 + 
 +        self.action_scale = torch.FloatTensor([[
 +            0.20833333, 1.        , 1.        , 1.        , 0.25      ,
 +            1.        , 1.        , 1.        , 0.12077295, 1.        ,
 +            1.        , 1.        , 0.15923567, 0.15923567, 1.        ,
 +            1.        , 1.        , 0.07961783, 1.        , 1.        ,
 +            1.        , 0.15923567, 0.12077295, 1.        , 1.        ,
 +            1.        , 0.15923567, 0.15923567, 1.        , 1.        ,
 +            1.        , 0.10775862, 1.        , 1.        , 1.        ,
 +            0.15923567
 +        ]])
 + 
 +        vc = 4
 + 
 +        self.critic = nn.Sequential(
 +            init_params(nn.Linear(n_features, vc * 1024)),
 +            nn.LayerNorm(vc * 1024),
 +            nn.ReLU(),
 +            init_params(nn.Linear(vc * 1024, vc * 512)),
 +            nn.LayerNorm(vc * 512),
 +            nn.ReLU(),
 +            init_params(nn.Linear(vc * 512, 1), True, 0.01),
 +        )
 + 
 +        self.mean = nn.Sequential(
 +            init_params(nn.Linear(n_features, 1024)),
 +            nn.LayerNorm(1024),
 +            nn.ReLU(),
 +            init_params(nn.Linear(1024, 512)),
 +            nn.LayerNorm(512),
 +            nn.ReLU(),
 +            init_params(nn.Linear(512, n_actions), True, 0.01),
 +        )
 + 
 +        self.logstd = nn.Sequential(
 +            init_params(nn.Linear(n_features, 1024)),
 +            nn.LayerNorm(1024),
 +            nn.ReLU(),
 +            init_params(nn.Linear(1024, 512)),
 +            nn.LayerNorm(512),
 +            nn.ReLU(),
 +            init_params(nn.Linear(512, n_actions), True, np.log(var)),
 +        )
 +        self.max_logvar = np.log(1)# np.log(2 * var)
 +        self.min_logvar = np.log(1e-9)
 +        self.max_var = 1  # 2 * std
 +        self.min_var = 1e-9
 +        
 +        self.apply(self._init_weights)
 + 
 +    def forward(self, x):
 +        return self.critic(x), self.mean(x), self._var(x)
 + 
 +    def _var(self, x):
 +        logvar = self.logvar(x) 
 +        logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
 +        var = torch.exp(logvar)
 +        return var
 +        
 +    def _init_weights(self, module):
 +        if isinstance(module, (nn.Linear, nn.Embedding)):
 +            module.weight.data.normal_(mean=0.0, std=0.02)
 +            if isinstance(module, nn.Linear) and module.bias is not None:
 +                module.bias.data.zero_()
 +        elif isinstance(module, nn.LayerNorm):
 +            module.bias.data.zero_()
 +            module.weight.data.fill_(1.0)
 +
 +
 def sample_action(mu, var): def sample_action(mu, var):
     return mu + torch.randn(var.size()) * var.sqrt()     return mu + torch.randn(var.size()) * var.sqrt()
줄 110: 줄 182:
  
   - TD3: https://towardsdatascience.com/td3-learning-to-run-with-ai-40dfc512f93   - TD3: https://towardsdatascience.com/td3-learning-to-run-with-ai-40dfc512f93
 +
 +
 +{{tag>RL continuous_control action_space}}
continuous_control.1597947939.txt.gz · 마지막으로 수정됨: 2024/03/23 02:37 (바깥 편집)