Wasserstein GAN

Wasserstein GAN

2017-04-11T00:00+02:00

Les GAN (Generative Adversarial Networks) ont déjà été présentés sur ce blog https://data-alchemy.net/blog/post/3(https://data-alchemy.net/blog/post/3). Ils constituent un tout nouveau pan des réseaux de neurones et se sont imposés ces deux dernières années comme le modèle de référence pour générer de la donnée "originale", et plus globalement pour modéliser la distribution d'une donnée revoir notre article sur les VAEs(https://data-alchemy.net/blog/post/5)

Une publication récente a effectué un petit coup de semonce parmi les chercheurs qui travaillent ces architectures. Celle-ci : Wasserstein GAN - Martin Arjovsky Soumith Chintala Léon Bottou(https://arxiv.org/abs/1701.07875) propose un nouveau mode d'apprentissage de ces réseaux qui se distingue car :

  • Il donne un canevas de travail enfin un minimum fiable dans la convergence de ces réseaux, notamment en proposant une fonction d'erreur qui soit réellement reliée à la qualité de la donnée générée, et en assurant une convergence moins hasardeuse qu'auparavant.
  • Il s'inspire de la Théorie Optimale du Transport (Cedric Villani) qui arrive de plus en plus dans le champ théorique des réseaux de neurones pour proposer une approche mathématiquement plus structurée.

(L'immense majorité des graphiques et équations de ce post sont directement issus de la publication originelle)

Différentes distances...

Reprenant le fil depuis notre article sur les VAEs et notamment la définition d'une distribution de données, nous savons qu'un GAN cherche à évoluer de manière à se rapprocher d'une distribution "cible". Reprenant un exemple classique où l'on veut générer des photographies de visages de 32 par 32 pixels, l'espace global est l'ensemble des images possibles de 32x32 pixels, et nous cherchons à "comprendre" (approximer) la distribution des images représentant un visage sur ces dimensions.

À chaque itération, notre réseau de génération représente donc une approximation de cette information que nous allons confronter à un modèle (un dataset d'images de référence) pour savoir comment ce réseau doit évoluer afin de mieux approximer la distribution cible. Afin de savoir exactement "comment" évoluer, il va exploiter une distance mathématique permettant d'exprimer en quoi diffèrent les deux distributions : celle en cours d'apprentissage et celle que nous voulons reproduire. Cette distance est importante car c'est d'elle (où d'une approximation de cette distance) qu'est tirée la fonction de coût dont la dérivée permet de faire évoluer les poids du réseau.

La publication WGAN s'intéresse tout particulièrement à différentes distances :

  • La TV (Total Variation) Distance. Présentée dans la publication à titre de référence, elle n'est pas utilisée en pratique.
  • La Kullback-Leibler Divergence. Cette distance est généralement la plus utilisée dans les approches d'inférence variationelle et a beaucoup été exploitée pour les GANs ou VAEs. Sans trop rentrer dans les mathématiques, cette distance présente plusieurs défauts : elle n'est pas symétrique (la distance entre Pa et Pb n'est pas la distance entre Pb et Pa), et sa définition peut donner lieu à des valeurs "absurdes" dès lors que les deux distributions comparées ne sont pas définies sur le même sous-espace.
  • La Jensen-Shannon Divergence. Cette distance est dérivée de la KL Divergence en adressant un certain nombre de ses soucis. La JS est symétrique (ce qui est déjà quelque peu rassurant) et est aussi moins réductrice quand aux définitions respectives des distributions comparées. Elle est reconnue comme plus utilisée dans le Tutorial GAN NIPS 2016 de Goodfellow, mais sans apporter une énorme amélioration aux problèmes de convergence rencontrés.

Les auteurs de la publication WGAN proposent eux une nouvelle distance : La Earth Mover Distance ou Wasserstein-1. Cette distance vient directement des travaux de Villanie (cocorico) sur la théorie optimale du transport, théorie qui s'introduit de plus en plus dans les travaux en Deep Learning. L'objet de ce blog restant de conserver une distance respectueuse des mathématiques, nous n'en donnerons qu'une intuition : entre deux distributions de données, cette distance exprime le coût nécessaire pour transporter l'information de l'une à l'autre. Cette "definition" intuitive n'est naturellement pas directement exploitable algorithmiquement, mais peut être approximée d'une manière fiable et ainsi être utilisée pour faire converger nos réseaux.

Le schéma ci-dessous, issu de la publication originale, présente un cas usuel dans l'entraînement des GANs. Les points bleus à gauche représentent la donnée réelle que nous voulons générer. Les points verts eux présentent la donnée 'fausse' générée à un instant donné par le GAN. Le discriminateur est donc supposé effectuer une distinction entre ces deux distributions et la distance utilisée doit fournir un gradiant exploitable permettant d'évoluer entre ces distributions.

curves1 La courbe rouge présente donc le gradient d'un discriminateur classique de GAN, et l'on voit que ce dernier est quasiment inexploitable, constant sur les distributions en question. À l'inverse, la courbe bleu claire illustre le gradient présenté par un WGAN, exploitable sur l'ensemble de l'espace de travail.

Résultats

L'immense apport de cette publication est d'offrir une convergence contrôlée, soit : une fonction de coût que l'on puisse réellement suivre lors de l'apprentissage des réseaux. Ci-dessous se trouvent trois apprentissages différents de génération d'image depuis des GANs "classiques". On remarque que l'évolution de la fonction de coût ne présente pas de lien compréhensible avec la progression en qualité du réseau :

curves2 uselesslosses

En comparaison, un WGAN présente lors de son apprentissage une courbe réellement en lien avec la qualité des éléments générés :

curves3 youpi

Dans la continuité de cette propriété, un WGAN est ainsi beaucoup plus stable à entraîner, et beaucoup plus robuste face aux choix de réseaux de génération comme de discrimination. Entre autres exemples, est présenté dans la publication originelle une comparaison entre à gauche un WGAN et à droite un GAN usuel, sans techniques de régularisation avancée qui d'ordinaire sont indispensable à de telles architectures. Le GAN usuel reste incapable de converger vers la distribution cible, là où le WGAN parvient encore à générer des échantillons valables

curves4 compare

Pour continuer le sujet

Depuis la sortie de la publication, deux travaux sont intéressants à noter :

  • (Improved Training of Wasserstein GANs)()[https://arxiv.org/abs/1704.00028]. Cette publication vise à compléter l'analyse des WGANs et à proposer une meilleur approche optimisant les chances de réussite de l'apprentissage du réseau. Les auteurs ont pu ainsi entraîner une large quantité d'architectures de réseaux disponibles avec une recherche d'hyper-paramètres particulièrement limitée
  • (BEGAN: Boundary Equilibrium Generative Adversarial Networks - David Berthelot, Thomas Schumm, Luke Metz)()[https://arxiv.org/abs/1703.10717]. Ré-utilisant la distance de Wasserstein et proposant une nouvelle approche, cette architecture est considérée aujourd'hui comme celle présentant le plus d'interêt, en attendant la prochaine.