Cuando se usa el compilador de PyTorch, los modelos pueden correr hasta 10 veces más rápido. Pero, ¿qué pasa por debajo?. Sin compilación, la GPU ejecuta un kernel, una función sobre la GPU, por cada operación de torch en el código. Eso genera dos cuellos de botella grandes: el tiempo gastado moviendo datos en memoria y el overhead de iniciar cada nuevo kernel. Cada vez que la GPU lanza un kernel paga un costo fijo, y cada resultado intermedio significa escribir y leer memoria global.

Ahí entra la fusión. El compilador Inductor de PyTorch agrupa automáticamente operaciones dependientes en un único kernel Triton eficiente. Eso mantiene los datos en memoria rápida cerca de los registros y reduce el overhead por lanzamiento. En este artículo se revisa un ejemplo concreto y se delinean temas para lectura adicional.

¿Qué es la fusión vertical?

Pensá la fusión vertical como una forma de "encadenar" pasos, de modo que la salida de uno alimenta directamente al siguiente. Se llama "vertical" porque al graficar el computational graph, las operaciones quedan apiladas verticalmente: cada una depende del resultado anterior.

Es el patrón de fusión más común en deep learning porque las redes neuronales son cadenas de operaciones: normalización, capas lineales, funciones de activación, y así. La ganancia grande está en eliminar los resultados intermedios, esos tensores temporales que ya no necesitan escribirse ni leerse de memoria global. Quedan en los registros rápidos donde la GPU los alcanza de inmediato.

Ejemplo de pointwise fusion

Las operaciones pointwise son kernels matemáticos simples que actúan elemento por elemento: suma, multiplicación, funciones de activación. Un patrón típico dentro de una capa neuronal:

Código
import torch

def pointwise_example(x, w, b):
    # Multiples operaciones elemento a elemento
    tmp = x * w        # multiplicar
    tmp = tmp + b      # sumar
    tmp = tmp.sigmoid() # activacion sigmoid
    return tmp

Sin fusion: tres kernels separados

Sin fusión, Inductor crea tres kernels Triton independientes. El patrón clave es que cada kernel carga datos, hace una operación y escribe el resultado.

Código
@triton.jit
def mul_kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + x0, xmask)
    tmp1 = tl.load(in_ptr1 + x0, xmask)
    tmp2 = tmp0 * tmp1
    tl.store(out_ptr0 + x0, tmp2, xmask)

Para abreviar, solo se incluyen las firmas de los otros dos kernels (ver el repositorio Git para el código completo).

Entre los tres kernels se realizan ocho operaciones de memoria: dos lecturas para multiplicar, dos para sumar (resultado y bias), una lectura para el sigmoid, más las tres escrituras de resultados. Tráfico de memoria considerable.

Con fusion: un solo kernel

Con fusión, torch.compile crea un único kernel que carga todos los inputs una vez, ejecuta las tres operaciones en cadena y guarda solo el resultado final. Los valores intermedios (tmp2 y tmp4) quedan en registros, la memoria más rápida de la GPU, sin tocar nunca la global.

Beneficios concretos

  • Lanzamientos de kernel: de 3 a 1.
  • Buffers intermedios: 2 eliminados (resultado de multiplicar y de sumar).
  • Ancho de banda de memoria: de 8 operaciones (leer 5 tensores, escribir 3) a 4 operaciones (leer 3, escribir 1). Reducción del 50% en tráfico de memoria.

¿Qué otros tipos de fusion usa Inductor?

La pointwise fusion es solo un tipo de fusión vertical. Inductor implementa otras variantes:

Reduction Fusion: combina operaciones de reducción como max, mean o sum con las operaciones que vienen antes y después. Es crítica para batch normalization.

GEMM + Epilogue Fusion: anexa matemática simple al final de cálculos matriciales pesados. En vez de hacer la multiplicación matricial, escribir el resultado en memoria y volver a leerlo para sumar el bias y aplicar ReLU, el bias y la activación ocurren justo después del multiply dentro del mismo kernel.

Prologue Fusion: lo opuesto al epilogue, el preprocesamiento ocurre mientras se cargan los datos. Por ejemplo, normalizar el input antes de una multiplicación matricial al vuelo.

Además de la fusión vertical, el tipo más prominente, Inductor también usa Horizontal Fusion: corre múltiples operaciones independientes sobre el mismo input simultáneamente. Por ejemplo, computar sin(x) y cos(x) en un solo kernel cargando x una sola vez en lugar de dos.

¿Cómo ver la fusion en tu propio código?

Para probarlo, basta con un script breve usando un patrón de reducción:

Código
import torch

def reduction_example(x):
    tmp = x * 2.0
    result = tmp.sum(dim=-1)
    result = result + 1.0
    return result

x = torch.randn(1024, 1024, device='cuda')
compiled_fn = torch.compile(reduction_example)
result_fused = compiled_fn(x)

Corriendo el script con la variable de entorno TORCH_LOGS="output_code" (TORCH_LOGS="output_code" python fusion_example.py), Inductor imprime los kernels Triton generados en la terminal. Buscando un nombre como triton_per_fused_add_mul_sum_0: el prefijo per indica kernel "por reducción", y el sufijo revela que add, mul y sum fueron fusionadas en uno solo.

Conclusión

La fusión es una de las optimizaciones más importantes de torch.compile. Al enlazar operaciones dependientes dentro de kernels únicos, corta el tráfico de memoria y el overhead de lanzamiento, las dos fuentes principales de lentitud en GPU. No requiere cambiar la implementación: basta agregar el decorator del compilador y dejar que Inductor haga el trabajo.

Más información en la documentación de torch.compile y el repositorio Git con el código fuente del ejemplo.