2016-06-24 3 views
1

Я написал код в Tensorflow, чтобы вычислить расстояние редактирования между одной строкой и набором строк. Я не могу понять ошибку.Computing Edit Distance (ошибка feed_dict)

import tensorflow as tf 
sess = tf.Session() 

# Create input data 
test_string = ['foo'] 
ref_strings = ['food', 'bar'] 

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return(tf.SparseTensor(indices, chars, [num_words,1,1])) 


test_string_sparse = create_sparse_vec(test_string*len(ref_strings)) 
ref_string_sparse = create_sparse_vec(ref_strings) 

sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True)) 

Этот код работает и при запуске он выводит:

array([[ 0.25], 
     [ 1. ]], dtype=float32) 

Но когда я пытаюсь сделать это путем подачи разреженных тензоров через разреженные заполнители, я получаю сообщение об ошибке.

test_input = tf.sparse_placeholder(dtype=tf.string) 
ref_input = tf.sparse_placeholder(dtype=tf.string) 

edit_distances = tf.edit_distance(test_input, ref_input, normalize=True) 

feed_dict = {test_input: test_string_sparse, 
      ref_input: ref_string_sparse} 

sess.run(edit_distances, feed_dict=feed_dict) 

Вот отслеживающий ошибки:

Traceback (most recent call last): 

    File "<ipython-input-29-4e06de0b7af3>", line 1, in <module> 
    sess.run(edit_distances, feed_dict=feed_dict) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run 
run_metadata_ptr) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run 
    for subfeed, subfeed_val in _feed_fn(feed, feed_val): 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn 
    return feed_fn(feed, feed_val) 

    File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda> 
    [feed.indices, feed.values, feed.shape], feed_val)), 

TypeError: zip argument #2 must support iteration 

Любая идея, что здесь происходит?

+0

ошибка, вероятно, происходит от значения '' test_string_parse' или ref_string_parse', вы можете предоставить код для их создания –

ответ

2

TL; DR: Для типа возврата create_sparse_vec() используйте tf.SparseTensorValue вместо tf.SparseTensor.

Проблема здесь происходит от типа возвращаемого create_sparse_vec(), который tf.SparseTensor, и не понял, как значение подачи в вызове sess.run().

Когда вы кормите (плотный) tf.Tensor, ожидаемым типом значения является массив NumPy (или некоторые объекты, которые могут быть преобразованы в массив). Когда вы подаете tf.SparseTensor, ожидаемый тип значения - это tf.SparseTensorValue, который аналогичен tf.SparseTensor, но его indices, values и shape являются массивами NumPy (или некоторыми объектами, которые могут быть преобразованы в массивы, например списки в вашем примере).

следующий код должен работать:

def create_sparse_vec(word_list): 
    num_words = len(word_list) 
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)] 
    chars = list(''.join(word_list)) 
    return tf.SparseTensorValue(indices, chars, [num_words,1,1]) 
+0

Спасибо, что работает идеально?!. – nfmcclure