tensorflow学习之tf.where

    技术2024-08-01  76

    第一种形式:tf.where(tensor),numpy也可以,返回其中为true的元素的索引

    y_true = np.array([0, 1, 0, 1, 0]) a = tf.where(y_true) with tf.Session() as sess: print(a.eval()) 输出: [[1] [3]]

    第二种形式:tf.where(tensor,a,b), 也可以为numpy, 将tensor中的true位置元素替换为a中对应位置元素,false的替换为b中对应位置元素。

    y_true = np.array([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]]) a = tf.where(y_true) with tf.Session() as sess: print(a.eval()) 输出: [[0 1] [0 3] [1 1] [1 3]]

    公众号:

    Processed: 0.011, SQL: 9