Modelos de difusión guiados

Rubén Rodríguez Abril

En los modelos de difusión guiados, la media de cada pixel es perturbada por un algoritmo tras la generación de la imagen. La guía puede ser realizada por un clasificador externo o por espacios de representaciones comunes de texto-imagen como CLIP. En este artículo analizaremos el funcionamiento de GLIDE, uno de los modelos más influyentes de la actualidad, precursor en muchos aspectos de herramientas de generación de arte como Stable Diffusion o Midjourney.

Introducción

Tal y como vimos en los dos artículos anteriores, en cada paso del proceso de denoising, el modelo de difusión calcula para cada pixel la media y la desviación típica de una distribución gaussiana. El valor del pixel se obtiene haciendo muestreo sampling de dicha distribución.

En los sistemas de difusión guiados dicha media es perturbada de acuerdo con patrones de información proporcionados por el propio usuario, como la categoría deseada para la imagen o una descripción textual.

La guía (perturbación) puede realizarse por un clasificador externo, por la propia red generadora, o por un espacio de representación común de texto e imágenes, como CLIP. A continuación, analizaremos cada uno de los tres casos.

GLIDE1

Figura 1. Las imágenes de arriba fueron creadas por el modelo principal de GLIDE a partir de la descripción “a stained glass window of a panda eating bamboo” (“vidriera de un panda comiendo bambú”). La red no fue guiada en absoluto. Fuente: GLIDE.

Difusión guiada con clasificador externo

Durante la inferencia, en cada paso de denoising la imagen producida es sometida a un clasificador auxiliar. Esta segunda red, que ha sido entrenada previamente, clasifica la representación usando una distribución probabilística y tiene como capa de salida a la función softmax. La función de pérdida es la probabilidad logarítmica de la categoría correcta. Su gradiente respecto a las activaciones (no respecto a los pesos sinápticos) se retropropaga a través de la red hasta llegar a la entrada. Señala en qué sentido deben modificarse los píxeles de la imagen para que dicha probabilidad logarítmica sea máxima. Este gradiente perturba la distribución probabilística de la U-Net de acuerdo con la siguiente ecuación:

\hat{\mu}_\theta(x_t|y) = \mu_\theta(x_t|y) + s \cdot \Sigma_\theta(x_t|y) \nabla_{x_t} \log p_\phi(y|x_t)

donde xt son los datos de entrada, y la clase seleccionada, μ’θ(xt|y) es la media perturbada, μθ(xt|y) la media primitiva, Σθ(xt|y) la desviación típica y xxlogpφ(y|xt) el valor del gradiente de la función de pérdida en el pixel de que se trate.

Por ejemplo, si la clase es “tormenta”, el clasificador, a través del gradiente xxlogpφ(“tormenta”|xt) ajusta la imagen xt producida por la U-Net para que sea lo más parecido posible a una tormenta. La variable s, denominada coeficiente de guía, controla este ajuste, y el incremento de la misma aumenta la calidad de las imágenes producidas, aunque a costa de disminuir su la diversidad.

Difusión guiada sin clasificador externo

El trabajo Ho & Salimans del año 2021 propuso prescindir por completo de clasificadores externos y que fuese el propio modelo generativo el que se guiara a sí mismo. Durante el entrenamiento, el modelo aprendía a generar imágenes tanto con etiquetas como sin ellas. Esto último se realizaba sustituyendo a la clase y por una etiqueta nula 0. La media εθ(xt|y) de la distribución es perturbada del siguiente modo:

\hat{\epsilon}_\theta(x_t|y) = \epsilon_\theta(x_t|\emptyset) + s \cdot \left(\epsilon_\theta(x_t|y) – \epsilon_\theta(x_t|\emptyset)\right)

El mismo esquema es aplicable las imágenes generadas durante el proceso de reversión por medio de una descripción de texto o prompt (c):

\hat{\epsilon}_\theta(x_t|c) = \epsilon_\theta(x_t|\emptyset) + s \cdot \left(\epsilon_\theta(x_t|c) – \epsilon_\theta(x_t|\emptyset)\right)

Como puede comprobarse leyendo la ecuación, cuanto más cercano a 1 sea el valor de s mayor será la importancia del prompt en la generación de la imagen.

GLIDE2

Figura 2. Cuando GLIDE es guiado sin clasificador la mejora en las imágenes producidas es evidente. Fuente: GLIDE.

Difusión guiada por CLIP

CLIP, analizado en uno de nuestros artículos anteriores, construye espacios semánticos comunes para texto e imagen denominados espacios latentes compartidos. Dos codificadores transforman imágenes y cadenas de caracteres en vectores, respectivamente:

-Como codificador de imagen puede utilizarse una red residual (como ResNet-50) clasificadora o un transformer visual (ViT). En el primero de los casos se usa como representación el conjunto de activaciones de su última capa lineal. En el segundo caso, se usa el token inicial [CLS].

-Como codificador de texto se usa un modelo de lenguaje. En este caso, el valor del token final [EOS] antes de pasar por la función softmax es usado para representar a toda la cadena de texto.

Las representaciones obtenidas por cada uno de ellos se someten a una transformación lineal mediante matrices (Wi para imágenes y Wc para texto) y tras ello se obtienen vectores en el espacio compartido. La cercanía semántica (de 0 a 1) entre una imagen xt y su descripción textual ci se obtiene mediante el producto escalar de ambos vectores:

f(xt) = Wi(codificador_imagenes(xt))

g(c) = Wc(codificador_texto(c))

cercanía = f(xt)·g(c)

En los modelos de difusión, el gradiente de este producto escalar es utilizado para guiar al proceso de reversión:

\hat{\mu}_\theta(x_t|c) = \mu_\theta(x_t|c) + s \cdot \Sigma_\theta(x_t|c) \nabla_{x_t} \left(f(x_t) \cdot g(c)\right)

El gradiente es retropropagado a través del codificador de imagen hasta su capa inicial. Y tras ello, se usa para perturbar, en cada pixel, la media que previamente había sido calculada por el modelo de difusión.

GLIDE3

Figura 3. Resultados usando una guía basada en CLIP. Fuente: GLIDE.

GLIDE

GLIDE (Nichol et al, 2022) es uno de los ejemplos más tempranos de modelos de difusión guiados.

Arquitectura

En el modelo se utilizan cuatro redes:

-Un modelo principal genera una imagen de 64×64 mediante un proceso de limpieza de ruido (denoising) condicionado a una cadena textual. 3,5B parámetros.

-Otra segunda red, de incremento de resolución (upsampling), transforma esa imagen en otra de 256×256, también mediante un proceso de limpieza condicionado. 1,5B parámetros.

-Un codificador de texto (al estilo de GPT) convierte una cadena de texto en una matriz con tokens cargados de información semántica, que se inyectan en el modelo principal. 1,2B parámetros.

-La última es una red de guiado. Se trata de un transformer visual cuya función es de guiar a la primera y perturbar la media de cada pixel.

Codificador de texto

Es un transformer. La representación de la cadena es la matriz de contexto, compuesta de K tokens, en la forma que tiene al salir del último módulo y justo antes de pasar por la última capa lineal y la función softmax. El token final [EOS] es utilizado para codificar la clase, proporcionando un resumen de la cadena de texto.

Dos redes generadoras

Son dos modelos de difusión: El primero de ellos crea una imagen de 64×64 a partir de ruido, y condicionado por una descripción textual (prompt). El segundo de ellos crea una imagen de 256×256, también a partir de ruido. Está condicionado no sólo por la descripción textual sino también por la imagen de 64×64 generada por la primera red.

Ambas redes se basan en el modelo ADM (Ablated Diffusion Model), que es una versión modificada de la U-Net que incorpora un mecanismo de atención, consistente en una capa de atención global situada a una resolución de 8×8. Está dotada de una única cabeza. Hay un vector-consulta por cada pixel. Sus componentes son los valores que toma el pixel en cada capa de característica. El contexto de la atención (es decir, los vectores clave y valor) vienen conformados por los vectores gráficos (los correspondientes a los pixeles) concatenados con los tokens producidos por el transformer. Una matriz de proyección asegura que estos últimos tengan la misma dimensión que los primeros.

Por otro lado, la información del embedding temporal, que codifica el paso/momento concreto en que se haya el proceso de difusión, y del embedding de clase (el token [EOS]) es integrada en la arquitectura mediante modulación de los mapas de características, en el modo señalado en el artículo anterior.

Red auxiliar de guía

Es un transformer visual 64×64 con CLIP (ViT-L CLIP) entrenado previamente. El gradiente de su función de pérdida CLIP es retropropagado a través del codificador de imagen hasta su capa inicial. La derivada no se calcula respecto de los pesos sinápticos sino respecto de los valores de activación, que son modificados. Los valores del gradiente en la primera capa son utilizados para perturbar la media de la red generadora principal.

Entrenamiento

Las bases de datos utilizadas para entrenar a GLIDE son las mismas utilizadas DALL-E. En todos los casos se trata de imágenes acompañadas de sus respectivas descripciones textuales.

El entrenamiento de la red principal y la red de upsampling abarcó 2,5 y 1,6 millones de iteraciones respectivamente. En este contexto, una iteración se define como el procesamiento de un lote durante una etapa del proceso de limpieza de ruido. Se usaron lotes de 2048 imágenes.