Не обучается нейронная сеть играть в змейку

на днях недавно решил создать нейронную сеть на базе наследственного алгоритма, создал игру на пайгейм, создал нейронную сеть и генетический алгоритм. После запуска всё работало, но через 30-40 эпох результат змейки практически не изменился, помогите мне найти мою ошибку

import random
import keras
import time
from keras import backend as K
from tensorflow.keras.layers import Dense,Flatten, Conv2D,MaxPooling2D, LSTM, Embedding
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
import numpy
import pyautogui
from PIL import Image, ImageGrab            # библиотека базы выборок Mnist
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Dropout, Conv2D, MaxPooling2D
from PIL import Image
import pygame
POPULATION = 30
P_CROSSOVER = 0.9
P_MUTATION = 0.01

weights = []
os_balls = []
window_side = (1920, 1020)

version = int(input("::"))
def create_model(epoch,i):
    global weights
    model = keras.Sequential()
    model.add(Dense(16,activation = "relu", input_shape = (1440,)))
    model.add(Dense(16,activation='relu'))
    model.add(Dense(8,activation='relu'))
    model.add(Dense(4,activation = "softmax"))
    model.compile(optimizer = Adam(0.001), loss=("mse"))
    if epoch !=1:
        model.set_weights(weights[i])
    else:
        if version == 1:
            print(model.summary())
            model.load_weights(rf"D:\\python\\neyro_weighst{str(i)}.h5")
            print("Загрузка успешно")
        weights.append(model.get_weights())
    return model

def hod(l):
    l = l[0]
    if max(l) == l[0]:
        return 1
    elif max(l) == l[1]:
        return 2
    elif max(l) ==l[2]:
        return 3
    else:
        return 4



def game(neyro,i,lost_hod):
    balls = 0
    l_hod = 0
    global count_apple
    pygame.init()
    count_apple = 0
    id=0
    screen = pygame.display.set_mode(window_side, flags = pygame.NOFRAME)
    pygame.NOFRAME
    class Snake(pygame.sprite.Sprite):
        number = 0
        def __init__ (self,x,y, filename):
            pygame.sprite.Sprite.__init__(self)
            self.image = pygame.image.load(filename).convert_alpha()
            self.rect = self.image.get_rect(center = (x,y))
            self.number = Snake.number
            Snake.number+=1
            self.wait = Snake.number
        def update(self,moves):
            if self.wait ==1 or self.wait ==0:
                move = moves[::-1]
                self.rect.x += move[self.number][0]
                self.rect.y += move[self.number][1]
            else:
                self.wait-=1

    class Apple(pygame.sprite.Sprite):
        def __init__(self):
            pygame.sprite.Sprite.__init__(self)
            self.image = pygame.image.load("D:\\python\\apple.png").convert_alpha() #пнг файл размером 5x5, полностью красный
            l = []
            for i in snakes:
                l.append(i.rect.center)
            while True:
                x = random.randint(1,W)
                y = random.randint(1,H)
                x = x-x%5+5
                y = y-y%5+5
                if (x,y) not in l and (x >=-5 and x<=W ) and (y >=-5 and y<=H ):
                    break
            self.rect = self.image.get_rect(center = (x,y))
    def lh(last, mov):
        if mov == (-5,0):
            if last!= (5,0):
                return (-5,0)
        if mov == (5,0):
            if last!= (-5,0):
                return (5,0)
        if mov == (0,5):
            if last!= (0,-5):
                return (0,5)
        if mov == (-5,0):
            if last!= (0,5):
                return (-5,0)
        return last
    H,W = 150,300
    sc = pygame.display.set_mode((W, H))

    clock = pygame.time.Clock()
    FPS = 60
    s1 = Snake((W//2)+5,H//2+5,"D:\\python\\had.png")#пнг файл размером 5x5, полностью зелёный
    speed = 10
    kkk = [(5,0),(-5,0),(0,5),(0,-5)]
    move = kkk[random.randint(0,3)]
    lats_mov = move
    last_move = move
    snakes = pygame.sprite.Group()
    apples = pygame.sprite.Group()
    s = 0
    a1 = Apple()
    snakes.add(s1)
    apples.add(a1)
    moves = []
    def catch():
        nonlocal a1
        nonlocal moves
        nonlocal id
        nonlocal balls
        global count_apple
        nonlocal lost_hod
        if s1.rect.collidepoint(a1.rect.center):
            k = Snake(*a1.rect.center,"D:\\python\\snake.png")#пнг файл размером 5x5, полностью белый
            snakes.add(k)
            a1.kill()
            a1 = Apple()
            apples.add(a1)
            id=1
            count_apple+=1
            balls = balls*2
            lost_hod*=2
    def dead():
        nonlocal id
        if id == 0:
            count = 0
            for i in snakes:
                if count ==0:
                    count = 1
                else:
                    if s1.rect.collidepoint(i.rect.center):
                        raise Exception
        if s1.rect.y<=-2 or s1.rect.y>=H:
            raise Exception
        if s1.rect.x <=--2 or s1.rect.x>=W:
            raise Exception
    time.sleep(1)

    while l_hod != lost_hod:
        a = time.time()
        sc.fill((0,0,0))
        apples.draw(sc)
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                exit()
        img2 = ImageGrab.grab((0, 30, 300, 150))
        img22 = img2.resize((60,24))
        scra = numpy.asarray(img22.convert("L"))
        scc = []
        k = 1
        for j in scra.tolist():
            scc.extend(j)
        scr = neyro.predict([scc])
        hodd = hod(scr)

        if hodd == 1:
            move = -5,0
        elif hodd == 2:
            move = 5,0
        elif hodd == 3:
            move =  0, -5
        elif hodd == 4:
            move = 0,5
        move = lh(lats_mov, move)
        lats_mov = move
        sc.blit(s1.image, s1.rect)  
        moves.append(move)
        snakes.update(moves)
        last_move = moves[0]
        del moves[0]
        try:
            dead()
        except:
            balls-=50
            moves.insert(0,move)
            print(f"Особь номер от столкновения {i} погибла с рузельтатом {balls} ")
            break
        if id ==1:
            moves.insert(0,move)
            id=0
        catch()
        snakes.draw(sc)
        l_hod+=1
        balls+=1
        pygame.display.update()
        clock.tick(FPS)
    print(f"Особь номер {i} погибла с рузельтатом {balls} ")
    return balls



def Turnier():
    ves = []
    slava = sorted(os_balls)
    print(os_balls)
    ind = [os_balls.index(slava[-1]),os_balls.index(slava[-2])]
    ves.append(weights[ind[0]])
    ves.append(weights[ind[1]])
    while len(ves) != POPULATION:
        ind = []
        num = 0
        while len(ind) !=3:
            l = random.randint(0,POPULATION-1)
            if l not in ind:
                ind.append(l)
        mx = os_balls[ind[0]]
        num = ind[0]
        for i in ind[1:]:
            if mx < os_balls[i]:
                num = i
        ves.append(weights[num])

    print("ТУРНИР!!!")
    print(max(os_balls))
    print(sum(os_balls)/len(os_balls))
    return ves

def mutation(P_MUTATION,iteration):
    iterr = iteration
    for i in range(len(iteration)):
        if random.random() < P_MUTATION:
            minus = random.randint(0,1)
            if minus == 0:
                minus = -1

            iterr[i] = random.random()*minus
        else:
            iterr[i] = iterr[i]
    return iterr

def pr_cross(P_CROSSOVER,P_MUTATION, iter1,iter2):
    if random.random() < P_CROSSOVER:
        ind = random.randint(2,len(iter1)-2)
        iter_1 = mutation(P_MUTATION, numpy.concatenate((iter1[:ind], iter2[ind:])))
        iter_2 = mutation(P_MUTATION, numpy.concatenate((iter2[:ind], iter1[ind:])))
        return (iter_1,iter_2)
    else:
        return(iter1,iter2)
def cross():
    global weights
    for i in range(2, POPULATION,2):
        for j in range(0,len(weights[i]),2):
            for l in range(len(weights[i][j])):
                weights[i][j][l],weights[i+1][j][l]=  pr_cross(P_CROSSOVER, P_MUTATION, weights[i][j][l],weights[i+1][j][l])

            weights[i][j+1], weights[i+1][j+1] = pr_cross(P_CROSSOVER, P_MUTATION, weights[i][j+1], weights[i+1][j+1])

epoch = 1
osob = 1
lost_hod = 60
while True:
    os_balls = []
    print(f"Эпоха = {epoch}")
    for i in range(POPULATION):

        model = create_model(epoch,i)
        a = game(model,i,lost_hod)
        os_balls.append(a)
        model.save_weights(rf"D:\\python\\neyro_weighst{str(i)}.h5")

        time.sleep(1)
    time.sleep(1)
    weights = Turnier()
    cross()
    print("Кросснигровер удался!!!!!!!!!!!!")
    lost_hod+=1
    epoch+=1

Ответы (1 шт):

Автор решения: sem4ik
model.add(Dense(16,activation = "relu", input_shape = (1440,)))
    model.add(Dense(16,activation='relu'))
    model.add(Dense(8,activation='relu'))
    model.add(Dense(4,activation = "softmax"))

Вот это уберите и почитайте про обучение с подкреплением. Возможно, сделать так, как вы хотите, можно, но польза от этого будет маленькая. Эволюшки применяются тогда, когда ты сам не можешь понять, что ты хочешь от своего кода, но точно не в такой вопрос в лоб.

→ Ссылка