CUTLASS的api

CUTLASS库是NVIDIA的开源库,能够通过调节各种参数逼近甚至超越传统cuBLAS库的矩阵乘性能,但是其C++风格式的源码晦涩难懂,通常需要联系多个类才能看懂源码,本文从CUTLASS的表层api入手,逐层递进,对最终的核函数进行解释分析。注意,本文看重的是大矩阵乘法最后的性能,所以分析样例选取的是大矩阵(三个维度都大于等于8192)的性能较优参数,对于小型矩阵或者形状不规则矩阵在此不过多赘述。

CUTLASS库中有profile工具,能够在指定架构上实现多种参数的调优核函数生成,并自动测试Flops,这里我选择的架构是Sm80,GPU为单卡A100 40G,数据类型为A half,B half,C half,也是目前比较主流的数据类型,通过profile工具,修改相应的矩阵乘规模和输入输出数据类型,能够自动进行多轮Flops性能测试,在选取最优值后,进入CUTLASS/example目录。

因为profile工具会自动进行调优,为了在自己设计的文件中运行单独的核函数重新评测性能,可以借用example目录中的例子文件,这里我借用的是ampere_tf32_tensorop_gemm案例,通过将内部的数据类型进行更换,指定对应的矩阵乘尺寸,同时执行该文件,即可单独运行调优出最好的核函数。

注意,执行example的cu文件需要链接一些特殊的库,可以直接在编译指令中链接,编译指令如下所示:

nvcc -I /path/CUTLASS/include -I /path/CUTLASS/tools/util/include -I /path/CUTLASS/examples/common ampere_tf32_tensorop_gemm.cu -o ampere_tf32_tensorop_gemm -arch=compute_80 -O3 -code=sm_80 --ptxas-options=-v

-I 表示增加include的文件夹,-arch 表示了GPU的架构,最后打印PTX汇编消息可以省略

如果只研究表层api,那么介绍到这里就可以结束了,你可以通过调节cu文件中的一些代码,使之能够与cuBLAS库进行性能比较,或者与自己写的核函数进行比较,但是CUTLASS本身是很好的开源库,目前很多项目的矩阵乘法都是根据CUTLASS库进行改变,就算不使用CUTLASS的模板,手写的核函数也大部分利用到了CUTLASS中提及的优化方法,所以如果想要继续研究CUTLASS内部的核函数,就要继续往深层挖掘~

核函数的调用

我们先看一看example的cu文件中如何调用核函数的:

using Gemm = CUTLASS::gemm::device::Gemm<ElementInputA,
                                         LayoutInputA,
                                         ElementInputB,
                                         LayoutInputB,
                                         ElementOutput,
                                         LayoutOutput,
                                         ElementAccumulator,
                                         MMAOp,
                                         SmArch,
                                         ShapeMMAThreadBlock,
                                         ShapeMMAWarp,
                                         ShapeMMAOp,
                                         EpilogueOp,
                                         SwizzleThreadBlock,
                                         NumStages>;
Gemm gemm_op;
status = gemm_op();

上述代码和源代码排布不一样(省略了参数初始化代码),但是已经可以看出,最终执行的核函数就是CUTLASS::gemm::device::Gemm中的operator函数,所以我们通过include文件进一步定位:

在/path/CUTLASS/include/CUTLASS/gemm/device/gemm.h中,可以找到相应类的模板定义:

template <
    /// Element type for A matrix operand
    typename ElementA_,
    /// Layout type for A matrix operand
    typename LayoutA_,
    /// Element type for B matrix operand
    typename ElementB_,
    /// Layout type for B matrix operand
    typename LayoutB_,
    /// Element type for C and D matrix operands
    typename ElementC_,
    /// Layout type for C and D matrix operands
    typename LayoutC_,
    /// Element type for internal accumulation
    typename ElementAccumulator_ = ElementC_,
    /// Operator class tag
    typename OperatorClass_ = arch::OpClassSimt,
    /// Tag indicating architecture to tune for
    typename ArchTag_ = arch::Sm70,
    /// Threadblock-level tile size (concept: GemmShape)
    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::ThreadblockShape,
    /// Warp-level tile size (concept: GemmShape)
    typename WarpShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::WarpShape,
    /// Instruction-level tile size (concept: GemmShape)
    typename InstructionShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::InstructionShape,
    /// Epilogue output operator
    typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::EpilogueOutputOp,
    /// Threadblock-level swizzling operator
    typename ThreadblockSwizzle_ =
        typename threadblock::GemmIdentityThreadblockSwizzle<>,
    /// Number of stages used in the pipelined mainloop
    int Stages =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kStages,
    /// Access granularity of A matrix in units of elements
    int AlignmentA =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentA,
    /// Access granularity of B matrix in units of elements
    int AlignmentB =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentB,
    /// If true, kernel supports split-K with serial reduction
    bool SplitKSerial = false,
    /// Operation performed by GEMM
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator,
    /// Gather operand A by using an index array
    bool GatherA = false,
    /// Gather operand B by using an index array
    bool GatherB = false,
    /// Scatter result D by using an index array
    bool ScatterD = false,
    /// Permute result D
    typename PermuteDLayout = layout::NoPermute>
class Gemm {
 public:
​
  using ElementA = ElementA_;
  using LayoutA = LayoutA_;
  using TensorRefA = TensorRef<ElementA const, LayoutA>;
  using ElementB = ElementB_;
  using LayoutB = LayoutB_;
  using TensorRefB = TensorRef<ElementB const, LayoutB>;
  using ElementC = ElementC_;
  using LayoutC = LayoutC_;
  using TensorRefC = TensorRef<ElementC const, LayoutC>;
  using TensorRefD = TensorRef<ElementC, LayoutC>;
  using ElementAccumulator = ElementAccumulator_;
  using OperatorClass = OperatorClass_;
  using ArchTag = ArchTag_;
  using ThreadblockShape = ThreadblockShape_;
  using WarpShape = WarpShape_;
  using InstructionShape = InstructionShape_;
  using EpilogueOutputOp = EpilogueOutputOp_;
  using ThreadblockSwizzle = ThreadblockSwizzle_;
  using Operator = Operator_;
  static int const kStages = Stages;
  static int const kAlignmentA = AlignmentA;
  static int const kAlignmentB = AlignmentB;
  static int const kAlignmentC = EpilogueOutputOp::kCount;
  static bool const kSplitKSerial = SplitKSerial;
  static ComplexTransform const kTransformA = ComplexTransform::kNone;
  static ComplexTransform const kTransformB = ComplexTransform::kNone;

这里比较有意思,C++里有模板的用法,通过尖括号来表示模板中的参数,但是对于一个同名模板,可以根据传入参数不同执行各自的函数,在CUTLASS中,对于Gemm,除了上述定义外,还有一个当C矩阵是列主序的时候特地的Gemm模板,这个模板里面似乎做了一个转置和调换AB矩阵的功能,但是类似地,列主序时的Gemm里面声明了一个非列主序的相同类模板的Gemm实例,里面调用的run还有operator,最后还是会指向非特地的这个类里面。

一般来说找类中的operator函数,这个函数是指使用括号时会调用的函数,这个类中的op函数指向的是run函数,我们来看run函数如下:

Status run(cudaStream_t stream = nullptr) {
​
    ThreadblockSwizzle threadblock_swizzle;
​
    dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
    dim3 block(GemmKernel::kThreadCount, 1, 1);
​
    cudaError_t result;
​
    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
​
    if (smem_size >= (48 << 10)) {
      result = cudaFuncSetAttribute(Kernel<GemmKernel>,
                                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                                    smem_size);
​
      if (result != cudaSuccess) {
        return Status::kErrorInternal;
      }
    }
​
    CUTLASS::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
​
    result = cudaGetLastError();
​
    return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
  }
​

十分清晰,里面声明了gridDim和blockDim,以及共享内存的大小,是个非常经典的核函数调用函数,说明我们总算找到入口了!所以说CUTLASS所有的秘密都藏在了GemmKernel里面,至于核函数内部的类和如何实现的,且听下回分解。