필자는 진로를 ML 시스템 레벨 최적화로 결정하였고, 커리어를 위해 cuda를 배워야겠다고 생각하고 있다. 그러던 중, 커뮤니티에서 triton의 사용에 대해 이야기를 듣기 시작한지 몇 주 되었다. 하지만 triton은 block 사이즈를 어떻게 지정하느냐에 따라 성능이 다르다고 알고 있기에 아직 많이 공부할 필요성을 느끼지 못했었다.
그러데 모델 성능을 높일 수 있도록 바로 cuda로 코드를 짤 수 있는 역량도 부족하기에 어떻게 cuda와 친해질지 골머리를 앓고 있었다. 책으로만 cuda programming을 공부하기도 어렵고, 강의로 배우고 한계가 존재한다고 생각했다. 이 부분을 해결할 방법으로 pytorch의 코드 부분 중 성능을 높일 부분을 찾고, 성능을 높일 수 있는 코드를 붙여서 bottleneck을 없얠 수 있는 과정에서 triton을 거쳐 cuda로 코드를 작성할 수 있는 방법을 알게 되었다.
이번 글에서는 pytorch engineer로서 profiling 할 수 있는 툴과 방법에 대해서 알아보는 단계에 대해 글을 작성해보려고 한다. 특히 profile과 pytorch와 triton, 그리고 cuda를 엮어서 작업할 때 사용할 툴에 대해 알아갈 예정이다.
profile한다면 어떻게 하지?
하나의 operator를 torch.autograd.profile로 측정해보자.
cuda는 pytorch와 다르게 async하기 때문에 time module을 사용해서 측정할 수 없다.
그렇기 때문에 cuda는 start event와 end event를 사용해서 측정한다.
def time_pytorch_function(func, example): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) # Warmup for _ in range(5): func(example) start.record() func(example) end.record() torch.cuda.synchronize() # Since cuda is asynchronous, First synchronize it. return start.elapsed_time(end)
실제로 코드를 살펴보면 start와 end라는 event를 선언해서 time 측정은 start와 end event의 시간 차를 반환해주는 것을 볼 수 있다. 그렇다면 이 함수를 활용해서 실제로 autograd의 profile 기능을 사용해서 측정해보자.
먼저 torch.square는 aten::pow와 aten::square를 사용한 것을 알 수 있다. cpu time 은 총 3.447ms, cuda time은 총 3.506ms임을 확인할 수 있다.
마찬가지로 이번에는 a*a의 연산을 profile해보자. 이 연산작업에서는 aten::mul을 사용한다. 그리고 간단히 결과만 확인했을 때 cpu time과 cuda time이 상대적으로 좀 더 빠르다는 것을 확인할 수 있다.
with torch.autograd.profiler.profile(use_cuda=True) as prof: torch.square(b) # row limit with sorted by most time consuming print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
저런 로그처럼 출력하는 방식은 바로 위 코드처럼 table형태로 만들어주고 sort key값을 줌으로써 profile하는데 도움을 줄 수 있다.
Pytorch profiler : 제일 익숙한 profiler
Chrome trace를 사용해서 json으로 export된 파일을 시각화 툴을 사용해서 분석된다.
하지만 이 profiler는 사실 얼마나 시간이 걸리는지에 대해서만 알려줄 수 있는 형태고 더 성능을 올릴 수 있는 방법에 대해 알기 어렵다. 그리고 아직 필자가 profiler를 제대로 활용하는 것이 익숙하지 않고 다른 블로그에서 소개를 많이 하기 때문에, 해당 profiler에 대한 소개는 생략하겠다.
custom cpp extension을 붙이는 법
그래서 cuda로 짠 코드를 python에서 사용하고 싶다면 아래와 같이 2가지 방법 중 한 개를 선택하면 된다.
1. pybind를 사용하기
2. torch.utils.cpp_extension에 구현된 load_inline을 사용해서 c++로 구현된 함수를 붙인다.
1번의 방식은 아직 익숙하지 않기 때문에 간단한 2번 방법에 대해서 알아보자.
cuda_source = ''' __global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < height && col < width) { int idx = row * width + col; result[idx] = matrix[idx] * matrix[idx]; } } torch::Tensor square_matrix(torch::Tensor matrix) { const auto height = matrix.size(0); const auto width = matrix.size(1); auto result = torch::empty_like(matrix); dim3 threads_per_block(16, 16); dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x, (height + threads_per_block.y - 1) / threads_per_block.y); square_matrix_kernel<<<number_of_blocks, threads_per_block>>>( matrix.data_ptr<float>(), result.data_ptr<float>(), width, height); return result; } '''
위와 같이 cuda의 square matrix kernel이라고 하는 kernel을 직접 구현한 것을 볼 수 있다. 간단한 예제로 result에다가 matrix를 순회하며 제곱을 다시 넣어주는 것을 볼 수 있다. 그리고 square_matrix 함수를 c++ 함수로 만들어서 python에서 matrix라는 torch::Tensor형 input이 들어오면, 그 input을 kernel에 넘겨주는 wrapper를 만들었다.
cpp_source = "torch::Tensor square_matrix(torch::Tensor matrix);" # Load the CUDA kernel as a PyTorch extension square_matrix_extension = load_inline( name='square_matrix_extension', cpp_sources=cpp_source, cuda_sources=cuda_source, functions=['square_matrix'], with_cuda=True, extra_cuda_cflags=["-O2"], build_directory='./load_inline_cuda', ) a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda') print(square_matrix_extension.square_matrix(a)) # tensor([[ 1., 4., 9.], # [16., 25., 36.]], device='cuda:0')
cuda_source와 cpp_source를 나눠서 선언했는데, 이는 사실 load_inline에서 각각 명시해서 넣어준다. cpp_source는 header의 역할, cuda_source는 실제로 구현한 부분을 넣어준다고 생각하면 될 것 같다.
load_inline에서는 cuda_cflags인 compiler 옵션이나 name을 명시할 수도 있다.
결과적으로 a라는 tensor를 만들고 square_matrix_extension에 있는 square_matrix를 호출하면 실제로 위에서 구현된 kernel이 사용된다.
더 성능을 높일 수 있는 방법을 알려주는 profiler가 있다면?
triton
triton은 dsl로 cuda 언어로 script를 만들어준는게 아니라, cuda assembly code인 ptx파일을 생성한다.
그리고 triton은 python을 사용해서 구현한다. 하지만 triton은 triton 대로 performance 이슈가 존재한다.
실제로 위 그래프와 같이 pytorch native code보다 triton보다 나은 경우들을 만나는 경우가 적지 않다. 즉, 성능 상의 이슈가 있어서 이 문제를 해결해야 triton을 사용하는데 의의가 있다.
@triton.jit(interpret=True) def square_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) breakpoint() input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) square_output = row * row breakpoint() # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, square_output, mask=col_offsets < n_cols)
위의 코드를 잠깐 살펴보자. 이 코드에서는 triton으로 똑같은 예제인 square_kernel을 구현한 경우다. cuda구현 부분과 비슷한데 이 코드는 row 단위로 계산을 하는 것을 볼 수 있다. 이때 row 값을 SRAM으로 올려서 loading하는 과정, 그리고 outputdmf DRAM으로 저장하는 과정을 볼 수 있다.
이런 문제는 profiler로 해결할 수 있다. 결과적으로 block size를 1024로 바꾸니 성능을 향상시킬 수 있었다.
그리고 triton.jit(interpret=true) 값을 줘서 실제로 python debugger를 사용할 수 있다. trition로 구현한 것들을 모두 wrappedTensor 객체가 감싸기 때문에 var_name.tensor로 tensor값을 디버깅할 수 있다. 하지만, 그냥 Python breakpoint를 사용하면 속도가 너무 느려서 이 다음에 나오는 profiler를 사용하는 것을 추천한다.
Trick : Generate a triton kernel
이렇게 위에서 triton과 pytorch profiler에 대해서 배웠는데, 그럼 triton도 cuda처럼 kernel을 구현해야 하는데 너무 공부하기 어렵다고 생각이 들었다.
하지만, 반대로 torch.compile기능을 사용해서 pytorch로 구현된 연산자를 triton kernel로 변환해서 output으로 만들면 더 배우기 쉽다.
실제로 row major로 matrix연산을 직접 만들었는데, compile된 kernel의 경우 다른 방식으로 구현되었다면 스스로 공부하는데 피드백을 더 빨리 받을 수 있다.
간단한 코드로 어떻게 할 수 있을지 확인해보자.
# compile_square.py import torch def square(a): a = torch.square(a) return torch.square(a) opt_square = torch.compile(square) opt_square(torch.randn(10000,10000).cuda()) # 위 python script를 이 명령어를 진행하면 아래와 같이 log를 볼 수 있다. TORCH_LOGS = "OUTPUT_CODE" python compile_square.py
위 python script를 이 명령어를 진행하면 아래와 같이 log를 볼 수 있다.
-- 생략 -- [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] xnumel = 100000000 [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:] [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask) [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp1 = tmp0 * tmp0 [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp1 * tmp1 [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp2, xmask) [2024-04-13 15:23:42,514] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''') -- 생략 --
실제로 pytorch operator의 optimize된 triton kernel을 구할 수 있으니, 이 코드를 살펴보면서 kernel 구현을 더 보완하면 좋을 것이다.
ncu → Not work in vast cloud. But a good profiler
ncu profiler가 이번 글에서 소개하고 싶은 profiler다. 실제로 operator의 성능을 측정하고 최적화 가이드라인을 받을 수 있기에 beginner들이 활용하면 더 좋을 툴이라고 생각한다. 간단한 Log 예시를 살펴보자.
ncu profiler는 아래와 같은 Log를 보여주는 profiler다.
void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::native::templates::cuda::normal_and_transform<float, float, (unsigned long)4, at::CUDAGeneratorImpl *, void at::native::templates::cuda::normal_kernel<at::CUDAGeneratorImpl *>(const at::TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 2)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, T4, T5)::[lambda(curandStatePhilox4_32_10 *) (instance 2)], void at::native::<unnamed>::distribution_nullary_kernel<float, float, (int)4, at::CUDAGeneratorImpl *, void at::native::templates::cuda::normal_and_transform<float, float, (unsigned long)4, at::CUDAGeneratorImpl *, void at::native::templates::cuda::normal_kernel<at::CUDAGeneratorImpl *>(const at::TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 2)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, T4, T5)::[lambda(curandStatePhilox4_32_10 *) (instance 2)], void at::native::templates::cuda::normal_kernel<at::CUDAGeneratorImpl *>(const at::TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 2)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, T4, const T5 &, T6)::[lambda(int, float) (instance 1)]>(int, at::PhiloxCudaState, T3, T4) (864, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.0 Section: GPU Speed Of Light Throughput ----------------------- ------------- ------------ Metric Name Metric Unit Metric Value ----------------------- ------------- ------------ DRAM Frequency cycle/nsecond 1.21 SM Frequency cycle/nsecond 1.07 Elapsed Cycles cycle 15606 Memory Throughput % 15.59 DRAM Throughput % 0.00 Duration usecond 14.46 L1/TEX Cache Throughput % 12.67 L2 Cache Throughput % 21.89 SM Active Cycles cycle 13050.31 Compute (SM) Throughput % 56.42 ----------------------- ------------- ------------ OPT This kernel exhibits low compute throughput and memory bandwidth utilization relative to the peak performance of this device. Achieved compute throughput and/or memory bandwidth below 60.0% of peak typically indicate latency issues. Look at Scheduler Statistics and Warp State Statistics for potential reasons.
위에 보이는 것처럼 여러 ThroughPut 수치와 다른 정보들을 보여주는데 OPT라는 log를 주목해야 한다.
여기서 kernel이 low compute throughput을 갖고 있다는 것과 memory bandwidth의 성능이 60%밖에 못 내고 있다는 것을 알려주고 있다. 마지막으로 친절하게 Scheduler Statistics and warp state statistics를 다시 살펴봐라라는 메뉴얼도 제공해주고 있다. 이 정보를 가지고 kernel구현을 보강하는 것을 권장한다.
# 사용법 log만 보고 싶을 때 ncu python train.py # 시각화 정보도 같이 보고 싶을 때 ncu --set full -o output $(which python) train.py
즉, optimization opportunity를 ncu profiler로 제공받을 수 있다!
마무리
이번 글에서는 pytorch → triton → cuda로 내려가면서 optimization을 진행할 때 사용될 수 있는 툴에 대해서 알아보았다. 다음 시간에는 실제로 이 툴들을 활용해서 operator의 최적화를 직접 진행해보는 것을 공유해볼 예정이다.
이 글에서는 udemy의 테크니컬 라이팅 글쓰기를 바탕으로 이전 글들의 보완점을 수정하여 작성되었음을 밝힙니다.
Share article