Destilación de modelos

Rubén Rodríguez Abril

La destilación de modelos es una técnica de transferencia de conocimientos entre dos inteligencias artificiales. Intervienen en ella dos modelos: el modelo profesor, que es por lo general de gran tamaño y se ejecuta en centros de datos, y el modelo alumno, muy inferior en dimensión, destinado a ser empleado en pequeños dispositivos. El primero entrena con sus salidas al segundo.

INTRODUCCIÓN

En la actualidad, los grandes modelos de lenguaje están compuestos de billones de parámetros y por esta razón sólo son ejecutables (incluso en inferencia) en gigantescos centros de datos en los que centenares de procesadores trabajan paralelamente.

Por esta razón, en el ámbito del aprendizaje profundo, desde el trabajo pionero de Bucilă et al del año 2006, se han abierto nuevas vías de investigación para reducir el tamaño de los modelos de lenguaje y permitir que éstos sean utilizables en teléfonos móviles y dispositivos IoT. De este modo, la IA no se limitaría al ámbito de la computación en la nube (cloud computing), sino que también se extendería a la computación en el borde (edge computing) o computación ultracompacta (tiny computing).

TIPOS DE DESTILACIÓN DE MODELOS

En el presente artículo se introduce al lector en la técnica de la destilación de modelos, en la que un modelo profesor entrena y transfiere conocimientos a un modelo estudiante, con mucho menor número de capas y parámetros, sin una merma significativa en eficiencia y conocimientos.

La destilación de modelo fue introducida por Geoffrey Hinton y su equipo en el trabajo: «Distilling the Knowledge in a Neural Network», del año 2015. Aunque esta técnica se ha usado fundamentalmente en el marco de los modelos de lenguaje, también ha encontrado aplicación en otras tareas de aprendizaje profundo como la clasificación de objetos.

La transferencia de conocimiento puede realizarse a través de múltiples vías. El artículo “Distiller: A Systematic Study of Model Distillation Methods in Natural Language Processing” agrupa a todas ellas en tres grandes categorías:

  • Aumento de datos: creación de datos sintéticos a partir de la modificación de los datos originales.
  • Destilación en capa de predicción: las salidas del profesor son utilizadas para entrenar al alumno.
  • Destilación en capa interna: se produce una transferencia de representaciones de profesor a alumno.

AUMENTO DE DATOS

La base de datos de profesor y alumno es incrementada mediante modificaciones en los datos originales, utilizando diversas técnicas:

  • Modelos de lenguaje enmascarados (Masked Language Models, MLMs). Se enmascaran palabras al azar dentro de una cadena de texto y se usa BERT para completar estos huecos.
  • Aumento aleatorio. Se eliminan, permutan o sustituyen palabras de una oración de forma aleatoria.
  • Traducción inversa. Los textos son traducidos a una lengua extranjera y luego retraducidos a la lengua original.
  • Mixup. Uso de interpolación lineal entre los distintos embeddings.

DESTILACIÓN EN CAPA DE PREDICCIÓN

En un modelo de lenguaje o en un sistema de clasificación de imágenes, las salidas del profesor (que tiene como última capa una función softmax) describen distribuciones de probabilidad sobre categorías y reciben el nombre de etiquetas blandas (soft targets). Al aplicar sobre una de ellas la función argmax, la distribución colapsa en un vector one-hot en el que todos sus componentes son nulos salvo el correspondiente a la categoría más probable, que toma el valor 1. En este caso, se habla de etiquetas duras (hard targets). El profesor se entrena con las etiquetas duras proporcionadas por la base de datos de entrenamiento. El alumno a su vez es entrenado con las etiquetas blandas generadas por el profesor. En ambos casos, se usa la entropía cruzada como función de pérdida.

La utilización de etiquetas blandas permite que el alumno aprenda no sólo de la categoría primaria (la más probable) sino también de la información contenida en la puntuación de las otras categorías. Sin embargo, en muchas ocasiones la puntuación atribuida a las categorías secundarias es bastante baja, lo que provoca gradientes débiles y ralentiza el aprendizaje.

Dos soluciones se han planteado para mitigar este problema:

  • La primera es calcular la función de pérdida a partir de los logits de la última capa lineal (en cuyo caso, la función de pérdida sería la del error cuadrático medio).
  • La segunda es utilizar una temperatura superior a 1, en cuyo caso la función softmax tomaría esta forma:
q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}

TEMPERATURA

La temperatura T, que aparece en la fórmula, es el factor común por el que se dividen todos los logits de las diversas categorías. Su incremento es una suerte de reescalado y produce el aplanamiento de la distribución de probabilidades, reduciéndose la distancia entre la categoría primaria y el resto. Utilizando T = 1 el vector de salida es más parecido a una etiqueta dura. Con una temperatura alta (p.e. T = 20) el modelo puede aprender las relaciones sutiles entre todas las clases clases.

RESULTADOS

Los autores decidieron realizar experimentos en el ámbito de la clasificación de dígitos utilizando la base de datos MNIST. Se diseñó un modelo profesor compuesto de dos capas ocultas de 1200 unidades con una función de activación ReLU. Este modelo fue entrenado en 60.000 imágenes con dropout (apagado) y mecanismos de aumento de datos (desplazamiento de dos píxeles en todas direcciones). Cometió 67 errores durante la fase de comprobación. El modelo alumno estaba compuesto de dos capas internas de 800 unidades. Sin destilación (es decir, utilizando etiquetas duras –hard targets- y siendo entrenado directamente desde la base de datos), el alumno cometió 146 errores. Con destilación (entrenamiento a partir de etiquetas blandas –soft targets- generadas por el profesor, a una temperatura de 20) los fallos descendieron a 74. Los investigadores encontraron que la temperatura T = 20 era la idónea para esta magnitud de modelos.

Además, se llevó a cabo un experimento extremo con resultados sorprendentes: durante el entrenamiento del alumno, las imágenes relativas al número 3 fueron eliminadas por completo de la base de datos. Sin embargo, cuando en inferencia se le presentaron primera vez imágenes del dígito 3 (que nunca había visto antes), fue capaz de clasificarlas correctamente, con un margen de error sorprendentemente bajo. Este fenómeno sugiere que el modelo pudo extrapolar la información secundaria de las puntuaciones de los dígitos 2 y 7, que son los más cercanos, gráficamente hablando, al dígito 3.

DESTILACIÓN DE REPRESENTACIONES INTERMEDIAS

En los transformers, la información relativa a niveles superiores de significación (conceptos abstractos, estilo de escritura, posible ideología o estado de ánimo del autor de un texto) es almacenada y procesada en los pesos sinápticos de las capas intermedias. Por el contrario, la salida y capas adyacentes se encargan del procesamiento de la información de bajo nivel (relacionada, fundamentalmente, con fonemas, caracteres individuales o incluso pares de bytes dentro del formato Unicode).

Este fenómeno hace recomendable aplicar la destilación no sólo en la capa de salida, sino también en las capas intermedias. Sin embargo, ello viene complicado por el hecho de que el modelo alumno tiene un tamaño inferior al del modelo profesor y no existe una correspondencia plena entre las capas de ambas redes. Para soslayar este problema se utilizan varias estrategias de mapeo de capas:

CONEXIÓN DIRECTA DE CAPAS SELECCIONADAS

Los diseñadores del entrenamiento escogen capas del modelo profesor y las emparejan con capas del modelo alumno. Como función de pérdida se suele aplicar una función de tipo L2 o la divergencia de Kullback-Leibler. Para que este mapeo sea implementable las dos capas deben de tener el mismo número de unidades/neuronas.

Es el modelo usado en TinyBERT (Jiao et al., 2020).

PONDERACIÓN DE CAPAS DEL PROFESOR

Cada capa del modelo alumno se vincula a varias de las capas del modelo profesor, cuyas salidas son ponderadas. Así, por ejemplo, la capa 4 de un modelo alumno podría quedar vinculada, mediante la función de pérdida, a una ponderación de las capas 3, 6 y 9 del modelo profesor:

Hponderada4 = αHprofesor3 + βHprofesor6 + γHprofesor9

funcion_perdida(Halumno4,Hponderada4)

donde los coeficientes α, β, γ pueden ser constantes o variables aprendidas.

Este mecanismo se usa en DistilBERT (Sanh et al, 2019), donde el modelo alumno tiene 6 capas y el modelo profesor 12.

PROYECCIÓN DE CAPAS DEL PROFESOR

Las salidas de las capas del profesor son sometidas a una transformación lineal (multiplicación matricial) y tras ello son comparadas con la correspondiente capa del alumno a través de la función de pérdida.

Hproyeccion4 = WHprofesor4 + b

funcion_perdida(Halumno4,Hproyeccion4)

Sistema adoptado por MobileBERT (Sun et al, 2020).

La matriz W y el sesgo b son aprendidas. Su valor puede ser determinado de dos maneras: simultáneamente con los pesos sinápticos de profesor y alumno (entrenamiento de una etapa) o bien en una fase anterior (entrenamiento de dos etapas).

En el primero de los casos, la función de pérdida tiene esta estructura:

\mathcal{L} = \mathcal{L}_{\text{entropia-cruzada}} + \lambda \left\| W H_{\text{profesor}} – H_{\text{alumno}} \right\|^2

El primer término de entropía cruzada se aplica en la capa de salida del alumno, en base a las etiquetas blandas proporcionadas por la profesor. El segundo término se aplica en la capa de conexión del modelo alumno.

El entrenamiento de dos etapas tiene el inconveniente de que durante su primera fase, en la que se construyen espacios de representación común entre ambos modelos, los pesos del alumno no se modifican y siguen siendo aleatorios, lo cual dificulta la optimización de W y b. Además, dado que la transformación lineal es una simple operación alebraica, basta utilizar métodos matriciales estándar para determinar los valores de W a partir de los de Hproyeccion y Halumno, sin entrenamiento. Para soslayar estos inconvenientes o bien se preentrena ligeramente al alumno, o bien se usa un modelo iterativo en el cual se va alternando el entrenamiento de los pesos del alumno con el refinamiento de W al concluir el procesamiento de cada lote.

MECANISMO DE ATENCIÓN

Un vector de atención a pondera la contribución de las diferentes capas del profesor en el entrenamiento de una capa del alumno (en este caso, la cuarta):

Hatencion4 = ΣaiHprofesori

funcion_perdida(Halumno4,Hatencion4)

Los componentes ai son aprendidos. Una función softmax asegura que la suma de todos ellos sea igual a 1. Como el lector podrá comprobar, este sistema es una generalización del mecanismo de ponderación de capas.

El mecanismo de atención fue usado en el modelo MiniLM (Wang et al., 2020), entre otros.

METADESTILACIÓN

El artículo “Distiller: A Systematic Study of Model Distillation Methods in Natural Language Processing” propone un enfoque integral en el que se combinan todas las técnicas citadas para crear un procedimiento (pipeline) de metadestilación, denominado Distiller. En este marco, la función de pérdida es una suma de términos relativos al aumento de datos, la destilación en capa intermedia y la destilación en capa final (o de predicción). El entrenamiento pondera dinámicamente la importancia de todos los sumandos.

Figura 1. Esquema del “pipeline” de Distiller. En verde se describen los tres grandes procedimientos de destilación: aumento de datos, destilación en capa intermedia y destilación en capa final. Fuente He et al.

CONCLUSIÓN

Las técnicas de destilación están posibilitando la creación la creación de modelos de lenguaje eficientes con cada vez menores requisitos de memoria y de capacidad, que pueden ser utilizados en dispositivos móviles o en entornos industriales con software restringido. En el futuro, se espera que nuevas técnicas, como la de la destilación adaptativa, posibiliten al alumno decidir qué partes del conocimiento del profesor necesita retener. En el próximo artículo, analizaremos otro grupo de técnicas cuyo enfoque reside no tanto el entrenamiento de una red a partir de otra como en la compresión de los pesos sinápticos de un modelo.

SERIES