[Tensor]#9 gather, scatter

Clay Ryu's sound lab·2024년 4월 28일
0

Framework

목록 보기
45/48

This post is created with the help of GPT-4

Gather

The gather function collects values from a tensor along a specified dimension, based on indices provided in a separate tensor. It's useful for creating a new tensor by picking specific elements from an input tensor.

Here's an example to illustrate:

import torch

# Let's create a tensor of shape (3, 4)
t = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# We want to gather elements from each row
# The indices we will use for each row are specified in the following tensor
indices = torch.tensor([[0, 2], [1, 3], [2, 3]])

# Gather elements
result = torch.gather(t, 1, indices)

print(result)
tensor([[ 1,  3],
        [ 6,  8],
        [11, 12]])

In this example, torch.gather(t, 1, indices) takes values from t according to indices. The 1 specifies that we are selecting elements along the columns (dimension 1). For the first row, it picks elements at positions 0 and 2, resulting in [1, 3], and so on for the other rows.

Scatter

The scatter_ function (note the underscore, indicating it does in-place modification) writes values into a tensor based on given indices. It's a way of filling a tensor with values at specific positions, which can be specified in another tensor.

Here's an example:

import torch

# Create a tensor filled with zeros, which we'll fill using scatter_
t = torch.zeros(3, 5)

# Indices where we want to scatter values
indices = torch.tensor([[0, 1, 2],
                        [2, 0, 1]])

# Values to scatter
values = torch.tensor([[10, 20, 30],
                       [40, 50, 60]])

# Scatter values into t
result = t.scatter_(1, indices, values)

print(result)
tensor([[10., 20.,  30.,  0., 0.],
        [50., 60., 40.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

In this example, scatter_ modifies t in place. The first argument 1 specifies that we are scattering along columns. For each row in indices, it takes the corresponding row in values and places each value in the position specified by indices. The first row of indices is [0, 1, 2], so it places the values [10, 20, 30] in those positions in the first row of t. The operation is repeated for the second row, and so on. Note that the last row of t remains zeros because we didn't specify any indices for it.

profile
chords & code // harmony with structure

0개의 댓글