Что я делаю неправильно?sklearn BallTree дает неожиданные результаты
Я пытаюсь использовать BallTree от sklearn, чтобы создать похожие коллекции, а затем создать некоторые предложения по элементам, которые могут отсутствовать в данной коллекции.
import random
from sklearn.neighbors import BallTree
import numpy
collections = [] # 10k sample collections of between
# 7 and 15 (of a possible 300...) items
for sample in range(0, 10000): # build sample data
items = random.sample(range(1, 300), random.randint(7, 15))
collections.append(items)
darray = numpy.zeros((len(collections), max(map(len, collections)))) # 10k x 15 matrix
for c_cnt, items in enumerate(collections): # populate matrix
for cnt, i in enumerate(sorted(items)):
darray[C_cnt][cnt] = i
query = BallTree(darray).query(darray[0], k=15)
nearest_neighbors = query[1][0]
# test the results against the first item!
all_sets = [set(darray[0]) & set(darray[item]) for item in nearest_neighbors]
for item in all_sets:
print item # intersection of the neighbor
Я получаю следующие результаты:
set([0.0, 130.0, 167.0, 290.0, 162.0, 144.0, 17.0, 214.0]) # Nearest neighbor is itself! Awesome!
set([0.0]) # WTF? The second closest item shares only 1 item?
set([0.0, 290.0])
set([0.0, 17.0])
set([0.0, 130.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0, 162.0])
set([0.0, 144.0, 162.0]) # uhh okay, i would expect this to be higher up
set([0.0, 144.0, 17.0])
Я наблюдаю, что выше предложенные элементы, как правило, имеют одинаковую длину ненулевых значений в качестве массива я пытающийся сравнить. Есть ли какая-то подготовка, которую я могу сделать с моими данными, чтобы исправить это?