[Tensor]#5 torch.einsum

Clay Ryu's sound lab·2023년 6월 23일
0

Framework

목록 보기
14/48

https://pytorch.org/docs/stable/generated/torch.einsum.html

Einsum

This function supports multiplications between multiple combinations of matrix, vector, scalar. This approach has the advantage of allowing for a more intuitive understanding of the calculation method. Let's see the examples

a = torch.ones(1)
b = torch.ones(1)
c = torch.einsum('i,i->', a, b)
A = torch.arange(32.).reshape(8, 4)
B = torch.arange(16.).reshape(4, 4)
C_einsum = torch.einsum('ij,jk->ik', A, B)
A = torch.arange(800.).reshape(32, 5, 5)/800.0
B = torch.arange(320.).reshape(32, 5, 2)/320.0
C_einsum = torch.einsum('ijk,ikl->ijl', A, B)

Compare Time Consumption

import time

u = torch.randn(100000)
v = torch.randn(100000)

start_time = time.time()
print(f"Result of (u * v).sum() is {(u * v).sum()}")
print(f"Time taken_*: {time.time() - start_time}")
print("--------------------------------------------------------")
start_time = time.time()
print(f"Result of torch.mm(u.unsqueeze(0), v.unsqueeze(1)) is {torch.mm(u.unsqueeze(0), v.unsqueeze(1))}")
print(f"Time taken_mm: {time.time() - start_time}")
print("--------------------------------------------------------")
start_time = time.time()
print(f"Result of torch.matmul(u, v) is {torch.matmul(u, v)}")
print(f"Time taken_matmul: {time.time() - start_time}")
print("--------------------------------------------------------")
start_time = time.time()
print(f"Result of u @ v is {u @ v}")
print(f"Time taken_@: {time.time() - start_time}")
print("--------------------------------------------------------")
start_time = time.time()
print(f"Result of torch.einsum is {torch.einsum('i,i->', u, v)}")
print(f"Time taken_einsum: {time.time() - start_time}")
Result of (u * v).sum() is -195.31497192382812
Time taken_*: 0.0002617835998535156
--------------------------------------------------------
Result of torch.mm(u.unsqueeze(0), v.unsqueeze(1)) is tensor([[-195.3152]])
Time taken_mm: 0.0009050369262695312
--------------------------------------------------------
Result of torch.matmul(u, v) is -195.31488037109375
Time taken_matmul: 0.00019359588623046875
--------------------------------------------------------
Result of u @ v is -195.31488037109375
Time taken_@: 0.000171661376953125
--------------------------------------------------------
Result of torch.einsum is -195.3148956298828
Time taken_einsum: 0.00012087821960449219
profile
chords & code // harmony with structure

0개의 댓글