1. tensor
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.int16)
b = torch.tensor([2], dtype=torch.float32)
c = torch.tensor([3], dtype=torch.float64)
print(f"{a}\n{b}\n{c}\n")
print(f"shape of tensor {a.shape}")
print(f"data type of tensor {a.dtype}")
print(f"device tensor is stored on {a.device}")

2. sum, subtract
a = torch.tensor([3, 2])
b = torch.tensor([5, 3])
print(f"input {a}, {b}\n")
sum = a + b
print(f"sum: {sum}\n")
sub = a - b
print(f"sub: {sub}\n")
sum_elements_a = a.sum()
print(f"sum_elements: {sum_elements_a}")

3. multiply
a = torch.arange(0, 9).view(3, 3)
b = torch.arange(0, 9).view(3, 3)
print(f"input tensor\n{a},\n{b}\n")
c = torch.matmul(a, b)
print(f"mat mul\n{c}\n")
d = torch.mul(a, b)
print(f"elementwise mul\n{d}")

4. view, transpose
a = torch.tensor([2, 4, 5, 6, 7, 8])
print(f"input tensor\n{a}\n")
b = a.view(2, 3)
print(f"view\n{b}\n")
b_t = b.t()
print(f"transpose\n{b_t}")

5. slicing
a = torch.arange(1, 13).view(4, 3)
print(f"input tensor\n{a}\n")
print("slicing")
print(a[:, 0])
print(a[0, :])
print(a[1, 1])

a = torch.arange(1, 13).view(4, 3)
print(f"input\n{a}\n")
a_np = a.numpy()
print(f"tensor to numpy\n{a_np}\n")
b = np.array([1, 2, 3])
b_torch = torch.from_numpy(b)
print(f"numpy to tensor\n{b_torch}")

7. concatenate
a = torch.arange(1, 10).view(3, 3)
b = torch.arange(1, 10).view(3, 3)
c = torch.arange(1, 10).view(3, 3)
print(f"input\n{a}\n{b}\n{c}\n")
abc_0 = torch.concat([a, b, c], dim=0)
print(f"concat (dim=0)\n{abc_0}\n{abc_0.shape}\n")
abc_1 = torch.concat([a, b, c], dim=1)
print(f"concat (dim=1)\n{abc_1}\n{abc_1.shape}")

8. stack
a = torch.arange(1, 10).view(3, 3)
b = torch.arange(1, 10).view(3, 3)
c = torch.arange(1, 10).view(3, 3)
print(f"input\n{a}\n{b}\n{c}\n")
abc_0 = torch.stack([a, b, c], dim=0)
print(f"stack (dim=0)\n{abc_0}\n{abc_0.shape}\n")
abc_1 = torch.stack([a, b, c], dim=1)
print(f"stack (dim=1)\n{abc_1}\n{abc_1.shape}")

9. transpose, permute
a = torch.arange(1, 10).view(3, 3)
a_t = torch.transpose(a, 0, 1)
print(f"input\n{a}")
print(f"transpose\n{a_t}\n")
b = torch.arange(1, 25).view(4, 3, 2)
b_t = torch.transpose(b, 0, 2)
print(f"input\n{b}, {b.shape}")
print(f"transpose\n{b_t}\n{b_t.shape}\n")
b_permute = b.permute(2, 0, 1)
print(f"permute\n{b_permute}, {b_permute.shape}")
