原始代码:https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
代码使用的PettingZoo游戏环境是井字棋(Tic Tac Toe)。
我在原始代码里加了保存模型的代码,以便训练结束后可以随时拿出来测试、可视化训练的成果(智能体的表现)。
训练的代码写在train.py,测试的代码(我额外写的)写在t_ttt.py。

train.py

# 用天授训练智能体
''''''
"""This is a minimal example of using Tianshou with MARL to train agents.
这是将天授与MARL一起用于训练智能体的一个最小示例。
Author: Will (https://github.com/WillDudley)

Python version used: 3.8.10

Requirements:
pettingzoo == 1.22.0
git+https://github.com/thu-ml/tianshou
"""

import os
from typing import Optional, Tuple

import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Net

from pettingzoo.classic import tictactoe_v3

import pickle # 我用pickle保存模型

def _get_agents(
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = _get_env()
    observation_space = (
        env.observation_space["observation"]
        if isinstance(env.observation_space, gym.spaces.Dict)
        else env.observation_space
    )
    if agent_learn is None:  # 这里的学习器默认用DQN算法。
        # model
        net = Net(
            state_shape=observation_space["observation"].shape
            or observation_space["observation"].n,
            action_shape=env.action_space.shape or env.action_space.n,
            hidden_sizes=[128, 128, 128, 128],
            device="cuda" if torch.cuda.is_available() else "cpu",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=1e-4)
        agent_learn = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )

    if agent_opponent is None:  # 这里的对手默认用随机策略
        agent_opponent = RandomPolicy()

    agents = [agent_opponent, agent_learn] # 这里设置了2个智能体,第0号是对手,第1号是学习器。
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents


def _get_env():
    """This function is needed to provide callables for DummyVectorEnv."""
    '''此函数是为DummyVectorEnv提供可调用程序所必需的。'''
    return PettingZooEnv(tictactoe_v3.env())


if __name__ == "__main__":
    # ======== Step 1: Environment setup =========
    # ======== 第一步:环境设置 =========
    train_envs = DummyVectorEnv([_get_env for _ in range(10)])  # 训练的环境
    test_envs = DummyVectorEnv([_get_env for _ in range(10)])  # 测试的环境

    # seed
    seed = 1
    np.random.seed(seed)
    torch.manual_seed(seed)
    train_envs.seed(seed)
    test_envs.seed(seed)

    # ======== Step 2: Agent setup =========
    # ======== 第二步:智能体设置 ========
    policy, optim, agents = _get_agents()

    # ======== Step 3: Collector setup =========
    # ======== 第三步:采集器设置 ========
    train_collector = Collector(
        policy,
        train_envs,
        VectorReplayBuffer(20_000, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=64 * 10)  # batch size * training_num

    # ======== Step 4: Callback functions setup =========
    # ======== 第四步:回调函数设置 ========
    def save_best_fn(policy):
        model_save_path = os.path.join("log", "rps", "dqn", "policy.pth")
        os.makedirs(os.path.join("log", "rps", "dqn"), exist_ok=True)
        torch.save(policy.policies[agents[1]].state_dict(), model_save_path)

    def stop_fn(mean_rewards):
        return mean_rewards >= 0.6

    def train_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.1)

    def test_fn(epoch, env_step):
        policy.policies[agents[1]].set_eps(0.05)

    def reward_metric(rews):
        return rews[:, 1]

    # ======== Step 5: Run the trainer =========
    # ======== 第五步:运行训练器 ========
    result = offpolicy_trainer(  # 这里使用了异策略的技巧
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=50,
        step_per_epoch=1000,
        step_per_collect=50,
        episode_per_test=10,
        batch_size=64,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        update_per_step=0.1,
        test_in_train=False,
        reward_metric=reward_metric,
    )

    # return result, policy.policies[agents[1]]
    print(f"\n==========Result==========\n{result}")
    print("\n(the trained policy can be accessed via policy.policies[agents[1]])")
    # 训练好的策略可以通过policy.policies[agents[1]]访问(存取)

    # 以下代码是我加上的:
    # 保存模型参数:
    torch.save(policy.policies[agents[1]].state_dict(), 'tictactoe_dqn.pth') # 这里仅保存了第1号智能体(学习器)
    

t_ttt.py

# 把训练好的智能体拿来测试
# 'tictactoe_dqn.pth'
#import torch
import tianshou as ts
from pettingzoo.classic import tictactoe_v3
import pickle

env=tictactoe_v3.env(render_mode="human")
env = ts.env.PettingZooEnv(env)

device="cuda" if torch.cuda.is_available() else "cpu"
net = Net(
    state_shape=env.observation_space["observation"].shape
                or env.observation_space["observation"].n,
    action_shape=env.action_space.shape or env.action_space.n,
    hidden_sizes=[128, 128, 128, 128],
    device=device,
).to(device)
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
p1=DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=0.9,
            estimation_step=3,
            target_update_freq=320,
        )
#加载模型参数
p1.load_state_dict(torch.load('tictactoe_dqn.pth'))

p2=ts.policy.RandomPolicy()
po=ts.policy.MultiAgentPolicyManager([p1, p2], env)

env = ts.env.DummyVectorEnv([lambda: env])

collector = ts.data.Collector(po, env)
result = collector.collect(n_episode=1, render=0.2)
print(result)

输出如下(截图不完全):
在这里插入图片描述

目前我遇到的问题是:使用Tianshou的方法【policy.load_state_dict(torch.load(‘tictactoe_dqn.pth’))】加载模型不行,总是提示没有这个函数。所以我仍然使用pickle来保存和加载模型。

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐