Интеграция в игру Q-Learning

Не получается соединить одно с другим Код сети:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import namedtuple


UP = "up"
DOWN = "down"
LEFT = "left"
RIGHT = "right"
TAKE = "take"
PASS = "pass"
PLAYER = "player"
GOLD = "gold"
WALL = "wall"
EMPTY = "empty"


# Определение архитектуры нейронной сети CNN
class QNetwork(nn.Module):
    def __init__(self, input_channels, output_size):
        super(QNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, output_size)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
# Определение параметров обучения
LEARNING_RATE = 0.001
GAMMA = 0.99
BATCH_SIZE = 64
REPLAY_MEMORY_SIZE = 10000
TARGET_UPDATE_FREQUENCY = 100
NOISE_SCALE = 0.1

# Определение типа namedtuple для хранения состояний
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# Инициализация реплей буфера
replay_buffer = []
# Функция для добавления transition в реплей буфер
def add_to_replay_buffer(transition):
    replay_buffer.append(transition)
    if len(replay_buffer) > REPLAY_MEMORY_SIZE:
        replay_buffer.pop(0)
# Функция для выбора действия с использованием epsilon-greedy стратегии с добавлением шума
def select_action(state, policy_net, epsilon):
    with torch.no_grad():
        if random.random() < epsilon:
            action = random.choice([UP, DOWN, LEFT, RIGHT])
        else:
            action = policy_net(state).argmax().item()
            action += np.random.normal(0, NOISE_SCALE)  # добавляем шум к выбранному действию
            action = min(max(action, 0), 3)  # ограничиваем действие в пределах доступных
    return action
# Функция для обновления параметров сети на основе мини-пакета transition'ов с dropout
def optimize_model(policy_net, target_net, optimizer):
    if len(replay_buffer) < BATCH_SIZE:
        return
    transitions = random.sample(replay_buffer, BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    loss = nn.MSELoss()(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# Функция обучения нейронной сети с использованием Deep Q-Learning
def train(policy_net, target_net, optimizer, env, num_episodes):
    for episode in range(num_episodes):
        state = env.reset()
        state = process_state(state)
        total_reward = 0
        for t in range(1000):  # ограничим число шагов в эпизоде, чтобы избежать бесконечного цикла
            epsilon = max(0.01, 0.1 - 0.01*(episode/200))  # уменьшаем epsilon по мере обучения
            action = select_action(state, policy_net, epsilon)
            next_state, reward, done = env.step(action)
            next_state = process_state(next_state)
            reward = torch.tensor([reward], dtype=torch.float32)
            total_reward += reward.item()
            if not done:
                next_state = process_state(next_state)
            else:
                next_state = None
            add_to_replay_buffer(Transition(state, torch.tensor([action]), next_state, reward))
            state = next_state
            optimize_model(policy_net, target_net, optimizer)
            if done:
                break
        if episode % TARGET_UPDATE_FREQUENCY == 0:
            target_net.load_state_dict(policy_net.state_dict())  # обновляем веса target сети
# Функция для обработки состояния игры и преобразования его в тензор
def process_state(state):
    map_height, map_width = len(state["map"]), len(state["map"][0])
    processed_state = np.zeros((map_height, map_width))  # Создание матрицы для представления игрового поля
    for x in range(map_width):
        for y in range(map_height):
            if state["map"][x][y] == WALL:
                processed_state[y][x] = 1  # Значение пикселя для стены
            elif state["map"][x][y].isdigit():
                processed_state[y][x] = 0.5  # Значение пикселя для монеты 
    player_x, player_y = state["player_position"]
    processed_state[player_y][player_x] = 0.75  # Значение пикселя для позиции игрока 
    processed_state = processed_state / np.max(processed_state)  # Нормализация данных
    processed_state = np.expand_dims(processed_state, axis=0)  # Добавление размерности канала изображения
    return torch.tensor(processed_state, dtype=torch.float32)  # Преобразование в тензор PyTorch

Код игры:

import time
import sys
import json
from importlib import import_module
from pathlib import Path
from random import randrange, shuffle
import tkinter as tk
from plitk import load_tileset, PliTk
from QNetwork import *

SCALE = 1
DELAY = 50

UP = "up"
DOWN = "down"
LEFT = "left"
RIGHT = "right"
TAKE = "take"
PASS = "pass"
PLAYER = "player"
GOLD = "gold"
WALL = "wall"
EMPTY = "empty"


class Board:
    def __init__(self, game, canvas, label):
        self.game = game
        self.canvas = canvas
        self.label = label
        self.tileset = load_tileset(game["tileset"])
        self.screen = PliTk(canvas, 0, 0, 0, 0, self.tileset, SCALE)

    def load_players(self):
        self.players = []
        for i, name in enumerate(self.game["players"]):
            script = import_module(name).script
            tile = self.game["tiles"]["@"][i]
            self.players.append(Player(name, script, self, tile, create_nn()))
        shuffle(self.players)

    def load_level(self):
        self.gold = 0
        self.steps = 0
        self.level = self.game["levels"][self.level_index]
        data = self.game["maps"][self.level["map"]]
        cols, rows = len(data[0]), len(data)
        self.map = [[data[y][x] for y in range(rows)] for x in range(cols)]
        self.has_player = [[None for y in range(rows)] for x in range(cols)]
        self.canvas.config(width=cols * self.tileset["tile_width"] * SCALE,
                           height=rows * self.tileset["tile_height"] * SCALE)
        self.level["gold"] = sum(sum(int(cell)
            if cell.isdigit() else 0 for cell in row) for row in data)
        self.screen.resize(cols, rows)
        for y in range(rows):
            for x in range(cols):
                self.update(x, y)
        for p in self.players:
            self.add_player(p, *self.level["start"])
        self.update_score()

    def get(self, x, y):
        if x < 0 or y < 0 or x >= self.screen.cols or y >= self.screen.rows:
            return "#"
        return self.map[x][y]

    def update(self, x, y):
        if self.has_player[x][y]:
            self.screen.set_tile(x, y, self.has_player[x][y].tile)
        else:
            self.screen.set_tile(x, y, self.game["tiles"][self.map[x][y]])

    def remove_player(self, player):
        self.has_player[player.x][player.y] = None
        self.update(player.x, player.y)

    def add_player(self, player, x, y):
        player.x, player.y = x, y
        self.has_player[x][y] = player
        self.update(x, y)

    def take_gold(self, x, y):
        self.gold += self.check("gold", x, y)
        self.map[x][y] = " "
        self.update(x, y)
        self.update_score()
    
    def get_game_state(self):
        state = {
            "map": self.map,
            "player_position": (self.players[0].x, self.players[0].y)  # Предположим, что у нас только один игрок
        }
        return state

    def check(self, cmd, *args):
        if cmd == "level":
            return self.level_index + 1
        x, y = args
        item = self.get(x, y)
        if cmd == "wall":
            return item == "#"
        if cmd == "gold":
            return int(item) if item.isdigit() else 0
        if cmd == "player":
            return item != "#" and self.has_player[x][y]
#############################################################################
    def play(self):
        for p in self.players:
            p.act()  # Изменение этой строки для вызова функции act без аргументов
            if self.gold >= self.level["gold"]:
                return self.select_next_level()
        self.steps += 1
        return self.steps < self.level["steps"]

    def update_score(self):
        lines = [("Level:%4d\n" % (self.level_index + 1))]
        players = sorted(self.players, key=lambda x: x.gold, reverse=True)
        for p in players:
            lines.append("%s:%4d" % (p.name, p.gold))
        self.label["text"] = "\n".join(lines)

    def select_next_level(self):
        self.level_index += 1
        if self.level_index < len(self.game["levels"]):
            self.load_level()
            return True
        return False


class Player:
    def __init__(self, name, script, board, tile, policy_net):
        self.name = name
        self.script = script
        self.board = board
        self.tile = tile
        self.x, self.y = 0, 0
        self.gold = 0
        self.policy_net = policy_net  # Передаем нейронную сеть в конструктор

    def act(self):
        state = self.get_state()  # Получаем текущее состояние игры
        with torch.no_grad():
            action = self.policy_net(state).argmax().item()  # Предсказываем действие с помощью нейронной сети
        dx, dy = 0, 0
        if action == 0:
            dy -= 1
        elif action == 1:
            dy += 1
        elif action == 2:
            dx -= 1
        elif action == 3:
            dx += 1
        elif action == 4:
            self.take()
        self.move(dx, dy)

    def get_state(self):
        state = self.board.get_game_state()
        return process_state(state)

    def move(self, dx, dy):
        x, y = self.x + dx, self.y + dy
        board = self.board
        board.remove_player(self)
        if not board.check("wall", x, y) and not board.check("player", x, y):
            self.x, self.y = x, y
        board.add_player(self, self.x, self.y)

    def take(self):
        gold = self.board.check("gold", self.x, self.y)
        if gold:
            self.gold += gold
            self.board.take_gold(self.x, self.y)


def start_game():
    def update():
        t = time.time()
        if board.play():
            dt = int((time.time() - t) * 1000)
            root.after(max(DELAY - dt, 0), update)
        else:
            label["text"] += "\n\nGAME OVER!"

    root = tk.Tk()
    root.configure(background="black")
    canvas = tk.Canvas(root, bg="black", highlightthickness=0)
    canvas.pack(side=tk.LEFT)
    label = tk.Label(root, font=("TkFixedFont",),
                     justify=tk.RIGHT, fg="white", bg="gray20")
    label.pack(side=tk.RIGHT, anchor="n")
    filename = sys.argv[1] if len(sys.argv) == 2 else "game.json"
    game = json.loads(Path(filename).read_text())
    board = Board(game, canvas, label)
    root.after(0, update)
    root.mainloop()
    create_nn(board)

def create_nn(board):
    # Создание нейросети
    input_channels = 1  # Так как мы используем одноканальное изображение после обработки состояния
    output_size = 5  # Количество действий
    policy_net = QNetwork(input_channels, output_size)
    target_net = QNetwork(input_channels, output_size)  # Инициализация целевой сети

    # Задание начальных весов целевой сети такими же, как и у основной сети
    target_net.load_state_dict(policy_net.state_dict())

    # Определение оптимизатора
    optimizer = torch.optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)

    # Запуск процесса обучения
    num_episodes = 1000  # Количество эпизодов для обучения
    train(policy_net, target_net, optimizer, board, num_episodes)

start_game()

Ответы (0 шт):