pytorch 와 c++

이승화·2022년 5월 13일
0

stylegan2 paper에 따르면 pytorch의 up/downsampling, bias addition, leakyrelu가 training time과 gpu memory footprint에 큰 overhead를 준다고 한다. 따라서 저자들은 filtered up/downsampling과 bias and activation을 CUDA kernel으로 만들어 pytorch implementation을 대체해 시간과 메모리 이득을 보았다.

  1. CUDA
    • A Platform for Heterogeneous Computing
    • A Programming interface for utilizing NVIDIA GPU

여기서 Heterogenous computing은 cpu-gpu, Nvidia gpu-AMD gpu처럼 여러종류의 컴퓨팅 리소스를 사용하는 것을 말한다. host인 cpu와 device인 gpu가 서로 잘하는 작업을 나눠서 처리하는게 효율적인 프로그래밍인데, CUDA Complier(NVCC Compiler)가 CUDA code를 각각 host code와 device code로 컴파일 해준다.

2. CUDA Kernel

  • CUDA C++ extends C++ by allowing the programmer to define C++ functions, called kernels, that, when called, are executed N times in parallel by N different CUDA threads, as opposed to only once like regular C++ functions.

NVIDIA docs의 설명에 따르면 kernel을 함수와 유사하지만 병렬적인 특징을 갖고 있음을 알 수 있다. 즉 하나의 명령어(Kernel)이 여러개의 Thread를 통해 처리한다는 뜻으로 이해할 수 있다.
이를 보다 잘 이해하기 위해선 GPU를 이해하면 좋다.

CPU는 주로 강력한 코어를 사용해 복잡한 작업을 직렬적으로 처리하는 반면 GPU는 여러개의 작은 코어들을 사용해 단순한 연산을 병렬적으로 처리하는데 특화돼있다.
특히 GPU의 구조는 SIMT(Single Instruction Multiple Threads) 이다.
SIMT에 대한 설명을 하기전에 GPU의 Thread Hierarchy에 대한 설명을 먼저 하겠다.

GPU에 커널이 호출에 의해 생성되는 모든 Thread를 Grid라 한다.
Grid는 여러개의 Thread Block(TB)로 구성되고, Thread Block은 Thread의 묶음이다.

threadId.x, threadId.y, threadId.z, blockId.x, blockId.y, blockId.z blockDim, gridDim 을 통해 thread indexing이 가능

The threads in the same thread block run on the same Stream Processor(SP).

All the blocks in the same grid contain the same number of threads.

we should note that, threads, thread blocks and grid are essentially a programmer’s perspective. In order to get a complete gist of thread block, it is critical to know it from a hardware perspective.
The hardware groups threads that execute the same instruction into warps. Several warps constitute a thread block. Several thread blocks are assigned to a Streaming Multiprocessor (SM). Several SM constitute the whole GPU unit (which executes the whole Kernel Grid).

GPU의 물리적 구조는 SP, SM으로 설명할 수 있다.
SP(Streaming Processor)는 보통 CUDA core라 한다.

Thread Block (TB) Scheduler가 TB를 Streaming Multiprocessor (SM)에 Scheduling 한다.
Warp는 하나의 TB의 Thread를 32개 단위로 나뉘어서 Instruction 한개를 실행하기 위한 최소 단위이다. (왜 32개로 나뉘는가? Warp을 처리하는 CUDA core가 32개의 thread를 처리하기 때문) Warp는 GPU Instruction을 실행하는 가장 기본 단위이다. Warp의 모든 Thread는 하나의 같은 Instruction을 실행한다. 하지만, Warp의 각 Thread는 서로 다른 Data 값을 읽을 수 있다. 결과적으로 Warp의 각 Thread는 같은 Instruction을 다른 Data 값을 사용하여 연산을 수행한다고 생각하면 된다. 이러한 연산방법을 SIMD (Single Instruction Multiple Data) 또는 SIMT (Single Instruction Multiple Thread)라고 한다. 프로그램을 작성하는 당시에는 하나의 Thread를 기준으로 코드를 작성하면 된다. 그리고 하드웨어가 여러 개의 Thread를 그룹화하여서 연산을 수행하는 것이다.

종합해보면 CUDA Kernel을 선언할때부터 Grid 안의 Block의 수, Block안의 Thread의 수를 사용자가 지정한다. 그리고 Block을 SM에, Block안의 Thread들을 Warp으로 묶어 SM안의 CUDA core 들로 처리한다.

  1. Pytorch에 custom CUDA kernel 사용하기

    1. write a C++ file which defines the functions that will be called from Python, and binds those functions to Python with pybind11. Also declare functions that are defined in CUDA (.cu) files.
    2. In the CUDA files, we write our actual CUDA kernels.
    3. cpp_extension package will then take care of compiling the C++ sources with a C++ compiler like gcc and the CUDA sources with NVIDIA’s nvcc compiler
#include <ATen/ATen.h>
#include <torch/extension.h>

// declare?? function that are definced in .cu
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
                           const torch::Tensor &kernel, int up_x, int up_y,
                           int down_x, int down_y, int pad_x0, int pad_x1,
                           int pad_y0, int pad_y1);

#define CHECK_CUDA(x)                                                          \
  TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x)                                                    \
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x)                                                         \
  CHECK_CUDA(x);                                                               \
  CHECK_CONTIGUOUS(x)

  
// write function that will be called from python
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
                        int up_x, int up_y, int down_x, int down_y, int pad_x0,
                        int pad_x1, int pad_y0, int pad_y1) {
  CHECK_INPUT(input);
  CHECK_INPUT(kernel);

  at::DeviceGuard guard(input.device());

  return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
                      pad_y0, pad_y1);
}

  
// bind function by PYBIND11
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
  
// TORCH_EXTENSION_NAME is a macro? that assign name 
// jit를 적용하면 aot 처럼 setup.py이 필요없이 cpp_extension.load()만 하면됨
// 아마도 upfirdn2d

Pybind is header-only library that exposes C++ types in Python and vice versa, mainly to create Python bindings of existing C++ code.
c++, python간의 코드를 연결해주는 라이브러리
사용법은 아래와 같다.

//cpp
 PYBIND11_MODULE(module_name, helper):
 	helper.def("python_function_name", &cpp_function);
#python
 import module_name
 module_name.python_function_name()
#python
from collections import abc
import os

import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load 
  # compile / Returns the loaded PyTorch extension as a Python module
  # load(name = The name of the extension to build. 
  #             This MUST be the same as the name of the pybind11 module,
  # 	 source = A list of relative or absolute paths to C++ source files.)
  #  	  CUDA support with mixed compilation is provided. 
  #	     Simply pass CUDA source files (.cu or .cuh) along with other sources.


module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
    "upfirdn2d",
    sources=[
        os.path.join(module_path, "upfirdn2d.cpp"),
        os.path.join(module_path, "upfirdn2d_kernel.cu"),
    ],
)


class UpFirDn2dBackward(Function):
    @staticmethod
    def forward(
        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
    ):

        up_x, up_y = up
        down_x, down_y = down
        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad

        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)

        grad_input = upfirdn2d_op.upfirdn2d(
            grad_output,
            grad_kernel,
            down_x,
            down_y,
            up_x,
            up_y,
            g_pad_x0,
            g_pad_x1,
            g_pad_y0,
            g_pad_y1,
        )
        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])

        ctx.save_for_backward(kernel)

        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        ctx.up_x = up_x
        ctx.up_y = up_y
        ctx.down_x = down_x
        ctx.down_y = down_y
        ctx.pad_x0 = pad_x0
        ctx.pad_x1 = pad_x1
        ctx.pad_y0 = pad_y0
        ctx.pad_y1 = pad_y1
        ctx.in_size = in_size
        ctx.out_size = out_size

        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_input):
        kernel, = ctx.saved_tensors

        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)

        gradgrad_out = upfirdn2d_op.upfirdn2d(
            gradgrad_input,
            kernel,
            ctx.up_x,
            ctx.up_y,
            ctx.down_x,
            ctx.down_y,
            ctx.pad_x0,
            ctx.pad_x1,
            ctx.pad_y0,
            ctx.pad_y1,
        )
        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
        gradgrad_out = gradgrad_out.view(
            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
        )

        return gradgrad_out, None, None, None, None, None, None, None, None


class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out

    @staticmethod
    def backward(ctx, grad_output):
        kernel, grad_kernel = ctx.saved_tensors

        grad_input = None

        if ctx.needs_input_grad[0]:
            grad_input = UpFirDn2dBackward.apply(
                grad_output,
                kernel,
                grad_kernel,
                ctx.up,
                ctx.down,
                ctx.pad,
                ctx.g_pad,
                ctx.in_size,
                ctx.out_size,
            )

        return grad_input, None, None, None, None


def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    if not isinstance(up, abc.Iterable):
        up = (up, up)

    if not isinstance(down, abc.Iterable):
        down = (down, down)

    if len(pad) == 2:
        pad = (pad[0], pad[1], pad[0], pad[1])

    if input.device.type == "cpu":
        out = upfirdn2d_native(input, kernel, *up, *down, *pad)

    else:
        out = UpFirDn2d.apply(input, kernel, up, down, pad)

    return out


def upfirdn2d_native(
    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x

    return out.view(-1, channel, out_h, out_w)  
  
 // cu file
#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>

#include <cuda.h>
#include <cuda_runtime.h>

static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
  int c = a / b;

  if (c * b > a) {
    c--;
  }

  return c;
}

struct UpFirDn2DKernelParams {
  int up_x;
  int up_y;
  int down_x;
  int down_y;
  int pad_x0;
  int pad_x1;
  int pad_y0;
  int pad_y1;

  int major_dim; // ?
  int in_h;  // input_height?
  int in_w;
  int minor_dim; // ?
  int kernel_h;
  int kernel_w;
  int out_h;
  int out_w;
  int loop_major; // ?
  int loop_x;  // ?
};

template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
                                       const scalar_t *kernel,
                                       const UpFirDn2DKernelParams p) {
  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int out_y = minor_idx / p.minor_dim;
  minor_idx -= out_y * p.minor_dim;
  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
  int major_idx_base = blockIdx.z * p.loop_major;

  if (out_x_base >= p.out_w || out_y >= p.out_h ||
      major_idx_base >= p.major_dim) {
    return;
  }

  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;

  for (int loop_major = 0, major_idx = major_idx_base;
       loop_major < p.loop_major && major_idx < p.major_dim;
       loop_major++, major_idx++) {
    for (int loop_x = 0, out_x = out_x_base;
         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;

      const scalar_t *x_p =
          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
                 minor_idx];
      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
      int x_px = p.minor_dim;
      int k_px = -p.up_x;
      int x_py = p.in_w * p.minor_dim;
      int k_py = -p.up_y * p.kernel_w;

      scalar_t v = 0.0f;

      for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
          x_p += x_px;
          k_p += k_px;
        }

        x_p += x_py - w * x_px;
        k_p += k_py - w * k_px;
      }

      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
          minor_idx] = v;
    }
  }
}

template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
                                 const scalar_t *kernel,
                                 const UpFirDn2DKernelParams p) {
  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;

  __shared__ volatile float sk[kernel_h][kernel_w];
  __shared__ volatile float sx[tile_in_h][tile_in_w];

  int minor_idx = blockIdx.x;
  int tile_out_y = minor_idx / p.minor_dim;
  minor_idx -= tile_out_y * p.minor_dim;
  tile_out_y *= tile_out_h;
  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
  int major_idx_base = blockIdx.z * p.loop_major;

  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
      major_idx_base >= p.major_dim) {
    return;
  }

  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
       tap_idx += blockDim.x) {
    int ky = tap_idx / kernel_w;
    int kx = tap_idx - ky * kernel_w;
    scalar_t v = 0.0;

    if (kx < p.kernel_w & ky < p.kernel_h) {
      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
    }

    sk[ky][kx] = v;
  }

  for (int loop_major = 0, major_idx = major_idx_base;
       loop_major < p.loop_major & major_idx < p.major_dim;
       loop_major++, major_idx++) {
    for (int loop_x = 0, tile_out_x = tile_out_x_base;
         loop_x < p.loop_x & tile_out_x < p.out_w;
         loop_x++, tile_out_x += tile_out_w) {
      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
      int tile_in_x = floor_div(tile_mid_x, up_x);
      int tile_in_y = floor_div(tile_mid_y, up_y);

      __syncthreads();

      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
           in_idx += blockDim.x) {
        int rel_in_y = in_idx / tile_in_w;
        int rel_in_x = in_idx - rel_in_y * tile_in_w;
        int in_x = rel_in_x + tile_in_x;
        int in_y = rel_in_y + tile_in_y;

        scalar_t v = 0.0;

        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
                        p.minor_dim +
                    minor_idx];
        }

        sx[rel_in_y][rel_in_x] = v;
      }

      __syncthreads();
      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
           out_idx += blockDim.x) {
        int rel_out_y = out_idx / tile_out_w;
        int rel_out_x = out_idx - rel_out_y * tile_out_w;
        int out_x = rel_out_x + tile_out_x;
        int out_y = rel_out_y + tile_out_y;

        int mid_x = tile_mid_x + rel_out_x * down_x;
        int mid_y = tile_mid_y + rel_out_y * down_y;
        int in_x = floor_div(mid_x, up_x);
        int in_y = floor_div(mid_y, up_y);
        int rel_in_x = in_x - tile_in_x;
        int rel_in_y = in_y - tile_in_y;
        int kernel_x = (in_x + 1) * up_x - mid_x - 1;
        int kernel_y = (in_y + 1) * up_y - mid_y - 1;

        scalar_t v = 0.0;

#pragma unroll
        for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
          for (int x = 0; x < kernel_w / up_x; x++)
            v += sx[rel_in_y + y][rel_in_x + x] *
                 sk[kernel_y + y * up_y][kernel_x + x * up_x];

        if (out_x < p.out_w & out_y < p.out_h) {
          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
              minor_idx] = v;
        }
      }
    }
  }
}

torch::Tensor upfirdn2d_op(const torch::Tensor &input,
                           const torch::Tensor &kernel, int up_x, int up_y,
                           int down_x, int down_y, int pad_x0, int pad_x1,
                           int pad_y0, int pad_y1) {
  int curDevice = -1;
  cudaGetDevice(&curDevice);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  UpFirDn2DKernelParams p;

  auto x = input.contiguous();
  auto k = kernel.contiguous();

  p.major_dim = x.size(0);
  p.in_h = x.size(1);
  p.in_w = x.size(2);
  p.minor_dim = x.size(3);
  p.kernel_h = k.size(0);
  p.kernel_w = k.size(1);
  p.up_x = up_x;
  p.up_y = up_y;
  p.down_x = down_x;
  p.down_y = down_y;
  p.pad_x0 = pad_x0;
  p.pad_x1 = pad_x1;
  p.pad_y0 = pad_y0;
  p.pad_y1 = pad_y1;

  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
            p.down_y;
  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
            p.down_x;

  auto out =
      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());

  int mode = -1;

  int tile_out_h = -1;
  int tile_out_w = -1;

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 1;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 3 && p.kernel_w <= 3) {
    mode = 2;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 3;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 2 && p.kernel_w <= 2) {
    mode = 4;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 5;
    tile_out_h = 8;
    tile_out_w = 32;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
      p.kernel_h <= 2 && p.kernel_w <= 2) {
    mode = 6;
    tile_out_h = 8;
    tile_out_w = 32;
  }

  dim3 block_size;
  dim3 grid_size;

  if (tile_out_h > 0 && tile_out_w > 0) {
    p.loop_major = (p.major_dim - 1) / 16384 + 1;
    p.loop_x = 1;
    block_size = dim3(32 * 8, 1, 1);
    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
                     (p.major_dim - 1) / p.loop_major + 1);
  } else {
    p.loop_major = (p.major_dim - 1) / 16384 + 1;
    p.loop_x = 4;
    block_size = dim3(4, 32, 1);
    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
                     (p.major_dim - 1) / p.loop_major + 1);
  }

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
    switch (mode) {
    case 1:
      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 2:
      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 3:
      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 4:
      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 5:
      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 6:
      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    default:
      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
          k.data_ptr<scalar_t>(), p);
    }
  });

  return out;
}

0개의 댓글

관련 채용 정보