Compiling Torch Code – Example on Memory Reduction

Staring from the version 2.x PyTorch, a popular deep-learning framework, introduces a JIT compiler torch.compile. In this post, I am sharing a non-trivial example demonstrating how this tool can reduce memory footprint on GPU. The point of departure is a sub-routine which computes similarity, similar to covariance but not as friendly to compute.

def similiarity(x,y):
    xy = x[:,:,None]-y[:,None,:]
    xy = xy.abs().lt(1).sum(axis=0)
    xy = xy.to('cpu')
    return xy

For two tensors of shape \( (n_{samples},n_{dim})\) it produces a similarity tensor of shape \( (n_{dim},n_{dim})\). However, the logic uses broadcasting when constructing and reducing an intermediate tensor of shape \( (n_{samples},n_{dim},n_{dim})\). Thus, the naive implementation takes \( O(n_{samples}\cdot n_{dim}^2)\) of memory which is seen from the profiler. After compilation, this bottleneck is removed 💪

This is how the profiling code looks like:

import torch
from torch.profiler import profile, record_function, ProfilerActivity

x = torch.randn( (256,2000) ).float().cuda()
torch.cuda.synchronize()

#@torch.compile(mode='max-autotune') # compare the effect with and without !
def similiarity(x,y):
    xy = x[:,:,None]-y[:,None,:]
    xy = xy.abs().lt(1).sum(axis=0)
    xy = xy.to('cpu')
    return xy

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True) as prof:
    with record_function("memory profile"):
        similiarity(x,x)
        torch.cuda.synchronize()

profiler_summary = prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)

And this is a useful utility to convert the profiling output to table

# torch profiler output to pandas

import pandas as pd
import io
import re

total_width = re.search('\n',profiler_summary).start()
widths = [t.end()-t.start()+2 for t in re.finditer('-{1,}',profiler_summary[:total_width])]

df = pd.read_fwf(io.StringIO(profiler_summary),widths=widths)
df.columns = df.loc[0]
df.drop([0,1],axis=0,inplace=True)
df.set_index(df.columns[0], inplace=True)
df.head(10)

This is the output without compiler, note huge memory excess in tensor operations while broadcasting:

NameSelf CPUSelf CUDASelf CUDA Mem# of Calls
aten::empty_strided32.000us0.000us7.63 Gb2
aten::sub69.000us4.624ms3.81 Gb1
aten::resize_10.000us0.000us3.81 Gb1
aten::lt51.000us5.857ms976.56 Mb1
aten::slice32.000us0.000us0 b4
aten::as_strided7.000us0.000us0 b7
aten::unsqueeze10.000us0.000us0 b2
cudaLaunchKernel185.000us0.000us0 b14
void at::native::elementwise_kernel<128, 2, at::nati…0.000us4.624ms0 b2
aten::abs43.000us9.711ms0 b2
Profiling without torch.compiler

Published by mskorski

Scientist, Consultant, Learning Enthusiast

Leave a comment

Your email address will not be published. Required fields are marked *