Глубокое обучение PyTorch
что я делаю не так? почему на графике, на обычных и на тестовых данных непонятно что?(только начал проходить ,не судите строго)
import torch.nn
from sklearn import datasets
from IPython.display import clear_output
import matplotlib.pyplot as plt
from collections import OrderedDict
from IPython.display import clear_output
import matplotlib.pyplot as plt
def custom_compare(x, y):
if str(x) != str(y):
raise RuntimeError(f'Ожидаемое значение: {y}. Фактическое: {x}')
def to_list(x, precision=2):
return [round(x, precision) for x in x.tolist()]
def to_list_m(m, precision=2):
res = []
for l in m.tolist():
res.append([round(x, precision) for x in l])
return res
data = datasets.load_iris()
X = torch.tensor(data["data"])
y = torch.tensor(data["target"])
model = torch.nn.Sequential(
OrderedDict([
('linear_1', torch.nn.Linear(4, 4)),
('activation_1', torch.nn.ReLU()),
('linear_2', torch.nn.Linear(4, 3)),
('log_softmax', torch.nn.LogSoftmax(dim=1))
])
)
model = model.to(dtype=X.dtype)
optimizer = torch.optim.SGD(
model.parameters(),
lr=1.,
)
losses = []
loss_fn = torch.nn.NLLLoss()
for step in range(1, 1001):
optimizer.zero_grad()
y_pred = model(X)
loss = loss_fn(y_pred, torch.tensor(y, dtype=torch.long))
losses.append(loss.item())
loss.backward()
optimizer.step()
print(losses[-1])
test_data = torch.tensor(
[[6.1925, 2.8127, 4.8053, 1.8340],
[6.3122, 3.3084, 4.6696, 1.5903],
[6.4308, 2.7846, 5.6049, 2.0857],
[6.1065, 2.7853, 3.9796, 1.3208],
[5.1890, 4.1195, 1.4894, 0.1041],
[6.4091, 2.7052, 5.3171, 1.9026],
[5.4741, 3.5095, 1.3004, 0.1830],
[6.7288, 3.0913, 4.7038, 1.4960],
[5.0875, 3.5243, 1.4057, 0.3117],
[5.3994, 3.8903, 1.7050, 0.4009],
[5.6904, 4.3916, 1.4684, 0.3964],
[4.9079, 3.0955, 1.4920, 0.1092],
[7.7159, 3.8090, 6.7016, 2.2142],
[4.8113, 3.0182, 1.3959, 0.2815],
[6.4310, 3.2257, 5.2900, 2.3065],
[6.9995, 3.1955, 4.7015, 1.3973],
[5.6823, 2.9997, 4.2251, 1.2014],
[5.5815, 2.7192, 4.1900, 1.2832],
[5.9034, 3.1997, 4.7991, 1.8313],
[5.7005, 2.6195, 3.4773, 0.9757],
[4.9751, 3.5004, 1.3134, 0.2750],
[6.0946, 2.9318, 4.6946, 1.3818],
[5.0014, 3.2270, 1.1918, 0.2007],
[5.8717, 3.0227, 4.2037, 1.5053],
[5.2060, 3.4223, 1.3902, 0.2009],
[4.3859, 3.2013, 1.3159, 0.2079],
[7.3128, 2.8799, 6.3334, 1.8338],
[6.7187, 3.1061, 5.5931, 2.4143],
[6.6812, 3.0036, 4.9912, 1.7009],
[7.1003, 2.9924, 5.8891, 2.0872],
[4.9905, 2.2989, 3.3021, 0.9962],
[6.2052, 3.4357, 5.4059, 2.2973],
[4.4913, 2.2953, 1.3294, 0.3075],
[7.9164, 3.7958, 6.4071, 1.9915],
[4.9946, 3.5099, 1.6099, 0.6029],
[6.9187, 3.1006, 5.4146, 2.0737],
[6.7196, 3.1102, 4.4057, 1.3974],
[6.3969, 2.8016, 5.5884, 2.1988],
[5.1093, 2.5149, 3.0349, 1.0820],
[5.0814, 3.3899, 1.5106, 0.2116],
[5.5119, 4.2125, 1.4162, 0.2252],
[6.5873, 2.8985, 4.6158, 1.3120],
[6.7868, 2.7933, 4.8130, 1.4186],
[5.7971, 2.6743, 3.8844, 1.1489],
[6.4954, 3.1968, 5.0977, 2.0127],
[6.3132, 2.5050, 4.8871, 1.4825],
[4.9923, 3.3990, 1.4873, 0.1776],
[5.8016, 2.6736, 5.1037, 1.8772],
[6.5899, 3.0080, 4.4031, 1.4098],
[6.7034, 3.2995, 5.6906, 2.5213],
[5.5726, 2.5027, 3.9056, 1.1082],
[4.6110, 3.1552, 1.4819, 0.2269],
[5.3962, 3.7292, 1.5056, 0.1840],
[4.6978, 3.1884, 1.2872, 0.2045],
[7.7259, 2.6203, 6.9175, 2.2707],
[4.9117, 3.5911, 1.3559, 0.1051],
[5.5060, 2.4992, 3.9971, 1.2857],
[6.0250, 2.2070, 3.9895, 0.9892],
[6.2824, 2.7039, 4.8852, 1.7950],
[6.5009, 2.9797, 5.4999, 1.8144],
[5.7074, 2.7965, 4.0783, 1.3030],
[7.1991, 3.0009, 5.7894, 1.5942],
[5.4843, 2.3981, 3.6846, 0.9985],
[5.9921, 2.9082, 4.4937, 1.5119],
[5.7057, 2.8921, 4.1759, 1.3012],
[4.9960, 3.3127, 1.3778, 0.1983],
[7.3930, 2.8093, 6.1120, 1.9125],
[4.8980, 3.0966, 1.5072, 0.2077],
[5.7947, 2.6964, 4.0938, 0.9683],
[5.5206, 2.4189, 3.8104, 1.0849],
[6.3227, 3.3981, 5.6103, 2.4238],
[5.3746, 3.4166, 1.5101, 0.4062],
[7.7002, 3.0019, 6.1238, 2.3163],
[7.6668, 2.7952, 6.6920, 2.0174],
[5.7909, 2.7058, 5.1240, 1.8787],
[4.8202, 3.4038, 1.9038, 0.1706],
[4.6180, 3.3850, 1.3792, 0.2890],
[6.0207, 3.3968, 4.4946, 1.5936],
[6.7062, 2.5106, 5.7983, 1.7929],
[7.1862, 3.1946, 6.0065, 1.8011],
[6.9320, 3.1704, 5.6946, 2.3006],
[5.0794, 3.5046, 1.3968, 0.1867],
[5.4282, 2.9711, 4.5280, 1.5053],
[6.2885, 2.8019, 5.0646, 1.5129],
[6.1996, 2.9017, 4.2869, 1.3103],
[6.9157, 3.0832, 5.0973, 2.3123],
[5.5830, 2.7712, 4.9095, 1.9998],
[4.5919, 3.5793, 0.9960, 0.1591],
[6.4994, 2.9805, 5.7897, 2.1949],
[5.6909, 3.8175, 1.7168, 0.3178],
[4.7868, 3.0152, 1.3686, 0.0763],
[5.7133, 2.5093, 4.9902, 1.9970],
[5.7222, 2.7908, 4.4754, 1.2860],
[5.8094, 2.5860, 3.9722, 1.1959],
[6.7882, 3.0047, 5.4856, 2.0806],
[4.9022, 3.0257, 1.3815, 0.2020],
[6.7021, 3.3029, 5.7350, 2.1114],
[4.3860, 2.8866, 1.4093, 0.1907],
[6.1990, 2.2013, 4.4786, 1.5149],
[5.8978, 3.0139, 5.1014, 1.8140],
[6.1015, 2.9760, 4.6244, 1.4117],
[6.1213, 2.9865, 4.9206, 1.7828],
[5.0933, 3.8093, 1.5047, 0.3021],
[6.0865, 2.8123, 4.7077, 1.2249],
[7.6094, 2.9810, 6.6075, 2.0891],
[6.3063, 3.2640, 5.9935, 2.4758],
[5.6007, 3.0030, 4.0734, 1.2912],
[4.6801, 3.1838, 1.5982, 0.1884],
[4.8895, 2.4996, 4.5087, 1.7239],
[5.8046, 2.7879, 5.0710, 2.3901],
[6.0216, 3.0023, 4.7838, 1.8201],
[6.5022, 2.8065, 4.6073, 1.4858],
[7.1887, 3.6003, 6.0947, 2.4976],
[6.9059, 3.0827, 4.8975, 1.5047],
[4.9918, 3.3971, 1.5863, 0.3994],
[6.7818, 3.2047, 5.8888, 2.3061],
[5.0897, 3.6993, 1.5022, 0.4099],
[6.3094, 2.2760, 4.3909, 1.2844],
[4.7939, 3.3883, 1.6162, 0.1956],
[5.1004, 3.3099, 1.6809, 0.4906],
[4.6038, 3.1900, 1.4267, 0.2038],
[6.4056, 3.0822, 5.4876, 1.7860],
[4.7875, 3.0931, 1.5968, 0.2039],
[5.2887, 3.6922, 1.5052, 0.1877],
[5.1729, 2.6834, 3.8810, 1.3679],
[5.6013, 2.9898, 4.4625, 1.5142],
[6.0093, 2.1654, 4.9892, 1.5131],
[4.9704, 3.6133, 1.4107, 0.1914],
[6.0135, 2.6794, 5.1117, 1.5889],
[4.9013, 2.4039, 3.3006, 0.9957],
[6.2933, 2.8898, 5.6077, 1.7835],
[6.1177, 2.5915, 5.6039, 1.3830],
[4.3163, 3.0072, 1.1065, 0.0893],
[5.0958, 3.7942, 1.9179, 0.3851],
[5.1961, 3.4681, 1.5010, 0.1978],
[6.4864, 2.9841, 5.2081, 2.0097],
[6.2862, 2.5160, 4.9928, 1.9029],
[5.3791, 3.4206, 1.7202, 0.2003],
[6.7209, 2.9944, 5.2160, 2.3220],
[5.4871, 2.3214, 4.0096, 1.3024],
[6.4028, 3.1972, 4.5295, 1.4939],
[5.0231, 2.9747, 1.6009, 0.2373],
[4.9869, 1.9933, 3.5222, 1.0113],
[5.0910, 3.8294, 1.5923, 0.2157],
[4.4141, 3.0291, 1.2948, 0.1959],
[6.3924, 2.8996, 4.2981, 1.2741],
[5.3970, 3.8881, 1.3111, 0.3878],
[5.5865, 2.8803, 3.5829, 1.3133],
[5.8151, 3.9625, 1.1992, 0.1796],
[5.5356, 2.6235, 4.3694, 1.1910]], dtype=torch.float64)
log_probs = model(test_data)
pred_class_probs = log_probs.exp()
pred_class = torch.argmax(log_probs, dim=1)
print(to_list(pred_class))