Ошибка при оптимизации потерь RuntimeError: Trying to backward through the graph a second time

Я хочу оптимизировать свою модель с помощью апекса NVIDIA amp. Я добавляю масштабирование потерь в свой training_loop (), но получаю сообщение об ошибке. Я читала о подобных проблемах у других в репозитории, там рекомендуют изменить процесс расчета loss. Я попыталась добавить retain_graph=True в loss_backward() но возникает та же ошибка. Как изменить расчет loss, чтобы не возникало исключений?

loss_fn = nn.CTCLoss(reduction='mean', zero_infinity=True)
scaler = GradScaler(growth_interval=1000)

def val_loop(data_loader, model, tokenizer, device):
    acc_avg = AverageMeter()
    for images, texts, _, _ in data_loader:
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer, device)
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(f'Validation, acc: {acc_avg.avg:.4f}')
    return acc_avg.avg


def train_loop(data_loader, model, criterion, optimizer, epoch):
    loss_avg = AverageMeter()
    model.train()
    for images, texts, enc_pad_texts, text_lens in data_loader:
        model.zero_grad()
        images = images.to(DEVICE)
        batch_size = len(texts)
        output = model(images)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )

        loss = criterion(output, enc_pad_texts, output_lenghts, text_lens)
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        loss_avg.update(loss.item(), batch_size)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step() 

    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg

    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
     
    pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    
    return text_preds


def get_loaders(tokenizer, config):
    train_transforms = get_train_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    train_loader = get_data_loader(
        json_path=config['train']['json_path'],
        root_path=config['train']['root_path'],
        transforms=train_transforms,
        tokenizer=tokenizer,
        batch_size=config['train']['batch_size'],
        drop_last=True
    )
    val_transforms = get_val_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    val_loader = get_data_loader(
        transforms=val_transforms,
        json_path=config['val']['json_path'],
        root_path=config['val']['root_path'],
        tokenizer=tokenizer,
        batch_size=config['val']['batch_size'],
        drop_last=False
    )
    return train_loader, val_loader


def train(config):
    tokenizer = Tokenizer(config['alphabet'])
    os.makedirs(config['save_dir'], exist_ok=True)
    train_loader, val_loader = get_loaders(tokenizer, config)

    model = CRNN(number_class_symbols=tokenizer.get_num_chars())
    model.to(DEVICE)

    criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    optimizer = torch.optim.SGD(model.parameters(), lr= 1e-3, weight_decay=0.1)

    model.load_state_dict(torch.load("/content/drive/MyDrive/model-2-0.7323.ckpt"), strict=False)

    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    scheduler = scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        epochs=config.get('num_epochs'),
        steps_per_epoch=len(train_loader),
        max_lr=0.01,
        pct_start=0.07,
        anneal_strategy='cos',
        final_div_factor=10 ** 5
    )

    best_acc = -np.inf
    acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
    for epoch in range(config['num_epochs']):
     
        loss_avg = train_loop(train_loader, model, criterion, optimizer, epoch)
        acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
        scheduler.step(acc_avg)
        print(f"Epoch: {epoch} Loss_avg: {loss_avg} Acc_avg: {acc_avg} Step {scheduler.step(acc_avg)}" )
        if acc_avg > best_acc:
            best_acc = acc_avg
            model_save_path = os.path.join(
                config['save_dir'], f'model-{epoch}-{acc_avg:.4f}.ckpt')
            torch.save(model.state_dict(), model_save_path)
            print('Model weights saved')

<ipython-input-31-b4336415bd64> in train(config)
    126     for epoch in range(config['num_epochs']):
    127 
--> 128         loss_avg = train_loop(train_loader, model, criterion, optimizer, epoch)
    129         acc_avg = val_loop(val_loader, model, tokenizer, DEVICE)
    130         scheduler.step(acc_avg)

<ipython-input-31-b4336415bd64> in train_loop(data_loader, model, criterion, optimizer, epoch)
     40             scaled_loss.backward()
     41         loss_avg.update(loss.item(), batch_size)
---> 42         loss.backward(retain_graph=True)
     43 
     44         torch.nn.utils.clip_grad_norm_(model.parameters(), 2)

/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305             See https://pytorch.org/docs/master/distributions.html
    306 
--> 307             Instead of:
    308 
    309             probs = policy_network(state)

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    154     inputs: _TensorOrTensors,
    155     grad_outputs: Optional[_TensorOrTensors] = None,
--> 156     retain_graph: Optional[bool] = None,
    157     create_graph: bool = False,
    158     only_inputs: bool = True,

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Ответы (0 шт):