欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

maddpg学习

最编程 2024-07-29 16:11:51
...

模仿的是PARL的example修改成基于torch的模型:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MAModle:
    '''
    提供的代码是用Python编写的,并使用PyTorch库定义了一个包含行动者(Actor)和评论者(Critic)的多代理模型(MAModle)。
    MAModle类有获取代理策略和价值的方法,以及获取行动者和评论者参数的方法。
    Actor类是一个PyTorch模块,它接收一个观察值并输出一个动作。它有一个前馈神经网络,包含两个隐藏层,每层有64个神经元。如果动作是连续的,它还会输出标准偏差。
    Critic类是另一个PyTorch模块,它接收一个状态和动作并输出一个Q值。它也有一个前馈神经网络,包含两个隐藏层,每层有64个神经元。
    Actor和Critic类中的forward方法定义了神经网络的前向传播。
    '''
    def __init__(self,
                 obs_dim,
                 act_dim,
                 critic_dim,
                 continuous_actions=False
                 ):
        super(MAModle,self).__init__()
        self.actor=Actor(obs_dim,act_dim,continuous_actions)
        self.critic=Critic(critic_dim)
    
    def policy(self,obs):
        return self.actor(obs)

    def value(self,obs,act):
        return self.critic(obs,act)
    
    def get_actor_param(self):
        return self.actor.parameters()
    
    def get_critic_param(self):
        return self.critic.parameters()

# input:  agent_i_obs_dim
# output: agent_i_action_dim
class Actor(nn.Module):
    def __init__(self, obs_dim,act_dim,continuous_actions=False):
        super(Actor, self).__init__()
        self.continuous_actions=continuous_actions
        self.fc1 = nn.Linear(obs_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, act_dim)
        if continuous_actions:
            self.std = nn.Linear(64, act_dim)

    def forward(self, x):
        hid1 = F.relu(self.fc1(x))
        hid2 = F.relu(self.fc2(hid1))
        means = self.fc3(hid2)
        if self.continuous_actions:
            std=self.std(hid2)
            return (means,std)
        return means

# input:  all_obs_dim+all_action_dim
# output: 1  (Q-value)
class Critic(nn.Module):
    def __init__(self, critic_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(critic_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        x = torch.cat([state,action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_value = self.fc3(x)
        q_value = torch.squeeze(q_value,dim=1)
        return q_value


if __name__=="__main__":
    # 创建一个MAModle实例
    ma_model = MAModle(obs_dim=10, act_dim=2, critic_dim=12)

    # 创建一些模拟的观察值和动作
    obs = torch.randn(4, 10)
    act = torch.randn(4, 2)

    # 测试policy方法
    print(ma_model.policy(obs).shape)

    # 测试value方法
    print(ma_model.value(obs, act).shape)

    # 测试get_actor_param和get_critic_param方法
    # print(list(ma_model.get_actor_param()))
    # print(list(ma_model.get_critic_param()))