【深度强化学习】 PPO 模型解析,附Pytorch完整代码

warning: 这篇文章距离上次修改已过236天,其中的内容可能已经有所变动。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(num_inputs, 64)
        self.fc2 = nn.Linear(64, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_scores = self.fc2(x)
        return action_scores

    def get_action_probs(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        action_scores = self.forward(state)
        action_probs = F.softmax(action_scores, dim=1).data.squeeze()
        return action_probs

    def get_action(self, state):
        action_probs = self.get_action_probs(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item(), action_probs

class ValueNetwork(nn.Module):
    def __init__(self, num_inputs):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(num_inputs, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        state_values = self.fc2(x)
        return state_values

    def get_value(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        state_value = self.forward(state).data.squeeze().item()
        return state_value

def compute_advantages(batch_size, gamma, lambd, values, rewards, masks, next_values):
    advantages = torch.zeros(batch_size)
    adv_t = 0
    for t in reversed(range(len(rewards))):
        adv_t = rewards[t] + gamma * adv_t * masks[t]
        advantages[t] = adv_t - values[t]
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return advantages

def train_ppo(env, policy_net, value_net, optimizer_policy, optimizer_value, gamma=0.99, lambd=0.95, num_episodes=1000, batch_size=64):
    for i_episode in range(num_episodes):
        state = env.reset()
        log_probs = []
        values = []
        rewards = []
        masks = []

        for t in range(100):
            action, log_prob = policy_net.get_action_probs(state)
            next_state, reward, done, _ = env.step(action)
            log_probs.append(log_prob)
            values.append(value
Python
none
最后修改于:2024年08月16日 10:29

评论已关闭

推荐阅读

DDPG 模型解析,附Pytorch完整代码
2024年11月24日
DQN 模型解析,附Pytorch完整代码
2024年11月24日
AIGC实战——Transformer模型
2024年12月01日
Socket TCP 和 UDP 编程基础(Python)
2024年11月30日
python , tcp , udp
如何使用 ChatGPT 进行学术润色?你需要这些指令
2024年12月01日
AI
最新 Python 调用 OpenAi 详细教程实现问答、图像合成、图像理解、语音合成、语音识别(详细教程)
2024年11月24日
ChatGPT 和 DALL·E 2 配合生成故事绘本
2024年12月01日
omegaconf,一个超强的 Python 库!
2024年11月24日
【视觉AIGC识别】误差特征、人脸伪造检测、其他类型假图检测
2024年12月01日
[超级详细]如何在深度学习训练模型过程中使用 GPU 加速
2024年11月29日
Python 物理引擎pymunk最完整教程
2024年11月27日
MediaPipe 人体姿态与手指关键点检测教程
2024年11月27日
深入了解 Taipy:Python 打造 Web 应用的全面教程
2024年11月26日
基于Transformer的时间序列预测模型
2024年11月25日
Python在金融大数据分析中的AI应用(股价分析、量化交易)实战
2024年11月25日
AIGC Gradio系列学习教程之Components
2024年12月01日
Python3 `asyncio` — 异步 I/O,事件循环和并发工具
2024年11月30日
llama-factory SFT系列教程:大模型在自定义数据集 LoRA 训练与部署
2024年12月01日
Python 多线程和多进程用法
2024年11月24日
Python socket详解,全网最全教程
2024年11月27日