Построение метрик fpr, tpr
Пытаюсь посчитать fpr, tpr, но получаю ошибку 'DeePixBiS' object has no attribute 'predict_proba' Как можно исправить?
model = DeePixBiS()
model.load_state_dict(torch.load('./DeePixBiS.pth'))
probs = model.predict_proba(X_test)
preds = probs[:,1]
fpr, tpr, threshold = metrics.roc_curve(y_test, preds)
roc_auc = metrics.auc(fpr, tpr)
import matplotlib.pyplot as plt
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
DeePixBiS
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
# import torch.sigmoid as sigmoid
class DeePixBiS(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
dense = models.densenet161(pretrained=pretrained)
features = list(dense.features.children())
self.enc = nn.Sequential(*features[:8])
self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0)
self.linear = nn.Linear(14 * 14, 1)
def forward(self, x):
enc = self.enc(x)
dec = self.dec(enc)
out_map = F.sigmoid(dec)
# print(out_map.shape)
out = self.linear(out_map.view(-1, 14 * 14))
out = F.sigmoid(out)
out = torch.flatten(out)
return out_map, out
Predict
def predict(mask, label, threshold=0.5, score_type='combined'):
with torch.no_grad():
if score_type == 'pixel':
score = torch.mean(mask, axis=(1, 2, 3))
elif score_type == 'binary':
score = label
else:
score = (torch.mean(mask, axis=(1, 2, 3)) + label) / 2
preds = (score > threshold).type(torch.FloatTensor)
return preds, score
UPD:
import sklearn.metrics as metrics
from Model import DeePixBiS
import torch
model = DeePixBiS()
model.load_state_dict(torch.load('./DeePixBiS.pth'))
probs = model.predict(X_test)
preds = probs[:,1]
fpr, tpr, threshold = metrics.roc_curve(y_test, preds)
roc_auc = metrics.auc(fpr, tpr)
model = tf.keras.model.load_model(weights_file)
# method I: plt
import matplotlib.pyplot as plt
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()