Проблема переобучения сети transformer

Признаюсь сразу, что в нейронках не силён, однако нужно сегодня защищать проект. Попросил o1 решить мне задачу, в итоге получил вот этот код:

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from typing import List, Dict

# miditok / miditoolkit
from miditok import REMI
from miditok.constants import CHORD_MAPS
from miditoolkit import MidiFile

# torchtoolkit (или другой способ сделать split)
from torchtoolkit.data import create_subsets

# Для паддинга последовательностей
from torch.nn.utils.rnn import pad_sequence
from torch import LongTensor

# Для смешанной точности и градиентного аккумулирования
from torch.cuda.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используемое устройство: {device}")

BASE_PATH = Path("C:/Users/Kitaro/Desktop/ai/maestro-v2.0.0")  # Папка с Maestro (MIDI)
TOKENS_PATH = Path("C:/Users/Kitaro/Desktop/ai/preprocessed_tokens")
TOKENS_PATH.mkdir(parents=True, exist_ok=True)

midi_paths = list(BASE_PATH.glob("**/*.mid")) + list(BASE_PATH.glob("**/*.midi"))
if not midi_paths:
    raise ValueError("MIDI-файлы не найдены. Проверьте указанный путь.")
print(f"Найдено {len(midi_paths)} MIDI-файлов.")

# Параметры для токенизатора
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 32
additional_tokens = {
    "Chord": True,
    "Rest": True,
    "Tempo": True,
    "Program": False,
    "tempo_range": (40, 250),
    "nb_tempos": 32,
    "rest_range": (2, 8),
    "chord_maps": CHORD_MAPS,
    "chord_unknown": False,
    "chord_tokens_with_root_note": True,
}
special_tokens = ["PAD", "BOS", "EOS"]

tokenizer = REMI(
    pitch_range=pitch_range,
    beat_res=beat_res,
    nb_velocities=nb_velocities,
    additional_tokens=additional_tokens,
    special_tokens=special_tokens,
)

Дальше идёт закомментированная часть.

# def tokenize_all_midi_files(BASE_PATH: Path, tokenizer: REMI, tokens_path: Path):
#     # Собираем все .mid и .midi из подпапок
#     midi_paths = list(BASE_PATH.glob("**/*.mid")) + list(BASE_PATH.glob("**/*.midi"))
#     if not midi_paths:
#         raise ValueError("MIDI-файлы не найдены. Проверьте указанный путь.")
#     print(f"Найдено {len(midi_paths)} MIDI-файлов для токенизации.")

#     for midi_file in tqdm(midi_paths, desc="Токенизация"):
#         try:
#             midi = MidiFile(str(midi_file))
#             tokens = tokenizer.midi_to_tokens(midi)

#             # Преобразуем токены в обычный список для JSON
#             # REMI.midi_to_tokens() возвращает список из нескольких треков (обычно 1),
#             # поэтому [token.tolist() for token in tokens].
#             tokens_list = [track.ids for track in tokens]


#             # Сохраняем в JSON
#             output_file = tokens_path / f"{midi_file.stem}.json"
#             with open(output_file, "w", encoding="utf-8") as json_file:
#                 json.dump({"ids": tokens_list}, json_file)
#         except Exception as e:
#             print(f"Ошибка при чтении/токенизации файла {midi_file}: {e}")

#     print("Токенизация завершена!")

# # Раскомментируйте для запуска токенизации:
# tokenize_all_midi_files(BASE_PATH, tokenizer, TOKENS_PATH)
class MIDIDataset(Dataset):
    def __init__(
        self,
        files_paths: List[Path],
        min_seq_len: int,
        max_seq_len: int,
        tokenizer: REMI,
        sliding_step: int = 64
    ):
        """
        :param files_paths: список путей к JSON-файлам (токенам).
        :param min_seq_len: минимальная длина последовательности, чтобы не брать слишком короткие хвосты.
        :param max_seq_len: максимальная длина (обрезка).
        :param sliding_step: на сколько токенов сдвигать "окно" после каждого сэмпла.
        """
        self.samples = []
        for file_path in tqdm(files_paths, desc="Загрузка и нарезка токенов"):
            with open(file_path, "r", encoding="utf-8") as json_file:
                data = json.load(json_file)

            # Предполагаем, что data["ids"] это список (часто один трек)
            # Берём первый элемент data["ids"][0], если там один трек
            tokens = data["ids"][0]
            tokens_length = len(tokens)

            i = 0
            while i < tokens_length - min_seq_len:
                end_i = i + max_seq_len
                if end_i > tokens_length:
                    end_i = tokens_length
                # Выбираем срез
                sample_slice = tokens[i:end_i]
                self.samples.append(LongTensor(sample_slice))
                # "Скользящее окно"
                i += sliding_step
    def __getitem__(self, idx):
        # Простейший сценарий: input_ids = labels = один и тот же
        sample = self.samples[idx]
        return {"input_ids": sample, "labels": sample}

    def __len__(self):
        return len(self.samples)
    
tokens_files = list(TOKENS_PATH.glob("*.json"))
if not tokens_files:
    raise ValueError("Не найдено ни одного .json-файла в папке preprocessed_tokens.")

# УМЕНЬШАЕМ max_seq_len, ЧТОБЫ УПАСТЬ В ПАМЯТЬ
min_seq_len = 128
max_seq_len = 256  # <-- вместо 512
sliding_step = 64

dataset = MIDIDataset(tokens_files, min_seq_len, max_seq_len, sliding_step)
print(f"Всего сэмплов: {len(dataset)}")

# Разделяем на train/valid = 80/20
train_subset, valid_subset = create_subsets(dataset, [0.2])
print(f"Тренировочная выборка: {len(train_subset)}")
print(f"Валидационная выборка: {len(valid_subset)}")
def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    labels = [item["labels"] for item in batch]
    # Паддинг до одинаковой длины
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)
    return {"input_ids": input_ids_padded, "labels": labels_padded}
batch_size_physical = 2  # Физический размер батча (помещается в память)
accum_steps = 8           # Градиентное аккумулирование => эффективный batch ~ 16*4=64

train_loader = DataLoader(train_subset, batch_size=batch_size_physical, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_subset, batch_size=batch_size_physical, shuffle=False, collate_fn=collate_fn)
class MusicTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout_rate):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        attn_output = self.dropout(attn_output)
        out1 = self.layer_norm1(x + attn_output)

        ffn_output = self.feed_forward(out1)
        ffn_output = self.dropout(ffn_output)
        out2 = self.layer_norm2(out1 + ffn_output)
        return out2

class MusicTransformer(nn.Module):
    def __init__(self, num_classes, d_model, num_layers, num_heads, dff, dropout_rate):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.encoder_layers = nn.ModuleList([
            MusicTransformerEncoderLayer(d_model, num_heads, dff, dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)
        self.output_layer = nn.Linear(d_model, num_classes)

    def forward(self, inputs):
        # inputs: (batch, seq_len)
        x = self.embedding(inputs)
        x = self.dropout(x)
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
        logits = self.output_layer(x)  # (batch, seq_len, num_classes)
        return logits
class MusicTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout_rate):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        attn_output = self.dropout(attn_output)
        out1 = self.layer_norm1(x + attn_output)
        ffn_output = self.feed_forward(out1)
        ffn_output = self.dropout(ffn_output)
        out2 = self.layer_norm2(out1 + ffn_output)
        return out2

class MusicTransformer(nn.Module):
    def __init__(self, num_classes, d_model, num_layers, num_heads, dff, dropout_rate):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.encoder_layers = nn.ModuleList([
            MusicTransformerEncoderLayer(d_model, num_heads, dff, dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)
        self.output_layer = nn.Linear(d_model, num_classes)

    def forward(self, inputs):
        x = self.embedding(inputs)
        x = self.dropout(x)
        for layer in self.encoder_layers:
            x = layer(x)
        return self.output_layer(x)
    num_classes = 420
model = MusicTransformer(
    num_classes=num_classes,
    d_model=64,
    num_layers=2,
    num_heads=2,
    dff=256,
    dropout_rate=0.1
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()

num_epochs = 4

num_epochs = 4
validation_losses = []
train_accuracies = []

try:
    for epoch in range(num_epochs):
        # ---------- TRAIN ----------
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        optimizer.zero_grad()

        for step, batch_data in enumerate(tqdm(train_loader, desc=f"[Epoch {epoch+1}/{num_epochs}] Training")):
            input_ids = batch_data["input_ids"].to(device)
            labels = batch_data["labels"].to(device)

            # Смешанная точность
            with autocast():
                outputs = model(input_ids)
                loss = criterion(outputs.view(-1, num_classes), labels.view(-1))

            # backward + grad scaling
            scaler.scale(loss).backward()

            # Градиентное аккумулирование
            if (step + 1) % accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item()

            # accuracy (token-level)
            _, predicted = torch.max(outputs, dim=-1)
            correct += (predicted == labels).sum().item()
            total += labels.numel()

            # Можно логи каждые N шагов
            if (step + 1) % 20000 == 0:
                acc = 100.0 * correct / total
                print(f" Step {step+1}, Loss: {loss.item():.4f}, Accuracy: {acc:.2f}%")

        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100.0 * correct / total
        train_accuracies.append(epoch_train_acc)
        print(f"[Epoch {epoch+1}] Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%")

        # ---------- VALID ----------
        model.eval()
        valid_loss_val = 0.0
        valid_correct = 0
        valid_total = 0

        with torch.no_grad():
            for batch_data in tqdm(valid_loader, desc="Validation", leave=False):
                input_ids = batch_data["input_ids"].to(device)
                labels = batch_data["labels"].to(device)

                with autocast():
                    outputs = model(input_ids)
                    loss = criterion(outputs.view(-1, num_classes), labels.view(-1))

                valid_loss_val += loss.item()

                _, predicted = torch.max(outputs, dim=-1)
                valid_correct += (predicted == labels).sum().item()
                valid_total += labels.numel()

        avg_valid_loss = valid_loss_val / len(valid_loader)
        validation_losses.append(avg_valid_loss)

        valid_acc = 100.0 * valid_correct / valid_total
        print(f"[Epoch {epoch+1}] Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_acc:.2f}%\n")

except KeyboardInterrupt:
    print("Обучение прервано вручную!")
    torch.save(model.state_dict(), "trained_model_interrupted.pth")
    print("Сохранена модель trained_model_interrupted.pth")
    plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(validation_losses) + 1), validation_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Validation Loss Over Epochs")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label="Train Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Train Accuracy Over Epochs")
plt.legend()

plt.tight_layout()
plt.show()
model_save_path = "trained_model_final.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Модель сохранена в {model_save_path}")

Закомментированная часть

# Если захотите загрузить потом в другом месте:
model_loaded = MusicTransformer(num_classes, d_model, num_layers, num_heads, dff, dropout_rate)
model_loaded.load_state_dict(torch.load("trained_model_final.pth"))
model_loaded.eval()
model_loaded.to(device)
print("Модель успешно загружена!")
'''
BOS_token = tokenizer.special_tokens.index("BOS")
EOS_token = tokenizer.special_tokens.index("EOS")

def generate_music_greedy(model, seed_seq, max_length=200, min_tokens=30):
    model.eval()
    generated = seed_seq.clone().to(device)
    for step in range(max_length):
        inputs = generated.unsqueeze(0)  # (1, seq_len)
        outputs = model(inputs)          # (1, seq_len, num_classes)
        next_logits = outputs[0, -1, :]
        next_token_id = torch.argmax(next_logits, dim=-1).unsqueeze(0)  # (1,)

        generated = torch.cat([generated, next_token_id], dim=0)

        if step > min_tokens and next_token_id.item() == EOS_token:
            print("Сгенерирован EOS, завершаем генерацию.")
            break

    return generated.detach().cpu().tolist()

# Простой seed — BOS
seed_seq = torch.tensor([BOS_token], dtype=torch.long)

print("\n--- Генерация музыки (greedy) ---")
generated_ids = generate_music_greedy(model, seed_seq, max_length=300, min_tokens=50)
print("Сгенерированы токены (пример первых 20):", generated_ids[:20])
print(f"Всего токенов сгенерировано: {len(generated_ids)}")

# Преобразуем в MIDI
generated_midi = tokenizer.tokens_to_midi([generated_ids])
out_midi_path = "generated_music.mid"
generated_midi.dump(out_midi_path)
print(f"Сгенерированный MIDI-файл сохранён как {out_midi_path}")

1 эпоха

2 Эпоха-3 Эпоха

Датасет с кегля : https://www.kaggle.com/code/robbynevels/maestro-metadata-wav-midi-performance-events/input

Помогите, я не понимаю....


Ответы (0 шт):