PaLM y PaLM 2

Rubén Rodríguez Abril

Este artículo versa sobre uno de los modelos de de transformers de Google, PaLM (caracterizado por su sistema de entrenamiento paralelo, Pathways), y su versión mejorada, PaLM 2.

Aunque Bard, el interfaz de usuario de LaMDA, nació con la intención de rivalizar con el buque insignia de Open AI, ChatGPT, pronto se demostró que sus capacidades lingüísticas estaban muy por debajo de las de este último. Es más, la realización de una demostración en vivo en París, el 8 de Febrero de 2023, de las capacidades de Bard tuvo resultados decepcionantes y provocó incluso un descenso de la capitalización en bolsa de Alphabet ese mismo día. Esto motivó que Sundar Pichai, el CEO de Google, anunciara el 21 de Marzo de 2023 la aparición de una nueva versión de Bard potenciada con un nuevo modelo de lenguaje, PaLM, que es el objeto de estudio de este artículo.

PaLM

Arquitectura de PaLM: novedades

La arquitectura general de PaLM es la de un descodificador autorregresivo, que tiene una estructura similar a la descrita por el artículo original de Vaswani et al, pero con las siguientes particularidades:

Función de activación SwiGLU. Es una combinación de las funciones de activación swish(x) y GLU(x), representadas por las ecuaciones

swish (x) = x · sigmoide(x)

GLU(x) = x + sigmoide(Wx + b)

donde W es una matriz de pesos sinápticos y b es el sesgo. Ambas funciones producen un resultado muy parecido al del rectificador (ReLU), pero con la particularidad de que en ambos casos la función tiene una derivada no nula en todo su dominio, paliando con ello el problema de la desaparición del gradiente. La estructura de la primera función, swish, permanece invariante a lo largo de todo el entrenamiento. En la segunda función, por el contrario, interviene una capa lineal, Wx + b, cuyos parámetros son aprendidos.

Figura 1: Comparación entre las funciones ReLU, GELU y swish.

Capas paralelas. En cada módulo de transformer, la unidad de atención y la capa de normalización están situadas paralela y no serialmente (difiriendo así del modelo original), tal y como se muestra en la imagen. De ello resulta un entrenamiento 15% veces más rápido, sin que haya merma en la calidad alguna de los resultados de la red (al menos, de 540.000 millones de parámetros) en adelante.

Figura 2: A la izquierda, estructura de cada módulo en Vaswani et al. A la derecha, estructura de los módulos de PaLM.

Atención multiconsulta (Multi-Query attention): En cada unidad de atención, las matrices de clave (key) y valor (value) son compartidas por todas las cabezas.

Misma codificación de embedding para la entrada y salida, aunque éste es un rasgo no es propiamente una novedad, sino que es compartido por buena parte de los modelos anteriores.

-En la codificación de contenido (embedding) se utiliza el vocabulario SentencePiece, compuesto de 256k tokens.

Pathways: un mecanismo de entrenamiento paralelo

En la actualidad, la cantidad de parámetros y de datos involucrados en el entrenamiento de un gran modelo de machine learning es tal que es necesario emplear centenares o incluso miles de procesadores (GPUs, NPUs, TPUs) en la tarea, ubicados en grandes centros de datos. Los cálculos de todos estos chips se realizan en paralelo, lo cual impone la utilización de algoritmos que coordinen las operaciones y gestionen eficientemente la transferencia de datos a lo largo de toda la red.

Uno de estos algoritmos es Pathways, que da nombre al modelo de lenguaje que estamos analizando en este mismo artículo (Pathways Language Model, PaLM). En Pathways, las TPUs (Tensorial Processing Units, unidades de procesamiento tensorial) de un centro de datos son agrupadas físicamente en islas (islands), cada una de las cuales tiene su propia topología propia en 2 o 3 dimensiones. Cada isla está controlada por un planificador (scheduler) y todo el centro de datos lo está por un manager de recursos (resource manager).

Las diferentes funciones son ejecutadas por subgrupos de TPUs denominadas lonchas virtuales (virtual slices). La creación de estas lonchas se realiza a petición de la secuencia principal del programa. Cuando la petición llega al manager de recursos, éste agrupa a un número de TPUs de una isla en el modo señalado en la imagen:

Figura 3: Funcionamiento genérico de Pathways. El diagrama del medio señala la estructura física de un centro de datos. Los TPUs (celeste claro) se agrupan en islas (en verde claro), cada una de las cuales está coordinada por un “planificador” (“scheduler”), representado por un asterisco azul. Las lonchas virtuales son creadas dentro de cada isla a petición del programador, se identifican por letras de abecedario (A, B, C) y su misión es ejecutar funciones determinadas. El diagrama de la derecha describe el funcionamiento del sistema a lo largo del tiempo. Dos planificadores dan inicio a la ejecución de las funciones en sus respectivas islas, que a lo largo del tiempo y se transmiten información entre sí (flechas en azul). Fuente: Barham et al.

En PaLM, el esquema es más simple. Sólo hay dos lonchas, denominadas “pods”, cada una de las cuales ocupa una isla. Durante el entrenamiento, los lotes se dividen en dos, y cada pod se encarga de procesar una mitad. En una primera fase, se calculan los gradientes, y en la segunda se actualizan los parámetros del modelo:

Figura 4: Una vez que cada pod calcula los gradientes para cada mitad del lote, transmite los resultados al otro pod. Entonces, se procede a actualizar los parámetros de la red. Fuente: Chowdhery et al.

Proceso de entrenamiento

El modelo es inicializado aleatoriamente y sometido a las fases de preentrenamiento y afinamiento. Durante el preentrenamiento se utilizó la base de datos descrita en Du et al, 2021, consistente en páginas web de alta calidad, libros, entradas de Wikipedia y datos procedentes de redes sociales (utilizados por Adiwardana et al, 2022). Un 77% de los textos estaba redactado en inglés, y el resto en otros idiomas. Por lo que se refiere al código, Java (18,8%), HTML (17,5%) y Javascript (11,6%) fueron los lenguajes predominantes.

La longitud de las cadenas de entrada fue de 2048 tokens. Durante el proceso y en las versiones más grandes del modelo, la longitud de los lotes se incrementó progresivamente de 512 a 1024 y 2048.

Evaluación

PaLM fue sometido nada más y nada menos que a 29 benchmarks (colecciones de pruebas diferentes) en tareas de NLP de lengua inglesa, como SuperGLUE, BIG-Bench, tareas de Winograd, tareas de completado de frases, razonamiento de sentido común, traducción o programación, entre muchas otras.

El modelo mejoró el estado de la técnica en los tres supuestos de aprendizaje sin ejemplos previos (zero-shot), con un sólo ejemplo (one shot) o con algunos de ellos (few-shot).

PaLM 2

PaLM 2, la versión mejorada de PaLM, fue presentado en la conferencia I/O de Google de 10 de Mayo de 2023. Sus principales innovaciones, que fueron expandidas a Bard, son las siguientes:

Multilingüismo: Es capaz de realizar tareas lingüísticas como la traducción, la composición de poemas, la resolución de acertijos en más de 100 lenguajes diferentes, debido en buena parte a la ampliación de la base de datos de entrenamiento a textos redactados en otras lenguas diferentes al inglés.

Razonamiento: Debido a la inclusión de trabajos científicos en sus datos de entrenamiento, el modelo muestra grandes capacidades de deducción lógica y matemática y de razonamiento de sentido común.

Codificación: PaLM 2 muestra grandes capacidades de programación no sólo en lenguajes comunes como Python y Javascript, sino también en otros no tan populares, como Prolog, Fortran y Verilog.

PaLM 2 está disponible en varios tamaños, de menor a mayor: Gecko, Otter, Bison y Unicorn. El primero de ellos, Gecko, es tan ligero, que puede utilizarse localmente en dispositivos móviles, lo que permite su uso en aplicaciones incluso sin conexión.

Las capacidades del modelo han sido aprovechadas para crear versiones del mismo especializados en temas científicos, como, Med-PaLM 2, entrenada para responder a cuestiones médicas, y Sec-PaLM 2, que tiene aplicaciones en el ámbito de la ciberseguridad, particularmente en la detección de scripts maliciosos.