Вы можете получить окончательный трехмерный результат E
без создания большой промежуточной матрицы с использованием batched_dot
:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (D, N, H)
B = tt.tensor3('B') # B.shape = (D, H, K)
E = tt.batched_dot(A, B) # E.shape = (D, N, K)
К сожалению, для этого вам необходимо переставить размеры на входные и выходные массивы. Хотя это может быть сделано с dimshuffle
в Теано, кажется batched_dot
не может справиться с произвольно strided массивами и поэтому следующее поднимает ValueError: Some matrix has no unit stride
когда E
оценивается:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_perm = A.dimshuffle((2, 0, 1)) # A_perm.shape = (D, N, H)
B_perm = B.dimshuffle((2, 1, 0)) # B_perm.shape = (D, H, K)
E_perm = tt.batched_dot(A_perm, B_perm) # E_perm.shape = (D, N, K)
E = E_perm.dimshuffle((1, 2, 0)) # E.shape = (N, K, D)
batched_dot
использует scan
вдоль первой (размер D
) измерение , Поскольку scan
выполняется последовательно, это может быть вычислительно менее эффективным, чем вычисление всех продуктов параллельно при работе на графическом процессоре.
Вы можете обменять между эффективностью использования памяти в соответствии с методом приближения и параллелизма в широковещательном подходе с использованием scan
. Идея в том, чтобы вычислить полный продукт C
для партий размером M
параллельно (предполагая, что M
является точным фактором D
), итерация партиями с scan
:
import theano as th
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_batched = A.reshape((N, H, M, D/M))
B_batched = B.reshape((K, H, M, D/M))
E_batched, _ = th.scan(
lambda a, b: (a[:, :, None, :] * b[:, :, :, None]).sum(1),
sequences=[A_batched.T, B_batched.T]
)
E = E_batched.reshape((D, K, N)).T # E.shape = (N, K, D)
Какой размер вы хотите просуммировать? Первый, 0? или «H», который от 2-го до последнего в исходных массивах? – hpaulj
В 'numpy' это может быть выражено как' np.einsum ('nhd, khd-> nkd', A, B) ' – hpaulj
Я хотел бы сделать это над H. Это должно быть суммой (1), считая, что тензор имеет форму (1, H, D) до суммирования. – Theo