# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
MIN_NUM_PATCHES = 16
라이브러리 불러오기
class channel_selection(nn.Module):
def __init__(self, num_channels):
"""
Initialize the `indexes` with all one vector with the length same as the number of channels.
During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
"""
super(channel_selection, self).__init__()
self.indexes = nn.Parameter(torch.ones(num_channels))
def forward(self, input_tensor):
"""
Parameter
---------
input_tensor: (B, num_patches + 1, dim).
"""
output = input_tensor.mul(self.indexes)
return output