CUTLASS的api
CUTLASS库是NVIDIA的开源库,能够通过调节各种参数逼近甚至超越传统cuBLAS库的矩阵乘性能,但是其C++风格式的源码晦涩难懂,通常需要联系多个类才能看懂源码,本文从CUTLASS的表层api入手,逐层递进,对最终的核函数进行解释分析。注意,本文看重的是大矩阵乘法最后的性能,所以分析样例选取的是大矩阵(三个维度都大于等于8192)的性能较优参数,对于小型矩阵或者形状不规则矩阵在此不过多赘述。
CUTLASS库中有profile工具,能够在指定架构上实现多种参数的调优核函数生成,并自动测试Flops,这里我选择的架构是Sm80,GPU为单卡A100 40G,数据类型为A half,B half,C half,也是目前比较主流的数据类型,通过profile工具,修改相应的矩阵乘规模和输入输出数据类型,能够自动进行多轮Flops性能测试,在选取最优值后,进入CUTLASS/example目录。
因为profile工具会自动进行调优,为了在自己设计的文件中运行单独的核函数重新评测性能,可以借用example目录中的例子文件,这里我借用的是ampere_tf32_tensorop_gemm案例,通过将内部的数据类型进行更换,指定对应的矩阵乘尺寸,同时执行该文件,即可单独运行调优出最好的核函数。
注意,执行example的cu文件需要链接一些特殊的库,可以直接在编译指令中链接,编译指令如下所示:
nvcc -I /path/CUTLASS/include -I /path/CUTLASS/tools/util/include -I /path/CUTLASS/examples/common ampere_tf32_tensorop_gemm.cu -o ampere_tf32_tensorop_gemm -arch=compute_80 -O3 -code=sm_80 --ptxas-options=-v
-I 表示增加include的文件夹,-arch 表示了GPU的架构,最后打印PTX汇编消息可以省略
如果只研究表层api,那么介绍到这里就可以结束了,你可以通过调节cu文件中的一些代码,使之能够与cuBLAS库进行性能比较,或者与自己写的核函数进行比较,但是CUTLASS本身是很好的开源库,目前很多项目的矩阵乘法都是根据CUTLASS库进行改变,就算不使用CUTLASS的模板,手写的核函数也大部分利用到了CUTLASS中提及的优化方法,所以如果想要继续研究CUTLASS内部的核函数,就要继续往深层挖掘~
核函数的调用
我们先看一看example的cu文件中如何调用核函数的:
using Gemm = CUTLASS::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
Gemm gemm_op;
status = gemm_op();
上述代码和源代码排布不一样(省略了参数初始化代码),但是已经可以看出,最终执行的核函数就是CUTLASS::gemm::device::Gemm中的operator函数,所以我们通过include文件进一步定位:
在/path/CUTLASS/include/CUTLASS/gemm/device/gemm.h中,可以找到相应类的模板定义:
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
typename threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Gather operand A by using an index array
bool GatherA = false,
/// Gather operand B by using an index array
bool GatherB = false,
/// Scatter result D by using an index array
bool ScatterD = false,
/// Permute result D
typename PermuteDLayout = layout::NoPermute>
class Gemm {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
这里比较有意思,C++里有模板的用法,通过尖括号来表示模板中的参数,但是对于一个同名模板,可以根据传入参数不同执行各自的函数,在CUTLASS中,对于Gemm,除了上述定义外,还有一个当C矩阵是列主序的时候特地的Gemm模板,这个模板里面似乎做了一个转置和调换AB矩阵的功能,但是类似地,列主序时的Gemm里面声明了一个非列主序的相同类模板的Gemm实例,里面调用的run还有operator,最后还是会指向非特地的这个类里面。
一般来说找类中的operator函数,这个函数是指使用括号时会调用的函数,这个类中的op函数指向的是run函数,我们来看run函数如下:
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
CUTLASS::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
十分清晰,里面声明了gridDim和blockDim,以及共享内存的大小,是个非常经典的核函数调用函数,说明我们总算找到入口了!所以说CUTLASS所有的秘密都藏在了GemmKernel里面,至于核函数内部的类和如何实现的,且听下回分解。