Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
861 views
in Technique[技术] by (71.8m points)

python - tensor_scatter_nd_update ValueError: Shapes must be equal rank, but are 0 and 1

I've always been able to use tf.tensor_scatter_nd_update without any problems to write into tensors, but I can't manage to figure our why it's not working with some specific tensors.

As a simple example, say I want to set certain values in input=[[0 0 0]] to update=[[1 2 3]], based on a boolean mask mask=[[1 0 1]]. I would simply do:

input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)

expecting the result of the operation to be input=[[1 0 3]].

Instead I'm getting

ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].

I really can't work out what's wrong; I've always been able to use the function without issue even in much more complex cases.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

I figured it out.

Part of the problem is indeed that tf.where() returns a 2-D tensor, but this came into play because I was using it to also generate the updates vector:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))

The solution is to remove the extra dimension by:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share
...