欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

Flamingo a Visual Language Model for Few-Shot Learning

最编程 2024-03-08 20:12:47
...

Flamingo: a Visual Language Model for Few-Shot Learning

TL; DR:Flamingo 在 VL-adapter 的结构上有创新,Perceiver Resampler + gated xattn,一种看起来比较复杂且高级的将图像特征注入到语言模型的方式。同时,优秀的结构设计使得 Flamingo 能够处理图/文/视交错数据,从而有多模态 few-shot learning (in-context learning) 的能力。


CLIP 的出现使得多模态模型有了 zero-shot 的能力,可以说是多模态领域里程碑式的工作。然而,CLIP 终归是一个表征模型,其能支持的任务也只能是分类、检索这一类任务。本文提出 Flamingo,Flamingo 的结构创新的优势有三:一是可以桥接预训练好的视觉模型和语言模型;二是可以处理任意交错的图文对数据;三是可以同时以图像和视频数据作为输入。有了这些优势,Flamingo 就能在规模巨大的互联网图文交错数据上进行训练。从而,Flamingo 结合 prompt 实现了多模态领域的 few-shot learning (in-context learning) 能力。

先来看下 Flamingo 的 few-shot / in-context learning 能力。通过在 prompt 中提供几个示例,Flamingo 能够理解任务的要求并给出合理的答案。in-context learning 在 NLP 中已经火了一段时间了,由于只有一种模态,NLP 中的 in-context learning 是很直接的,就像小时候试卷上的题目示例一样,可以直接用自然语言描述。而在多模态 in-context learning 中,为了增强模型的对于上下文的理解能力,必须在训练时就不能只有简单的图文对,而是需要加入大量的图文交错数据,这是实现多模态 in-context learning 的关键。

在这里插入图片描述

方法

本节分为模型结构和训练数据两个小节,来介绍 Flamingo 的具体方法。

模型结构

多模态模型的关键在于中间 adapter 层的设计,如何设计一种合理的 adapter 模块,将预训练的视觉编码器和语言模型连接起来。Flamingo 在模型结构方面的创新也是在这里,主要是两个结构:一是 Perceiver Resampler 模块,将任意个数的输入的视觉(视频/图像)特征转换为固定个数的 queries;二是 Gated XATTN-DENSE 模块,将 Perceiver Resampler 输出的视觉 queries,与新插入到 LM 中的层计算交叉注意力,从而将视觉信息注入到 LM 的生成过程中。

Flamingo 的整体结构概览图如下所示,它可以接收任意交错的图/文/视模态的输入数据,并输出文本。

在这里插入图片描述

Perceiver Resampler

Perceiver Resampler 模块的详细结构与伪代码如下图所示。它可以接收任意多个的视频帧(如果是图像,可以视作是单帧视频),经过视觉编码器(文中是 NFNet)提取特征,加上时间 embedding 之后全部展平,得到一个视觉 token 序列 X f X_f Xf X f X_f Xf 的序列长度是任意的。还会有一个可学习的固定长度的 latent query 序列 X X X X f X_f Xf X X X 同时作为输入,送入到 Perceiver Resampler 中计算交叉注意力,注意力的 K 和 V 是 X f X_f Xf X X X 拼接而成,而 Q 就是 X X X 。在经过多层 Attention + FFW 处理之后,输出固定长度(文中是 64)的 latent query,作为视觉表征。这里的 queries 就有点类似于 ViT 中的 cls token,拼接在数据 token 之后,并作为最后的输出,只不过 token 数目更多。

注意这里的 positional embedding 只是 time embedding,即标识图片的时序信息,即图片来自哪一帧,这应该是因为单张图片内的空间位置信息在 vision encoder 中已经处理过了。换个角度想,如果只想做针对图片的编码,并且是用 Transformer 的话,Perceiver Resampler 这里可以处理变长序列,转换为定长序列的特性,再配合上 Transformer 位置编码插值,或许可以用于解决任意图像分辨率的问题。

在这里插入图片描述

Gated xattn dense

Flamingo 将固定长度的视觉 query 注入到语言模型的方法称为 Gated xattn dense,其详细结构示意图及伪代码如下图所示。具体来说,在预训练好的 LM 的各层交替地插入一些随机初始化的交叉注意力层。所谓 gated 门控,在每一新插入的层之后的残差链接之前添加一个 tanh gating,即 tanh ( α ) \text{tanh}(\alpha) tanh(α),其中 α \alpha α 是一个可学习的标量值,初始值为 0,从而保证初始化时的输出与原 LM 一致,思路有点 controlnet 零卷积的感觉?

在这里插入图片描述

训练数据

Flamingo 的训练数据有三类:图文交错数据集、图像文本对数据集和视频文本对数据集。其中图文交错数据集是 Flamingo 数据的重点,其多模态 in-context learning (few-shot learning) 的能力可以说主要就是来自图文交错数据。作者收集了一个大规模图文交错数据集 M3W,通过解析 HTML 获取并标记图片在文本中的位置。

Flamingo 的训练还有一个关键点,就是如何处理图文交错数据。首先每张图片的特征用 vision encoder + resampler 处理得到固定长度的 token 序列,在文本序列中用 <image> 这个 tag 来标记出图像的位置。值得一提的是,每个文本 token 只会 attend 到之前最近的一张图片,而不是所有之前的图片,之前的其他图片会被 mask 掉。

在这里插入图片描述

实验

Flamingo 从模型结构到训练数据的创新还是比较多的,这里重点看一下消融实验,看各个设计哪些是真正有效的。下表展示的是 few-shot 性能的消融实验,第一行是 Flamingo 设计的模型,以下各行是去掉某种设计后的 few-shot 性能,注意每一行都应该与第一行比较。

数据方面:M3W 这种图文交错数据对结果的影响是最大的,不仅是在数据消融中影响最大,而是在所有消融中影响都是最大的,没有的话性能掉了接近 20 个点。而视频文本对数据看起来影响没有那么大。这里比的是 few-shot 的性能,那倒是挺合理的, 图文交错数据应该是多模态 few-shot / in-context learning 的关键。

模型方面:gating、resampler、cross attention layer 的缺失都会对性能有一定影响,但也都没有图文交错数据的影响大。

在这里插入图片描述

总结

Flamingo 的技术创新点很多,图文交错数据训练、multimodal few-shot / in-context learning、Perceiver Resampler、gated xattn dense 等。特别是图文交错数据实现了多模态的 few-shot / in-context learning。是很有价值的一篇多模态语言模型的工作。在 LLM 时代,也有很多多模态大模型会参考 resampler 的结构设计。