Pourquoi l'entraînement distribué est difficile

Lorsqu'on fragmente (shard) un tenseur sur un groupe de processus, chaque gradient qui remonte à travers ce fragment doit être identique à celui qu'on aurait obtenu sur un seul GPU. Faire cela manuellement oblige à disséminer des collectives dans le modèle, à gérer les hypothèses de placement à l'intérieur des opérateurs et à maintenir des chemins de code séparés pour FSDP, le parallélisme de tenseur et le parallélisme de pipeline. C'est étonnamment facile de se tromper, et les bugs sont presque toujours silencieux.

DTensor tente d'unifier ces préoccupations.

DTensor (Distributed Tensor de PyTorch) attache à chaque tenseur une petite métadonnée décrivant son placement : Replicate, Shard(dim) ou Partial(sum). Les opérateurs propagent ensuite automatiquement ces placements et insèrent les collectives appropriées quand les tenseurs doivent se déplacer entre les mises en page. En théorie, cela donne des abstractions plus propres et un passage à l'échelle plus sûr. En pratique, cela résout une classe de problèmes et en crée une autre.

Quatre tentatives pour paralléliser un module de trois lignes

Une illustration concrète est fournie avec un module de modulation tiré d'un transformateur de diffusion. Ce module projette un conditionnement (timestep, étiquette de classe, caractéristiques de texte…) en une échelle par canal qu'il multiplie dans les activations des jetons (tokens). L'objectif : fragmenter les tokens sur un groupe de processus, calculer localement, rassembler le résultat et produire des gradients identiques à la référence mono-GPU.

  • Tentative 1 : torch.chunk et all_gather. Le résultat avant est correct, mais tous les gradients sont faux. Le problème vient de la rétropropagation de torch.chunk : localement, chaque rang place le gradient entrant dans la tranche correspondante et remplit de zéros le reste. Mais dans le contexte distribué, on a besoin du gradient complet sur chaque rang, et chunk ignore l'existence des autres rangs.

  • Tentative 2 : un scatter personnalisé. On remplace torch.chunk par une fonction autograd personnalisée dont la rétropropagation all-gather et concatène les gradients partiels. Cette fois, tokens.grad est cohérent entre les rangs, mais il est exactement le double de la ligne de base. La cause : all_gather appelle reduce_scatter dans sa rétropropagation (somme entre rangs puis division), mais le gradient amont est identique sur les deux rangs (la perte est calculée sur la sortie rassemblée et répliquée), donc la somme double chaque valeur.

  • Tentative 3 : un all_gather-vers-replicate personnalisé. On écrit une deuxième fonction autograd dont la rétropropagation est simplement un chunk (chaque rang prend sa propre tranche du gradient amont, sans réduction). tokens.grad correspond enfin à la référence ! Mais cond.grad et weight.grad restent erronés : différents sur chaque rang, leur somme donne la valeur correcte. La raison : cond est répliqué, mais sur chaque rang il n'interagit qu'avec la tranche locale des tokens ; chaque rang ne contient donc que la contribution de sa moitié du travail. Quand un tenseur répliqué est consommé avec un tenseur fragmenté, son gradient atterrit comme une somme partielle et nécessite une réduction explicite.

  • Tentative 4 : copie-parallèle pour cond et self.weight. On ajoute une troisième fonction autograd : identité à l'aller, all_reduce au retour. Les deux tenseurs répliqués consommés avec les tokens fragmentés sont enveloppés. Maintenant, tous les gradients (tokens.grad, cond.grad, weight.grad) correspondent à la référence.

Le tableau de bord des quatre tentatives montre qu'avec trois fonctions autograd personnalisées, le parallélisme est désormais couplé à la forme exacte de la passe avant.

Coûts cachés et couplage à grande échelle

Même lorsque DTensor garantit la correction mathématique, l'abstraction a un prix. L'article technique souligne que « DTensor rend l'entraînement distribué correct en attachant des métadonnées de placement à chaque tenseur. À grande échelle, cela peut aussi introduire des coûts qui érodent silencieusement le débit, sauf si on les conçoit en connaissance de cause. »

En pratique, les métadonnées de placement doivent être propagées et vérifiées pour chaque opération, ce qui ajoute une surcharge de calcul et de communication. De plus, le fait que le parallélisme soit « automatique » peut masquer des inefficacités : un développeur peut croire que l'abstraction gère tout, alors que la disposition exacte des tenseurs (sharding, réplication, somme partielle) influence fortement les collectives insérées et donc le débit réel.

Vers des conceptions qui tiennent compte des coûts

L'article technique, publié le 18 mai 2026 par Wei Zhang, ingénieur de recherche chez Runway, sert d'avertissement : sans une compréhension fine des mécanismes sous-jacents (quand insérer un all_reduce plutôt qu'un reduce_scatter, comment éviter les doubles comptages), les abstractions distribuées peuvent donner l'illusion de la simplicité tout en imposant des pénalités de performance. La leçon est que la correction mathématique ne garantit pas l'efficacité, et que les concepteurs de modèles doivent rester vigilants, même avec des outils comme DTensor.

Implications pour l'industrie

Cette analyse a des implications directes pour les équipes d'ingénierie qui travaillent sur des modèles de grande taille (LLM, diffusion, etc.). L'entraînement distribué est un domaine où les bugs de gradient sont coûteux et difficiles à détecter. Les abstractions comme DTensor apportent une sécurité importante, mais elles ne dispensent pas d'une compréhension approfondie des compromis entre correction, performance et complexité de code. Runway, une entreprise spécialisée dans les modèles génératifs visuels, souligne ainsi que la maîtrise des collectives et des métadonnées de placement reste cruciale pour l'entraînement efficace à grande échelle.