tilelang

Texnee·2025년 9월 18일

1. Lowering

1.1. Front-End DSL (Python 영역)

예제 파이썬 코드 gemm.py

@tilelang.jit(out_idx=[-1], verbose=True)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

    @T.prim_func
    def gemm(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm
  • tilelang/language/kernel.pyKernelLaunchFrame이 CUDA의 grid/block/thread 개념을 노출하면서, CPU fallback 및 thread extent helper (get_thread_*)를 제공. Python DSL에서 커널 작성 시 실제 실행 컨텍스트임.

  • tilelang/ir.py에서 FFI(Foreign Function Interface) 등록 → Python의 T.*tl.Gemm 같은 stub이 내부적으로 C++ IR 노드(src/ir.cc)에 연결됨.
    stub: 나중에 실제 코드로 대체될 "껍데기 함수"

  • language/ 하위 모듈들이 고수준 문법(T.Parallel, T.Pipelined 등)과 schedule sugar를 제공, 최종적으로 TVM Script와 호환되는 PrimFunc를 뱉음.
    schedule sugar: 간단한 키워드나 문법으로 직관적으로 스케줄을 작성할 수 있도록 해주는 문법적 편의 장치.
    PrimFunc: TVM의 IR(중간 표현, Intermediate Representation) 중 하나로, 저수준 연산(primitive operations)을 정의하는 함수 표현

1.2. Pass Pipeline (중간 최적화 단계)

engine/lower.py:190 → 컴파일 진입점. 두 단계:

  • Phase 1: LowerAndLegalize
    IR Legalize, fragment/shared layout 설정, DSL(Domain specific Language) OP를 낮추고 안전성 체크.
    예: transform/frontend_legalize.cc.

  • Phase 2: OptimizeForTarget
    타깃 특화 최적화 (워프 스페셜라이제이션, async 파이프라인 계획, 벡터화, 메모리 병합 등).
    설정 플래그는 pass-context에서 제어.

FFI를 통해 Python에서 C++ 패스로 들어가며, 핵심 pass는 src/transform/*에 구현.

=== Host IRModule (after lowering) ===
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def gemm(args: T.handle, arg_type_ids: T.handle("int32"), num_args: T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), resource_handle: T.handle) -> T.int32:
        A_desc = T.handle("uint8x128", "grid_constant")
        A = T.handle("float16", "global")
        B_desc = T.handle("uint8x128", "grid_constant")
        B = T.handle("float16", "global")
        T.func_attr({"calling_conv": 1, "target": T.target({"keys": ["cpu"], "kind": "c", "tag": ""}), "thread_extent": {}, "tir.is_entry_func": True, "tma_descriptor_args": {"A_desc": ["__tvm_tensormap_create_tiled", A_desc, 6, 2, A, 1024, 1024, 2, 2048, 32, 128, 1, 1, 0, 2, 2, 0], "B_desc": ["__tvm_tensormap_create_tiled", B_desc, 6, 2, B, 1024, 1024, 2, 2048, 64, 32, 1, 1, 0, 3, 2, 0]}})
        assert num_args == 3, "gemm: num_args should be 3"
        assert not T.isnullptr(args), "gemm: TVMValue* arg pointer was NULL"
        assert not T.isnullptr(arg_type_ids), "gemm: int* type_codes was NULL"
        arg_type_ids_1 = T.decl_buffer((3,), "int32", data=arg_type_ids)
        A_handle_code: T.int32 = arg_type_ids_1[0]
        assert A_handle_code == 0 or A_handle_code == 4 or A_handle_code == 7 or A_handle_code >= 64, "gemm: Expect arg[0] to be pointer"
        B_handle_code: T.int32 = arg_type_ids_1[1]
        assert B_handle_code == 0 or B_handle_code == 4 or B_handle_code == 7 or B_handle_code >= 64, "gemm: Expect arg[1] to be pointer"
        C_handle_code: T.int32 = arg_type_ids_1[2]
        assert C_handle_code == 0 or C_handle_code == 4 or C_handle_code == 7 or C_handle_code >= 64, "gemm: Expect arg[2] to be pointer"
        A_handle: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
        B_handle: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
        C_handle: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
        assert not T.isnullptr(A_handle), "gemm.A_handle is expected to have non-NULL DLTensor* pointer"
        assert 2 == T.tvm_struct_get(A_handle, 0, 4, "int32"), "gemm.A_handle.ndim is expected to equal 2"
        gemm_A_handle_shape: T.handle("int64") = T.tvm_struct_get(A_handle, 0, 2, "handle")
        gemm_A_handle_shape_1 = T.decl_buffer((2,), "int64", data=gemm_A_handle_shape)
        gemm_A_handle_strides: T.handle("int64") = T.tvm_struct_get(A_handle, 0, 3, "handle")
        gemm_A_handle_strides_1 = T.decl_buffer((2,), "int64", data=gemm_A_handle_strides)
        dev_id: T.int32 = T.tvm_struct_get(A_handle, 0, 9, "int32")
        with T.LetStmt(T.tvm_struct_get(A_handle, 0, 1, "handle"), var=A):
            T.attr(A, "storage_alignment", 64)
            assert not T.isnullptr(B_handle), "gemm.B_handle is expected to have non-NULL DLTensor* pointer"
            assert 2 == T.tvm_struct_get(B_handle, 0, 4, "int32"), "gemm.B_handle.ndim is expected to equal 2"
            gemm_B_handle_shape: T.handle("int64") = T.tvm_struct_get(B_handle, 0, 2, "handle")
            gemm_B_handle_shape_1 = T.decl_buffer((2,), "int64", data=gemm_B_handle_shape)
            gemm_B_handle_strides: T.handle("int64") = T.tvm_struct_get(B_handle, 0, 3, "handle")
            gemm_B_handle_strides_1 = T.decl_buffer((2,), "int64", data=gemm_B_handle_strides)
            with T.LetStmt(T.tvm_struct_get(B_handle, 0, 1, "handle"), var=B):
                T.attr(B, "storage_alignment", 64)
                assert not T.isnullptr(C_handle), "gemm.C_handle is expected to have non-NULL DLTensor* pointer"
                assert 2 == T.tvm_struct_get(C_handle, 0, 4, "int32"), "gemm.C_handle.ndim is expected to equal 2"
                gemm_C_handle_shape: T.handle("int64") = T.tvm_struct_get(C_handle, 0, 2, "handle")
                gemm_C_handle_shape_1 = T.decl_buffer((2,), "int64", data=gemm_C_handle_shape)
                gemm_C_handle_strides: T.handle("int64") = T.tvm_struct_get(C_handle, 0, 3, "handle")
                gemm_C_handle_strides_1 = T.decl_buffer((2,), "int64", data=gemm_C_handle_strides)
                C: T.handle("float16", "global") = T.tvm_struct_get(C_handle, 0, 1, "handle")
                T.attr(C, "storage_alignment", 64)
                T.attr("default", "device_id", dev_id)
                T.attr("default", "device_type", 2)
                assert T.tvm_struct_get(A_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(A_handle, 0, 6, "uint8") == T.uint8(16) and T.tvm_struct_get(A_handle, 0, 7, "uint16") == T.uint16(1), "gemm.A_handle.dtype is expected to be float16"
                assert T.Cast("int32", gemm_A_handle_shape_1[0]) == 1024, "Argument gemm.A_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_A_handle_shape[0])"
                assert T.Cast("int32", gemm_A_handle_shape_1[1]) == 1024, "Argument gemm.A_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_A_handle_shape[1])"
                assert T.if_then_else(T.isnullptr(gemm_A_handle_strides), 1, T.Cast("int32", gemm_A_handle_strides_1[1])) == 1, "Argument gemm.A_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(gemm_A_handle_strides), 1, T.Cast(\"int32\", gemm_A_handle_strides_1[1]))"
                assert T.if_then_else(T.isnullptr(gemm_A_handle_strides), T.Cast("int32", gemm_A_handle_shape_1[1]), T.Cast("int32", gemm_A_handle_strides_1[0])) == 1024, "Argument gemm.A_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(gemm_A_handle_strides), T.Cast(\"int32\", gemm_A_handle_shape[1]), T.Cast(\"int32\", gemm_A_handle_strides_1[0]))"
                assert T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, "uint64"), "Argument gemm.A_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, \"uint64\")"
                assert T.tvm_struct_get(A_handle, 0, 10, "int32") == 2, "Argument gemm.A_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(A_handle, 0, 10, \"int32\")"
                assert not T.isnullptr(A), "gemm.A_handle is expected to have non-NULL data pointer"
                assert T.tvm_struct_get(B_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(B_handle, 0, 6, "uint8") == T.uint8(16) and T.tvm_struct_get(B_handle, 0, 7, "uint16") == T.uint16(1), "gemm.B_handle.dtype is expected to be float16"
                assert T.Cast("int32", gemm_B_handle_shape_1[0]) == 1024, "Argument gemm.B_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_B_handle_shape[0])"
                assert T.Cast("int32", gemm_B_handle_shape_1[1]) == 1024, "Argument gemm.B_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_B_handle_shape[1])"
                assert T.if_then_else(T.isnullptr(gemm_B_handle_strides), 1, T.Cast("int32", gemm_B_handle_strides_1[1])) == 1, "Argument gemm.B_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(gemm_B_handle_strides), 1, T.Cast(\"int32\", gemm_B_handle_strides_1[1]))"
                assert T.if_then_else(T.isnullptr(gemm_B_handle_strides), T.Cast("int32", gemm_B_handle_shape_1[1]), T.Cast("int32", gemm_B_handle_strides_1[0])) == 1024, "Argument gemm.B_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(gemm_B_handle_strides), T.Cast(\"int32\", gemm_B_handle_shape[1]), T.Cast(\"int32\", gemm_B_handle_strides_1[0]))"
                assert T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, "uint64"), "Argument gemm.B_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, \"uint64\")"
                assert T.tvm_struct_get(B_handle, 0, 10, "int32") == 2, "Argument gemm.B_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(B_handle, 0, 10, \"int32\")"
                assert dev_id == T.tvm_struct_get(B_handle, 0, 9, "int32"), "Argument gemm.B_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(B_handle, 0, 9, \"int32\")"
                assert not T.isnullptr(B), "gemm.B_handle is expected to have non-NULL data pointer"
                assert T.tvm_struct_get(C_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(C_handle, 0, 6, "uint8") == T.uint8(16) and T.tvm_struct_get(C_handle, 0, 7, "uint16") == T.uint16(1), "gemm.C_handle.dtype is expected to be float16"
                assert T.Cast("int32", gemm_C_handle_shape_1[0]) == 1024, "Argument gemm.C_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_C_handle_shape[0])"
                assert T.Cast("int32", gemm_C_handle_shape_1[1]) == 1024, "Argument gemm.C_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", gemm_C_handle_shape[1])"
                assert T.if_then_else(T.isnullptr(gemm_C_handle_strides), 1, T.Cast("int32", gemm_C_handle_strides_1[1])) == 1, "Argument gemm.C_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(gemm_C_handle_strides), 1, T.Cast(\"int32\", gemm_C_handle_strides_1[1]))"
                assert T.if_then_else(T.isnullptr(gemm_C_handle_strides), T.Cast("int32", gemm_C_handle_shape_1[1]), T.Cast("int32", gemm_C_handle_strides_1[0])) == 1024, "Argument gemm.C_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(gemm_C_handle_strides), T.Cast(\"int32\", gemm_C_handle_shape[1]), T.Cast(\"int32\", gemm_C_handle_strides_1[0]))"
                assert T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, "uint64"), "Argument gemm.C_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, \"uint64\")"
                assert T.tvm_struct_get(C_handle, 0, 10, "int32") == 2, "Argument gemm.C_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(C_handle, 0, 10, \"int32\")"
                assert dev_id == T.tvm_struct_get(C_handle, 0, 9, "int32"), "Argument gemm.C_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(C_handle, 0, 9, \"int32\")"
                assert not T.isnullptr(C), "gemm.C_handle is expected to have non-NULL data pointer"
                A_1 = T.decl_buffer((1024, 1024), "float16", data=A, strides=(1024, 1))
                B_1 = T.decl_buffer((1024, 1024), "float16", data=B, strides=(1024, 1))
                C_1 = T.decl_buffer((1024, 1024), "float16", data=C, strides=(1024, 1))
                assert T.FloorMod(1024, 8) == 0, "A: Vectorize dimension in buffer must be divisible by 8"
                assert T.FloorMod(1024, 8) == 0, "B: Vectorize dimension in buffer must be divisible by 8"
                assert T.FloorMod(1024, 8) == 0, "C: Vectorize dimension in buffer must be divisible by 8"
                T.call_packed("__tvm_set_device", 2, dev_id)
                with T.attr(0, "compute_scope", "gemm_compute_"):
                    with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=A_desc):
                        T.call_packed("__tvm_tensormap_create_tiled", A_desc, 6, 2, A, 1024, 1024, 2, 2048, 32, 128, 1, 1, 0, 2, 2, 0)
                        with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=B_desc):
                            T.call_packed("__tvm_tensormap_create_tiled", B_desc, 6, 2, B, 1024, 1024, 2, 2048, 64, 32, 1, 1, 0, 3, 2, 0)
                            T.call_packed("gemm_kernel", A_desc, B_desc, C, 8, 8, 256, 1, 1, 49152)
                return 0
  • OptimizeForTarget 내부에 WarpSpecialized, PipelinePlanning, MergeSharedMemoryAllocations 등 GPU 특화 패스가 실행됨.

다음 코드는 GPU 특화 패스가 실행된 이후 코드

=== Device IRModule (after target-specific optimizations) ===
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def gemm_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C: T.handle("float16", "global")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 49152, "target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 8, "blockIdx.y": 8, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        C_1 = T.decl_buffer((1048576,), "float16", data=C)
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((128,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 8)
        buf_dyn_shmem = T.allocate([49152], "uint8", "shared.dyn")
        C_local = T.allocate([128], "float32", "local")
        by = T.launch_thread("blockIdx.y", 8)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(6)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(1), 128)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(2), 128)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(3), 128)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(4), 128)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(5), 128)
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            for k in range(32):
                T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k % 6 // 3, 1))
                if T.tl_shuffle_elect(128):
                    T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 8192)
                    T.tma_load(A_desc, T.get_mbarrier(k % 3), T.tvm_access_ptr(T.type_annotation("float16"), buf_dyn_shmem, k % 3 * 4096, 4096, 2), k * 32, by * 128, 0)
                    T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 8192)
                    T.tma_load(B_desc, T.get_mbarrier(k % 3), T.tvm_access_ptr(T.type_annotation("float16"), buf_dyn_shmem, 12288 + k % 3 * 4096, 2048, 2), bx * 128, k * 32, 0)
                    T.tma_load(B_desc, T.get_mbarrier(k % 3), T.tvm_access_ptr(T.type_annotation("float16"), buf_dyn_shmem, 12288 + (k % 3 * 4096 + 2048), 2048, 2), bx * 128 + 64, k * 32, 0)
                T.ptx_arrive_barrier(T.get_mbarrier(k % 3))
        else:
            T.set_max_nreg(240, 1)
            for i in T.unroll(64):
                C_local_1[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0.0), 2)
            T.fence_proxy_async()
            for k in range(32):
                T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k % 6 // 3)
                T.tl_gemm("tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", T.tvm_access_ptr(T.type_annotation("float16"), buf_dyn_shmem, k % 3 * 4096, 4096, 1), T.tvm_access_ptr(T.type_annotation("float16"), buf_dyn_shmem, 12288 + k % 3 * 4096, 4096, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local, 0, 16384, 3))
                T.ptx_arrive_barrier(T.get_mbarrier(k % 3 + 3))
            for i in T.unroll(64):
                C_1[by * 131072 + i // 32 * 65536 + tx // 32 * 16384 + i % 2 * 8192 + tx % 32 // 4 * 1024 + bx * 128 + i % 32 // 2 * 8 + tx % 4 * 2:by * 131072 + i // 32 * 65536 + tx // 32 * 16384 + i % 2 * 8192 + tx % 32 // 4 * 1024 + bx * 128 + i % 32 // 2 * 8 + tx % 4 * 2 + 2] = T.Cast("float16x2", C_local_1[i * 2:i * 2 + 2])

1.3. Code Generation & Runtime

engine/lower.py:232 → 최적화 후 즉시 코드젠 or 소스만 출력 선택.

  • CUDA/HIP: nvcc/hipcc 호출

  • CPU: TVM LLVM backend.

  • C++ 코드젠(src/target/codegen_cuda.cc 등)에서 TileLang 전용 lowering (예: FP8/FP6, launch config 추출).

  • 런타임 계층(src/runtime/runtime.cc)에서 Hopper 기능(CUDA TensorMap 등) 지원.

위 과정을 거친 뒤의 코드


=== Generated Kernel Source ===
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, half_t* __restrict__ C);
extern "C" __global__ void __launch_bounds__(256, 1) gemm_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, half_t* __restrict__ C) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float C_local[128];
  __shared__ uint64_t mbarrier_mem[6];
  auto mbarrier = reinterpret_cast<Barrier*>(mbarrier_mem);
  if (tl::tl_shuffle_elect<0>()) {
    tl::prefetch_tma_descriptor(A_desc);
    tl::prefetch_tma_descriptor(B_desc);
    mbarrier[0].init(128);
    mbarrier[1].init(128);
    mbarrier[2].init(128);
    mbarrier[3].init(128);
    mbarrier[4].init(128);
    mbarrier[5].init(128);
  }
  __syncthreads();
  if (128 <= ((int)threadIdx.x)) {
    tl::warpgroup_reg_dealloc<24>();
    for (int k = 0; k < 32; ++k) {
      mbarrier[((k % 3) + 3)].wait((((k % 6) / 3) ^ 1));
      if (tl::tl_shuffle_elect<128>()) {
        mbarrier[(k % 3)].expect_transaction(8192);
        tl::tma_load(A_desc, mbarrier[(k % 3)], (&(((half_t*)buf_dyn_shmem)[((k % 3) * 4096)])), (k * 32), (((int)blockIdx.y) * 128));
        mbarrier[(k % 3)].expect_transaction(8192);
        tl::tma_load(B_desc, mbarrier[(k % 3)], (&(((half_t*)buf_dyn_shmem)[(((k % 3) * 4096) + 12288)])), (((int)blockIdx.x) * 128), (k * 32));
        tl::tma_load(B_desc, mbarrier[(k % 3)], (&(((half_t*)buf_dyn_shmem)[(((k % 3) * 4096) + 14336)])), ((((int)blockIdx.x) * 128) + 64), (k * 32));
      }
      mbarrier[(k % 3)].arrive();
    }
  } else {
    tl::warpgroup_reg_alloc<240>();
    #pragma unroll
    for (int i = 0; i < 64; ++i) {
      *(float2*)(C_local + (i * 2)) = make_float2(0x0p+0f/*0.000000e+00*/, 0x0p+0f/*0.000000e+00*/);
    }
    tl::fence_proxy_async();
    for (int k_1 = 0; k_1 < 32; ++k_1) {
      mbarrier[(k_1 % 3)].wait(((k_1 % 6) / 3));
      tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>((&(((half_t*)buf_dyn_shmem)[((k_1 % 3) * 4096)])), (&(((half_t*)buf_dyn_shmem)[(((k_1 % 3) * 4096) + 12288)])), (&(C_local[0])));
      mbarrier[((k_1 % 3) + 3)].arrive();
    }
    #pragma unroll
    for (int i_1 = 0; i_1 < 64; ++i_1) {
      uint1 __1;
      float2 v_ = *(float2*)(C_local + (i_1 * 2));
      ((half2*)(&(__1.x)))->x = (half_t)(v_.x);
      ((half2*)(&(__1.x)))->y = (half_t)(v_.y);
      *(uint1*)(C + ((((((((((int)blockIdx.y) * 131072) + ((i_1 >> 5) * 65536)) + ((((int)threadIdx.x) >> 5) * 16384)) + ((i_1 & 1) * 8192)) + (((((int)threadIdx.x) & 31) >> 2) * 1024)) + (((int)blockIdx.x) * 128)) + (((i_1 & 31) >> 1) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __1;
    }
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_gemm_kernel = cudaFuncSetAttribute(gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 49152);
    if (result_gemm_kernel != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 49152, cudaGetErrorString(result_gemm_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {

        CUtensorMap A_desc;
        CUtensorMapDataType A_desc_type= (CUtensorMapDataType)6;
        cuuint32_t A_desc_tensorRank= 2;
        void *A_desc_globalAddress= A;
        cuuint64_t A_desc_globalDim[2]= {1024,1024};
        cuuint64_t A_desc_globalStride[2]= {2,2048};
        cuuint32_t A_desc_boxDim[2]= {32,128};
        cuuint32_t A_desc_elementStrides[2]= {1,1};
        CUtensorMapInterleave A_desc_interleave= (CUtensorMapInterleave)0;
        CUtensorMapSwizzle A_desc_swizzle= (CUtensorMapSwizzle)2;
        CUtensorMapL2promotion A_desc_l2Promotion= (CUtensorMapL2promotion)2;
        CUtensorMapFloatOOBfill A_desc_oobFill= (CUtensorMapFloatOOBfill)0;

        CUresult A_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &A_desc, A_desc_type, A_desc_tensorRank, A_desc_globalAddress, A_desc_globalDim, A_desc_globalStride + 1, A_desc_boxDim, A_desc_elementStrides, A_desc_interleave, A_desc_swizzle, A_desc_l2Promotion, A_desc_oobFill);

        if (A_desc_result != CUDA_SUCCESS) {
                std::stringstream ss;
                ss << "Error: Failed to initialize the TMA descriptor A_desc";
                snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
                return -1;
        }

        CUtensorMap B_desc;
        CUtensorMapDataType B_desc_type= (CUtensorMapDataType)6;
        cuuint32_t B_desc_tensorRank= 2;
        void *B_desc_globalAddress= B;
        cuuint64_t B_desc_globalDim[2]= {1024,1024};
        cuuint64_t B_desc_globalStride[2]= {2,2048};
        cuuint32_t B_desc_boxDim[2]= {64,32};
        cuuint32_t B_desc_elementStrides[2]= {1,1};
        CUtensorMapInterleave B_desc_interleave= (CUtensorMapInterleave)0;
        CUtensorMapSwizzle B_desc_swizzle= (CUtensorMapSwizzle)3;
        CUtensorMapL2promotion B_desc_l2Promotion= (CUtensorMapL2promotion)2;
        CUtensorMapFloatOOBfill B_desc_oobFill= (CUtensorMapFloatOOBfill)0;

        CUresult B_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &B_desc, B_desc_type, B_desc_tensorRank, B_desc_globalAddress, B_desc_globalDim, B_desc_globalStride + 1, B_desc_boxDim, B_desc_elementStrides, B_desc_interleave, B_desc_swizzle, B_desc_l2Promotion, B_desc_oobFill);

        if (B_desc_result != CUDA_SUCCESS) {
                std::stringstream ss;
                ss << "Error: Failed to initialize the TMA descriptor B_desc";
                snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
                return -1;
        }
        gemm_kernel<<<dim3(8, 8, 1), dim3(256, 1, 1), 49152, stream>>>(A_desc, B_desc, C);
        TILELANG_CHECK_LAST_ERROR("gemm_kernel");

        return 0;
}

1.4. JIT Execution & Caching

  • jit/__init__.py, jit/kernel.py: JIT 로직.
  • PrimFunc → Lower → Build → Backend Adapter 래핑(ctypes, DLPack 등).
  • cache/kernel_cache.py: 커널 캐싱. IR + target/backend 옵션 해시로 재사용.
  • env.py의 환경변수/토글(enable_cache, TILELANG_CACHE_DIR)과 연동.

1.5. Environment & Target Detection

  • env.py:111: include path, cache dir, env var 제어.
  • utils/target.py: CUDA/HIP 자동 탐지 및 기능 질의 (target_is_hopper, target_has_async_copy 등).
    이 정보가 Pass 단계에서 최적화 플래그 결정에 사용됨.

2. 실제 lowering 예제

3. WarpSpecialized Pipeline

  • tilelang/engine/phase.py: 118 line에서 Hopper 계열로 판단되면 tilelang.transform.WarpSpecialized()PipelinePlanning/InjectSoftwarePipeline 패스를 실행함.
    allow_tma_and_warp_specialized에서 Hopper 계열을 판단하는 힌트가 있음
  • 복사 전용 워프와 텐서코어 연산 워프를 분리하고 겹치게 배치함. 실제 역할 분할과 타임라인 재배치는 C++ 구현인 src/transform/warp_specialized_rewriter.cc:1에 있으며, 여기서 각 문장을 Producer(SIMT 복사)/Consumer(Tensor Core)/Both로 마킹하고 warp 단위 파이프를 구성함.
  • 그 전에 복사·연산 단계를 추출해 파이프 스테이지를 계획하는 로직은 src/transform/pipeline_planning.cc:38 이후에 있어서, Tensor Core 연산이 진행되는 동안 CUDA 코어 쪽의 TMA/LDGSTS 같은 메모리 이동을 겹칠 수 있게 스케줄을 만들어 줌
profile
별 하나의 추억과.

0개의 댓글