edo1z blog

プログラミングなどに関するブログです

TensorFlow - tf.strided_slice関数を調べる

tf.strided_sliceを調べます。TensorFlowのGithubにのってる説明ページはこれです。

tf.strided_slice(input_, begin, end, strides=None, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0, var=None, name=None) {#strided_slice}

実験してみる

コード

import tensorflow as tf

tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
result = tf.strided_slice(tensor, [0], [9], [2])

with tf.Session() as sess:
    result = sess.run(result)
    print(result)

結果

[1 3 5 7 9]

とりあえず、inputは元データ、beginは開始位置、endは終了位置、stridesは間隔のようです。以前はstridesを指定しない場合、デフォルトでstridesを1とみなしていたようですが、最近のtensorflowのアップデートで、stridesも明示しないとエラーになるようになったようです。ちなみに、終了位置は普通の配列のスライスと同じで、指定したインデックスのひとつ前までになります。

これだけなら簡単なのですが、他にも色々引数があるし、inputもbeginも多次元配列に対応しているようです。

コード

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]])
result = tf.strided_slice(tensor, [0, 5], [2, 8], [1, 1])

with tf.Session() as sess:
    result = sess.run(result)
    print(result)

結果

[[ 6  7  8]
 [16 17 18]]

begin、end、stridesはnumpyのshapeのような感じで入れていくようです。コードだとbeginは[0, 5]ですが、これは1次元目は0から始まり、2次元目は5から始まるということになるようです。