Создание собственного датасета из тензоров
передо мной стоит задача классификации звука по спектрограммам. У меня есть решение этой задачи одним способом (я конвертирую все аудиозаписи в спектрограммы - > сохраню их как картинки и обучу на этом нейронку) но я хочу пойти более простым путем, то есть не сохранять картинки а сразу преобразовать аудиофайлы в тензоры, но есть проблема, я никак не могу найти путной информации по тому как создать свой датасет из тензоров именно в TensorFlow. Приведу пример такого кода на PyTorch.
class SoundDataset(Dataset):
def __init__(self, file_names, labels):
self.file_names = file_names
self.labels = labels
def __getitem__(self,index):
#format the file path and load the file
path = self.file_names[index]
scale, sr = librosa.load(path)
filter_banks = librosa.filters.mel(n_fft=2048, sr=22050, n_mels=10)
mel_spectrogram = librosa.feature.melspectrogram(scale, sr=sr, n_fft=2048, hop_length=512, n_mels=32)
log_mel_spectrogram = librosa.power_to_db(mel_spectrogram)
trch = torch.from_numpy(log_mel_spectrogram)
if log_mel_spectrogram.shape !=(10,87):
delta = 87 - log_mel_spectrogram.shape[1]
trch = torch.nn.functional.pad(trch, (0,delta))
return trch,self.labels[index]
def __len__(self):
return len(self.file_names)
Вот создается класс, который принимает в себя пути к аудиозаписям и конвертирует их в тензоры, и паддит нулями если вдруг тензоры не подходят по шейпам. Как я могу создать такой же класс для TensorFlow. Далее пример кода, который создает кортежи с путями к файлам и их классом и создает объект класса SoundDataset и соответсвенно генерирует датасет из этих файлов. Все это написано для PyTorch. Подскажите каким образом можно реализовать для TensorFlow.
path = '/content/drive/MyDrive/МДМА/audiodata/for-rerecorded/training/'
files = []
labels = []
lbl = '1 0'.split()
for lab in lbl:
if lab == '0':
c = 'fake'
else:
c ='real'
names = os.listdir(path+c)
for n in names:
pth = path+c+'/'+n
files.append(pth)
labels.append(int(lab))
train_dataset = SoundDataset(files, labels)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 20)