CLIP 模型

Photo by Farhan Khan on Unsplash
Photo by Farhan Khan on Unsplash
CLIP(Contrastive Language-Image Pre-training)是由 OpenAI 於 2021 年提出的模型。它透過融合視覺與語言共同編碼達成強大的泛化能力,並具有廣泛的潛在用途。本文章將介紹 CLIP 的理論與實作。

CLIP(Contrastive Language-Image Pre-training)是由 OpenAI 於 2021 年提出的模型。它透過融合視覺與語言共同編碼達成強大的泛化能力,並具有廣泛的潛在用途。本文章將介紹 CLIP 的理論與實作。

完整程式碼可以在 下載。

架構

OpenAI 於 2021 年提出 CLIP 模型。其設計的動機源於傳統圖像分類模型受限於固定的分類標籤,缺乏泛化能力。為克服此限制,CLIP 採用一種全新的預訓練方法,透過對比式學習(contrastive learning)從大量圖像與文本配對資料中學習視覺與語言共同的表徵空間(joint embedding space)。這種方法使 CLIP 在從未訓練過的任務上也能表現出卓越的零樣本(zero-shot)泛化能力。

CLIP 通過聯合訓練影像編碼器(image encoder)與文本編碼器(text encoder),學習一個多模態嵌入空間(multi-modal embedding space),以最大化每個真實配對影像與文本嵌入的餘弦相似度(cosine similarity),同時最小化錯誤配對的 embedding 的 cosine similarity,並對這些 similarity 分數使用對稱交叉熵損失(symmetric cross entropy loss)進行優化。

Contrastive Pre-training (source from Learning Transferable Visual Models From Natural Language Supervision)
Contrastive Pre-training (source from Learning Transferable Visual Models From Natural Language Supervision)

下面的 pseudo code 顯示,在 pre-training 時,如何利用 text 和 image 的 embeddings 來計算 loss。

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

文本編碼器(Text Encoder)

CLIP 的 text encoder 是採用 Transformer 的 encoder 架構。如果還不熟悉 Transformer 的話,請先參考以下文章。

CLIP 採用以下的參數設定:

  • Layers: 12
  • Attention heads: 8
  • d_model: 512
  • Tokenizer: Byte Pair Encoding (BPE)
  • Vocab size: 49152
  • Max sequence length: 76
  • Total parameters: 63M

另外,CLIP 對於 Transformer 的 encoder 有些許的修改。此修改方式是參照 GPT 對 Transformer 的 decoder 的修改。如,將 LayerNorm 改為置於每個 sub-block 之前(pre-LM 結構)。

Text encoder 將 text sequence 以 [SOS][EOS] token 包圍,並在 Transformer 的 encoder 的最後 layer 的 activations 中,取 [EOS] 位置的 activation 作為 text 的特徵表徵(feature representation)。

OpenAI 有開放 CLIP 的原始碼,請參照這裡。以下的 text encoder 實作,是基於官方的原始碼改寫的精簡版。

import torch
import torch.nn as nn


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.ln_2 = nn.LayerNorm(d_model)

    def forward(self, x):
        """
        Parameters
        ----------
        x: Tensor of shape [batch_size, sequence_length, d_model]
        """
        attn_out, _ = self.attn(self.ln_1(x), self.ln_1(x), self.ln_1(x))
        x = x + attn_out
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width, layers, heads):
        super().__init__()
        self.resblocks = nn.Sequential(
            *[ResidualAttentionBlock(width, heads) for _ in range(layers)]
        )

    def forward(self, x):
        return self.resblocks(x)


class TextEncoder(nn.Module):
    def __init__(self, vocab_size, context_length, width, layers, heads):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, width)
        self.positional_embedding = nn.Parameter(torch.empty(context_length, width))
        nn.init.normal_(self.positional_embedding, std=0.01)
        self.transformer = Transformer(width, layers, heads)
        self.ln_final = nn.LayerNorm(width)
        self.context_length = context_length

    def forward(self, text):
        """
        Parameters
        ----------
        text: Tensor of shape [batch_size, context_length] containing token indices.
        """

        x = self.token_embedding(text)  # [batch_size, context_length, width]
        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # [context_length, batch, width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # [batch, context_length, width]
        x = self.ln_final(x)
        return x[:, -1, :]  # [batch, width]

影像編碼器(Image Encoder)

OpenAI 在 CLIP 論文中,比較了 5 種基於 ResNets 和 3 種基於 Vision Transformers 的不同的 image encoder 的效能表現,其分別如下:

  • ResNet-50
  • ResNet-101
  • RN50x4: 擴展 4 倍的 ResNet-50
  • RN50x16: 擴展 16 倍的 ResNet-50
  • RN50x64: 擴展 64 倍的 ResNet-50
  • ViT-B/32
  • ViT-B/16
  • ViT-L/14

論文中的 zero-shot 效能測試結果中,Vision Transformer(ViT)架構優於 CNN 架構(ResNet 系列)。而在 3 個 ViT 的模型中,ViT-L/14 的表現最佳,遠高於其他模型,其次為 ViT-B/16 和 ViT-B/32。這顯示 Vision Transformer 在大量預訓練數據下有極佳的泛化能力。

在 CLIP 的原始碼中,作者直接實作 Vision Transformer。下面的程式碼中,我們直接使用 torchvision 中的 vit_b_16。在 constructor 中,我們移除了 vit_b_16 模型的 classification head。

from torchvision.models import vit_b_16, ViT_B_16_Weights


class ImageEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.vit.heads = nn.Identity()  # [batch, 768]
        self.proj = nn.Linear(self.vit.hidden_dim, output_dim)

    def forward(self, x):
        features = self.vit(x)  # [batch, 768]
        features = self.proj(features)  # [batch, output_dim]
        return features

CLIP

下面的程式碼中的 CLIP 整合了上述的 text 和 image encoders。我們可以看到,此程式碼其實就是先前 pesudo code 的實作。

class CLIP(nn.Module):
    def __init__(self, image_embed_dim, text_vocab_size, context_length, text_width, text_layers, text_heads):
        super().__init__()
        self.image_encoder = ImageEncoder(image_embed_dim)
        self.text_encoder = TextEncoder(text_vocab_size, context_length, text_width, text_layers, text_heads)
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))

    def encode_image(self, image):
        return self.image_encoder(image)

    def encode_text(self, text):
        return self.text_encoder(text)

    def forward(self, image, text):
        image_features = F.normalize(self.encode_image(image), dim=-1)
        text_features = F.normalize(self.encode_text(text), dim=-1)
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()
        return logits_per_image, logits_per_text

以下程式碼顯示如何使用 CLIP。

if __name__ == "__main__":
    vocab_size = 49408
    context_length = 77
    image_embed_dim = 512
    text_width = 512
    text_layers = 12
    text_heads = 8

    model = CLIP(
        image_embed_dim=image_embed_dim,
        text_vocab_size=vocab_size,
        context_length=context_length,
        text_width=text_width,
        text_layers=text_layers,
        text_heads=text_heads
    )

    batch_size = 2
    image = torch.randn(batch_size, 3, 224, 224)
    text = torch.randint(0, vocab_size, (batch_size, context_length))

    logits_per_image, logits_per_text = model(image, text)
    print("logits_per_image shape:", logits_per_image.shape)
    print("logits_per_text shape:", logits_per_text.shape)

應用

由於 CLIP 把圖像和文字都嵌入到同一個向量空間,所以我們能做跨模態檢索(cross-modal retrieval)。最直接的應用就是視覺搜尋(visual search),其包含以下三種任務:

  • Image-to-Image Search:用輸入的 image 找出相近的 images。
  • Image-to-Text Search:用輸入的 image 找出相近的 texts。
  • Text-to-Image Search:用輸入的 text 找出相近的 images。

範例

我們將利用 CLIP 來實作 image-to-image search。在此範例中,我們有 5 張電風扇商品圖、5 張冰箱商品圖、和 5 張洗衣機商品圖。

首先,我們先用 CLIP 的 image encoder 將這 15 張商品圖轉換成 embeddings。

import os
from pathlib import Path

import clip
import faiss
import numpy as np
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

PRODUCT_IMAGES = Path(__file__).parent.parent / "product_images"


def build_database_embeddings(database_folder):
    image_files = [f for f in os.listdir(database_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]
    embeddings = []
    filenames = []

    for img_file in image_files:
        img_path = os.path.join(database_folder, img_file)
        image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)

        with torch.no_grad():
            embedding = model.encode_image(image).cpu().numpy()

        embeddings.append(embedding / np.linalg.norm(embedding))
        filenames.append(img_file)

    embeddings = np.vstack(embeddings).astype('float32')
    return embeddings, filenames


print("Building database embeddings...")
db_embeddings, filenames = build_database_embeddings(PRODUCT_IMAGES)

然後,將這些 embeddings 加入到 FAISS 的索引中。當然你可以使用其他的 vector databases,如 ChromaDBMilvus

def build_faiss_index(embeddings):
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner product similarity
    index.add(embeddings)
    return index


print("Creating FAISS index...")
faiss_index = build_faiss_index(db_embeddings)

最後,我們想要找與 SEARCH_IMAGE 相似的圖。

def search_similar_product(user_image_path, index, filenames):
    image = preprocess(Image.open(user_image_path)).unsqueeze(0).to(device)

    with torch.no_grad():
        query_embedding = model.encode_image(image).cpu().numpy()

    query_embedding /= np.linalg.norm(query_embedding)
    distances, indices = index.search(query_embedding, k=3)
    return [(filenames[i], distances[0][idx]) for idx, i in enumerate(indices[0])]


SEARCH_IMAGE = "./search-image-1.jpg"

print("Searching for similar products...")
results = search_similar_product(SEARCH_IMAGE, faiss_index, filenames)

print("Top matched products:")
for filename, score in results:
    print(f"Product: {filename}, Similarity Score: {score:.4f}")

結語

CLIP 就像是教電腦同時學會看圖和讀字,讓它不僅能理解圖片,也能把文字描述和圖片精準連結起來。透過這種強大的能力,我們能更輕鬆地開發各種圖文互動的應用。

參考

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

You May Also Like