Cuando se usa el compilador de PyTorch, los modelos pueden correr hasta 10 veces más rápido. Pero, ¿qué pasa por debajo?. Sin compilación, la GPU ejecuta un kernel, una función sobre la GPU, por cada operación de torch en el código. Eso genera dos cuellos de botella grandes: el tiempo gastado moviendo datos en memoria y el overhead de iniciar cada nuevo kernel. Cada vez que la GPU lanza un kernel paga un costo fijo, y cada resultado intermedio significa escribir y leer memoria global.
Ahí entra la fusión. El compilador Inductor de PyTorch agrupa automáticamente operaciones dependientes en un único kernel Triton eficiente. Eso mantiene los datos en memoria rápida cerca de los registros y reduce el overhead por lanzamiento. En este artículo se revisa un ejemplo concreto y se delinean temas para lectura adicional.
¿Qué es la fusión vertical?
Pensá la fusión vertical como una forma de "encadenar" pasos, de modo que la salida de uno alimenta directamente al siguiente. Se llama "vertical" porque al graficar el computational graph, las operaciones quedan apiladas verticalmente: cada una depende del resultado anterior.
Es el patrón de fusión más común en deep learning porque las redes neuronales son cadenas de operaciones: normalización, capas lineales, funciones de activación, y así. La ganancia grande está en eliminar los resultados intermedios, esos tensores temporales que ya no necesitan escribirse ni leerse de memoria global. Quedan en los registros rápidos donde la GPU los alcanza de inmediato.
Ejemplo de pointwise fusion
Las operaciones pointwise son kernels matemáticos simples que actúan elemento por elemento: suma, multiplicación, funciones de activación. Un patrón típico dentro de una capa neuronal:
import torch
def pointwise_example(x, w, b):
# Multiples operaciones elemento a elemento
tmp = x * w # multiplicar
tmp = tmp + b # sumar
tmp = tmp.sigmoid() # activacion sigmoid
return tmpSin fusion: tres kernels separados
Sin fusión, Inductor crea tres kernels Triton independientes. El patrón clave es que cada kernel carga datos, hace una operación y escribe el resultado.
@triton.jit
def mul_kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + x0, xmask)
tmp1 = tl.load(in_ptr1 + x0, xmask)
tmp2 = tmp0 * tmp1
tl.store(out_ptr0 + x0, tmp2, xmask)Para abreviar, solo se incluyen las firmas de los otros dos kernels (ver el repositorio Git para el código completo).
Entre los tres kernels se realizan ocho operaciones de memoria: dos lecturas para multiplicar, dos para sumar (resultado y bias), una lectura para el sigmoid, más las tres escrituras de resultados. Tráfico de memoria considerable.
Con fusion: un solo kernel
Con fusión, torch.compile crea un único kernel que carga todos los inputs una vez, ejecuta las tres operaciones en cadena y guarda solo el resultado final. Los valores intermedios (tmp2 y tmp4) quedan en registros, la memoria más rápida de la GPU, sin tocar nunca la global.
Beneficios concretos
- Lanzamientos de kernel: de 3 a 1.
- Buffers intermedios: 2 eliminados (resultado de multiplicar y de sumar).
- Ancho de banda de memoria: de 8 operaciones (leer 5 tensores, escribir 3) a 4 operaciones (leer 3, escribir 1). Reducción del 50% en tráfico de memoria.
¿Qué otros tipos de fusion usa Inductor?
La pointwise fusion es solo un tipo de fusión vertical. Inductor implementa otras variantes:
Reduction Fusion: combina operaciones de reducción como max, mean o sum con las operaciones que vienen antes y después. Es crítica para batch normalization.
GEMM + Epilogue Fusion: anexa matemática simple al final de cálculos matriciales pesados. En vez de hacer la multiplicación matricial, escribir el resultado en memoria y volver a leerlo para sumar el bias y aplicar ReLU, el bias y la activación ocurren justo después del multiply dentro del mismo kernel.
Prologue Fusion: lo opuesto al epilogue, el preprocesamiento ocurre mientras se cargan los datos. Por ejemplo, normalizar el input antes de una multiplicación matricial al vuelo.
Además de la fusión vertical, el tipo más prominente, Inductor también usa Horizontal Fusion: corre múltiples operaciones independientes sobre el mismo input simultáneamente. Por ejemplo, computar sin(x) y cos(x) en un solo kernel cargando x una sola vez en lugar de dos.
¿Cómo ver la fusion en tu propio código?
Para probarlo, basta con un script breve usando un patrón de reducción:
import torch
def reduction_example(x):
tmp = x * 2.0
result = tmp.sum(dim=-1)
result = result + 1.0
return result
x = torch.randn(1024, 1024, device='cuda')
compiled_fn = torch.compile(reduction_example)
result_fused = compiled_fn(x)Corriendo el script con la variable de entorno TORCH_LOGS="output_code" (TORCH_LOGS="output_code" python fusion_example.py), Inductor imprime los kernels Triton generados en la terminal. Buscando un nombre como triton_per_fused_add_mul_sum_0: el prefijo per indica kernel "por reducción", y el sufijo revela que add, mul y sum fueron fusionadas en uno solo.
Conclusión
La fusión es una de las optimizaciones más importantes de torch.compile. Al enlazar operaciones dependientes dentro de kernels únicos, corta el tráfico de memoria y el overhead de lanzamiento, las dos fuentes principales de lentitud en GPU. No requiere cambiar la implementación: basta agregar el decorator del compilador y dejar que Inductor haga el trabajo.
Más información en la documentación de torch.compile y el repositorio Git con el código fuente del ejemplo.



