Найти минимальное косинусное расстояние между двумя матрицами

У меня есть два 2D np.arrays давайте назовем их A и B, оба имеют форму. Для каждого вектора в двумерном массиве A мне нужно найти вектор в матрице B с минимальным косинусным расстоянием. Для этого у меня просто есть двойной цикл for, внутри которого я пытаюсь найти минимальное значение. В общем, я делаю следующее:

from scipy.spatial.distance import cosine
l, res = A.shape[0], []
for i in xrange(l):
    minimum = min((cosine(A[i], B[j]), j) for j in xrange(l))
    res.append(minimum[1])

В приведенном выше коде один цикл скрыт за пониманием. Все работает нормально, но двойной цикл for делает его слишком медленным (я пытался переписать его с двойным пониманием, что сделало вещи немного быстрее, но все же медленно).

Я считаю, что есть функция numpy, которая может быстрее достичь следующего (используя некоторую линейную алгебру).

Так есть ли способ добиться того, чего я хочу быстрее?

4 голоса | спросил Salvador Dali 21 stEurope/Moscowp30Europe/Moscow09bEurope/MoscowMon, 21 Sep 2015 09:43:50 +0300 2015, 09:43:50

2 ответа


0

Из cosine docs у нас есть следующая информация -

scipy.spatial.distance.cosine (u, v) : вычисляет косинусное расстояние между одномерными массивами.

Косинусное расстояние между u и v, определяется как

 введите описание изображения здесь

где u⋅v - это скалярное произведение u и v.

Используя приведенную выше формулу, мы получили бы одно векторизованное решение, используя ` Возможность трансляции NumPy , вот так -

# Get the dot products, L2 norms and thus cosine distances
dots = np.dot(A,B.T)
l2norms = np.sqrt(((A**2).sum(1)[:,None])*((B**2).sum(1)))
cosine_dists = 1 - (dots/l2norms)

# Get min values (if needed) and corresponding indices along the rows for res.
# Take care of zero L2 norm values, by using nanmin and nanargmin  
minval = np.nanmin(cosine_dists,axis=1)
cosine_dists[np.isnan(cosine_dists).all(1),0] = 0
res = np.nanargmin(cosine_dists,axis=1)

Испытания во время выполнения -

In [81]: def org_app(A,B):
    ...:    l, res, minval = A.shape[0], [], []
    ...:    for i in xrange(l):
    ...:        minimum = min((cosine(A[i], B[j]), j) for j in xrange(l))
    ...:        res.append(minimum[1])
    ...:        minval.append(minimum[0])
    ...:    return res, minval
    ...: 
    ...: def vectorized(A,B):
    ...:     dots = np.dot(A,B.T)
    ...:     l2norms = np.sqrt(((A**2).sum(1)[:,None])*((B**2).sum(1)))
    ...:     cosine_dists = 1 - (dots/l2norms)
    ...:     minval = np.nanmin(cosine_dists,axis=1)
    ...:     cosine_dists[np.isnan(cosine_dists).all(1),0] = 0
    ...:     res = np.nanargmin(cosine_dists,axis=1)
    ...:     return res, minval
    ...: 

In [82]: A = np.random.rand(400,500)
    ...: B = np.random.rand(400,500)
    ...: 

In [83]: %timeit org_app(A,B)
1 loops, best of 3: 10.8 s per loop

In [84]: %timeit vectorized(A,B)
10 loops, best of 3: 145 ms per loop

Проверить результаты -

In [86]: x1, y1 = org_app(A, B)
    ...: x2, y2 = vectorized(A, B)
    ...: 

In [87]: np.allclose(np.asarray(x1),x2)
Out[87]: True

In [88]: np.allclose(np.asarray(y1)[~np.isnan(np.asarray(y1))],y2[~np.isnan(y2)])
Out[88]: True
ответил Divakar 21 stEurope/Moscowp30Europe/Moscow09bEurope/MoscowMon, 21 Sep 2015 10:22:54 +0300 2015, 10:22:54
0

Использование scipy.spatial.distance.cdist:

from scipy.spatial.distance import cdist

def cdist_func(A, B):
    dists = cdist(A, B, 'cosine')
    return np.argmin(dists, axis=1), np.min(dists, axis=1)

Это дает те же результаты, что и ответ Дивакара:

x2, y2 = vectorized(A, B)
x3, y3 = cdist_func(A, B)

np.allclose(x2, x3) # True
np.allclose(y2, y3) # True

Но это не так быстро:

%timeit vectorized(A, B) # 11.9 ms per loop
%timeit cdist_func(A, B) # 85.9 ms per loop
ответил user2034412 21 stEurope/Moscowp30Europe/Moscow09bEurope/MoscowMon, 21 Sep 2015 19:52:21 +0300 2015, 19:52:21

Похожие вопросы

Популярные теги

security × 330linux × 316macos × 2827 × 268performance × 244command-line × 241sql-server × 235joomla-3.x × 222java × 189c++ × 186windows × 180cisco × 168bash × 158c# × 142gmail × 139arduino-uno × 139javascript × 134ssh × 133seo × 132mysql × 132