Создание собственного датасета из тензоров

передо мной стоит задача классификации звука по спектрограммам. У меня есть решение этой задачи одним способом (я конвертирую все аудиозаписи в спектрограммы - > сохраню их как картинки и обучу на этом нейронку) но я хочу пойти более простым путем, то есть не сохранять картинки а сразу преобразовать аудиофайлы в тензоры, но есть проблема, я никак не могу найти путной информации по тому как создать свой датасет из тензоров именно в TensorFlow. Приведу пример такого кода на PyTorch.

class SoundDataset(Dataset):
  def __init__(self, file_names, labels):
    self.file_names = file_names
    self.labels = labels
  def __getitem__(self,index):
    #format the file path and load the file
    path = self.file_names[index]
    scale, sr = librosa.load(path)
    filter_banks = librosa.filters.mel(n_fft=2048, sr=22050, n_mels=10)
    mel_spectrogram = librosa.feature.melspectrogram(scale, sr=sr, n_fft=2048, hop_length=512, n_mels=32)
    log_mel_spectrogram = librosa.power_to_db(mel_spectrogram)
    trch = torch.from_numpy(log_mel_spectrogram)
    if log_mel_spectrogram.shape !=(10,87):
      delta = 87 - log_mel_spectrogram.shape[1]
      trch = torch.nn.functional.pad(trch, (0,delta))

    return trch,self.labels[index]
  def __len__(self):
    return len(self.file_names)

Вот создается класс, который принимает в себя пути к аудиозаписям и конвертирует их в тензоры, и паддит нулями если вдруг тензоры не подходят по шейпам. Как я могу создать такой же класс для TensorFlow. Далее пример кода, который создает кортежи с путями к файлам и их классом и создает объект класса SoundDataset и соответсвенно генерирует датасет из этих файлов. Все это написано для PyTorch. Подскажите каким образом можно реализовать для TensorFlow.

path = '/content/drive/MyDrive/МДМА/audiodata/for-rerecorded/training/'
files = []
labels = []
lbl = '1 0'.split()
for lab in lbl:
  if lab == '0': 
    c = 'fake' 
  else: 
    c ='real'
  names = os.listdir(path+c)
  for n in names:
    pth = path+c+'/'+n
    files.append(pth)
    labels.append(int(lab))
train_dataset = SoundDataset(files, labels)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 20)

Ответы (0 шт):