Помогите правильно написать tf.while_loop
Помогите правильно написать через tf.while_loop с возможностью распараллеливания через parallel_iterations. Уже всю голову сломал с этими lambda. В текущем виде все очень тормозит.
@tf.function
def data_augment(image, Ytr):
...
i = 0
while i < EMB_SIZE_2:
modified = tf.concat([tf.slice(modified, [0, 0, 0], [WINDOW, WINDOW, EMB_SIZE_2 + i]), tf.image.random_jpeg_quality(tf.slice(image, [0,0,EMB_SIZE_2 + i], [WINDOW,WINDOW,1]), 90, 100)], 2)
i += 1
return modified, Ytr
Ответы (1 шт):
Автор решения: RSR
→ Ссылка
i = tf.constant(0)
c = lambda modified, i: tf.less(i, EMB_SIZE_2)
b = lambda modified, i: (tf.concat([tf.slice(modified, [0, 0, 0], [WINDOW, WINDOW, EMB_SIZE_2 + i]), tf.image.random_jpeg_quality(tf.slice(image, [0,0,EMB_SIZE_2 + i], [WINDOW,WINDOW,1]), 90, 100)], 2), tf.add(i, 1))
modified, i = tf.while_loop(c, b, (modified,i), shape_invariants=(tf.TensorShape([WINDOW,WINDOW,None]), i.get_shape()), parallel_iterations=EMB_SIZE_2)