Помогите правильно написать tf.while_loop

Помогите правильно написать через tf.while_loop с возможностью распараллеливания через parallel_iterations. Уже всю голову сломал с этими lambda. В текущем виде все очень тормозит.

@tf.function
def data_augment(image, Ytr): 
    ...
    i = 0    
    while i < EMB_SIZE_2:      
      modified = tf.concat([tf.slice(modified, [0, 0, 0], [WINDOW, WINDOW, EMB_SIZE_2 + i]), tf.image.random_jpeg_quality(tf.slice(image, [0,0,EMB_SIZE_2 + i], [WINDOW,WINDOW,1]), 90, 100)], 2)           
      i += 1         
    return modified, Ytr 

Ответы (1 шт):

Автор решения: RSR
i = tf.constant(0)
c = lambda modified, i: tf.less(i, EMB_SIZE_2)
b = lambda modified, i: (tf.concat([tf.slice(modified, [0, 0, 0], [WINDOW, WINDOW, EMB_SIZE_2 + i]), tf.image.random_jpeg_quality(tf.slice(image, [0,0,EMB_SIZE_2 + i], [WINDOW,WINDOW,1]), 90, 100)], 2), tf.add(i, 1)) 
modified, i = tf.while_loop(c, b, (modified,i), shape_invariants=(tf.TensorShape([WINDOW,WINDOW,None]), i.get_shape()), parallel_iterations=EMB_SIZE_2) 
→ Ссылка