CLIP 모델의 동작 방식을 이해하기 위해 이미지와 텍스트를 직접 입력하는 방식으로 Zero-shot 분류를 실험해보았다.

이번 글에서는 한 단계 더 나아가 실제 데이터셋(Oxford-IIIT Pet Dataset)을 사용해 CLIP의 Zero-shot 분류 성능을 정량적으로 평가하고 Single Prompt 방식과 Template Ensemble 방식의 차이를 비교해보았다.


1. 실험 환경 및 목적

사용 모델

  • openai/clip-vit-base-patch32
  • 추가 학습(Fine-tuning) 없이 Pretrained CLIP 그대로 사용

데이터셋

  • Oxford-IIIT Pet Dataset
  • 고양이 12종 + 개 25종 = 총 37 클래스
  • Test split에서 무작위 2,000장 샘플링

실험 목적

  1. 단일 프롬프트(Single Prompt) 기반 Zero-shot 분류 성능 확인
  2. 여러 텍스트 템플릿을 사용하는 Prompt(Template) Ensemble 효과 검증
  3. 정확도(Top-1 Accuracy)와 Confidence Margin 비교

2. 모델 및 데이터셋 로드

 
device = "cuda" if torch.cuda.is_available() else "cpu" 
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) 
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 
model.eval()
  • GPU 사용 가능 시 CUDA 적용
  • 평가 목적이므로 model.eval()로 설정
pet_dataset = OxfordIIITPet( root=os.path.expanduser("~/.cache"),
split='test', target_types='category', download=True )
  • 클래스 라벨은 품종 이름 문자열로 사용
  • CLIP의 텍스트 입력에 그대로 활용 가능

3. 평가 데이터 구성 (무작위 샘플링)

indices = torch.randperm(len(pet_dataset))[:2000].tolist()
images = [pet_dataset[i][0] for i in indices]
gt_labels = [pet_dataset.classes[pet_dataset[i][1]] for i in indices]
  • 전체 테스트셋 중 2,000장 랜덤 추출
  • 실제 분류 문제와 동일하게 정답 클래스 이름 문자열을 기준으로 평가

4. Baseline: Single Prompt 방식

(1) 단일 템플릿 정의

single_template = ["a photo of a {}"]
single_texts_per_class = [[t.format(c)] for c in pet_dataset.classes]
  • 가장 기본적인 CLIP 예제 형태
  • 각 클래스당 하나의 문장만 사용
 

(2) 이미지 & 텍스트 임베딩 계산

 
single_inputs = processor( 
    images=images, 
    text=single_flat_texts, 
    return_tensors="pt", 
    padding=True 
).to(device)
  • 이미지와 텍스트를 동시에 CLIP에 입력
  • 내부적으로 이미지 인코더 / 텍스트 인코더가 분리 처리됨
single_image_features = single_outputs.image_embeds
single_text_features = single_outputs.text_embeds

(3) 정규화 및 유사도 계산

single_image_features = single_image_features / single_image_features.norm(dim=-1, keepdim=True) 
single_text_features = single_text_features / single_text_features.norm(dim=-1, keepdim=True)
single_similarity = single_image_features @ single_text_features.T
  • CLIP은 코사인 유사도 기반
  • 정규화는 필수

(4) Single Prompt 성능 평가

  • Top-1 Accuracy
  • Confidence Margin
    • 1등과 2등 클래스 간 유사도 차이
single_acc = single_correct / len(images) * 100 single_avg_margin = np.mean(single_margins)

5. Template Ensemble 방식

(1) 다중 텍스트 템플릿 설계

  • 단순 객체 인식이 아닌
    • 품종
    • 촬영 구도
    • 이미지 품질
    • 시선, 자세
      를 반영한 문장들
templates = [ 
    "a photo of a {}, a type of pet.", 
    "a photo of the {}, a type of cat or dog.", 
    "a photo of a {}, a breed of dog.", 
    ... 
]

 CLIP이 학습 당시 접했을 법한 다양한 표현을 의도적으로 포함


(2) 텍스트 임베딩 앙상블

text_features = text_features.view(num_classes, num_templates, -1).mean(dim=1)
  • 동일 클래스에 대한 여러 문장 임베딩을 평균
  • 특정 표현에 대한 편향 감소
  • Zero-shot 분류 성능 안정화

(3) Ensemble 성능 평가

ensemble_similarity = image_features @ text_features.T
ensemble_acc = ensemble_correct / len(images) * 100 
ensemble_avg_margin = np.mean(ensemble_margins)
  • 이미지 ↔ 클래스 대표 임베딩 비교
 

6. 최종 코드

import os
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import OxfordIIITPet
from transformers import CLIPModel, CLIPProcessor
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np

# 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using device: {device}")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

# Oxford-IIIT Pet 로드
pet_dataset = OxfordIIITPet(
    root=os.path.expanduser("~/.cache"),
    split='test',
    target_types='category',
    download=True
)

# 무작위 이미지 선택 (2000장)
indices = torch.randperm(len(pet_dataset))[:2000].tolist()
images = [pet_dataset[i][0] for i in indices]
gt_labels = [pet_dataset.classes[pet_dataset[i][1]] for i in indices]

# 1. SINGLE PROMPT (BASELINE)
print("\nSINGLE PROMPT EVALUATION")
single_template = ["a photo of a {}"]
single_texts_per_class = [[t.format(c) for t in single_template] for c in pet_dataset.classes]
single_flat_texts = [t for texts in single_texts_per_class for t in texts]

# Single prompt CLIP 입력
single_inputs = processor(
    images=images,
    text=single_flat_texts,
    return_tensors="pt",
    padding=True
).to(device)

with torch.no_grad():
    single_outputs = model(**single_inputs)
    single_image_features = single_outputs.image_embeds
    single_text_features = single_outputs.text_embeds

# Single 정규화
single_image_features = single_image_features / single_image_features.norm(dim=-1, keepdim=True)
single_text_features = single_text_features / single_text_features.norm(dim=-1, keepdim=True)

# Single 클래스별 평균 (37, 1, 512) → (37, 512)
num_classes = len(pet_dataset.classes)
single_text_features = single_text_features.view(num_classes, 1, -1).mean(dim=1)
single_similarity = single_image_features @ single_text_features.T

# Single 성능 계산
single_topk = 5
single_values, single_indices_pred = single_similarity.topk(single_topk, dim=1)

single_correct = 0
single_margins = []
for img_idx in range(len(images)):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    if single_indices_pred[img_idx, 0] == gt_idx:
        single_correct += 1
    if single_topk > 1:
        margin = (single_values[img_idx, 0] - single_values[img_idx, 1]).item()
        single_margins.append(margin)

single_acc = single_correct / len(images) * 100
single_avg_margin = np.mean(single_margins) if single_margins else 0

# 2. 템플릿 앙상블
print("\nTEMPLATE ENSEMBLE EVALUATION")
# Oxford-IIIT Pet 및 동물 분류에 최적화된 10대 템플릿
templates = [
    "a photo of a {}, a type of pet.",               # 가장 강력한 기본형
    "a photo of the {}, a type of cat or dog.",      # 대분류(개/고양이) 명시
    "a photo of a {}, a breed of dog.",              # 품종 맥락 추가
    "a close-up photo of a {}.",                     # 근접 촬영 대응
    "a photo of a sitting {}.",                      # 자세 정보 추가
    "a pet portrait of a {}.",                       # 인물화 형식의 구도
    "the {} is shown in the image.",                 # 객체 중심 설명
    "a blurry photo of a {}.",                       # 저화질/노이즈 대응
    "a photo of a {} looking at the camera.",        # 시선 처리 대응
    "a high quality photo of a {}."                  # 고화질 특징 강조
]

num_templates = len(templates)
texts_per_class = [[template.format(c) for template in templates] for c in pet_dataset.classes]
flat_texts = [t for texts in texts_per_class for t in texts]

# Ensemble CLIP 입력
inputs = processor(
    images=images, 
    text=flat_texts, 
    return_tensors="pt", 
    padding=True
).to(device)

with torch.no_grad():
    outputs = model(**inputs)
    image_features = outputs.image_embeds
    text_features = outputs.text_embeds

# Ensemble 정규화 & 평균
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
text_features = text_features.view(num_classes, num_templates, -1).mean(dim=1)
ensemble_similarity = image_features @ text_features.T

# Ensemble 성능 계산
ensemble_values, ensemble_indices_pred = ensemble_similarity.topk(single_topk, dim=1)
ensemble_correct = 0
ensemble_margins = []
for img_idx in range(len(images)):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    if ensemble_indices_pred[img_idx, 0] == gt_idx:
        ensemble_correct += 1
    if single_topk > 1:
        margin = (ensemble_values[img_idx, 0] - ensemble_values[img_idx, 1]).item()
        ensemble_margins.append(margin)

ensemble_acc = ensemble_correct / len(images) * 100
ensemble_avg_margin = np.mean(ensemble_margins) if ensemble_margins else 0

#최종 비교 결과
print("\n" + "="*80)
print(f'SINGLE PROMPT vs {len(templates)} TEMPLATE ENSEMBLE COMPARISON')
print("="*80)
print(f"{'Metric':<25} {'Single':<12} {'Ensemble':<12} {'Improvement':<12}")
print("-"*80)
print(f"Top-1 Accuracy     : {single_acc:6.1f}% ({single_correct:3d})  {ensemble_acc:6.1f}% ({ensemble_correct:3d})  {ensemble_acc-single_acc:+6.1f}%")
print(f"Avg Margin         : {single_avg_margin:8.3f}    {ensemble_avg_margin:8.3f}    {ensemble_avg_margin-single_avg_margin:+7.3f}")
print(f"Improvement Rate   : {'':<25} {'':<12} {((ensemble_correct/single_correct-1)*100 if single_correct>0 else 0):+6.1f}%")
print("="*80)


# 상세 결과 (첫 10개)
print("\nDETAILED RESULTS (first 5 images):")
for img_idx in range(min(5, len(images))):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    single_correct = "O" if single_indices_pred[img_idx, 0] == gt_idx else "X"
    ensemble_correct = "O" if ensemble_indices_pred[img_idx, 0] == gt_idx else "X"
    
    print(f"\nImage {img_idx+1} | Answer: {gt_labels[img_idx]}")
    print(f"  Single:  {pet_dataset.classes[single_indices_pred[img_idx, 0]]} {single_correct} ({single_values[img_idx, 0]:.3f})")
    print(f"  Ensemble:{pet_dataset.classes[ensemble_indices_pred[img_idx, 0]]} {ensemble_correct} ({ensemble_values[img_idx, 0]:.3f})")

# 이미지 시각화
n_show = min(8, len(images))
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.ravel()

for idx, (ax, img, gt) in enumerate(zip(axes, images[:n_show], gt_labels[:n_show])):
    ax.imshow(img)
    gt_idx = pet_dataset.class_to_idx[gt]
    single_pred = pet_dataset.classes[single_indices_pred[idx, 0]]
    ensemble_pred = pet_dataset.classes[ensemble_indices_pred[idx, 0]]
    single_status = "O" if single_indices_pred[idx, 0] == gt_idx else "X"
    ensemble_status = "O" if ensemble_indices_pred[idx, 0] == gt_idx else "X"
    
    title = f"Answer: {gt}\nSingle: {single_pred} {single_status}\nEnsemble: {ensemble_pred} {ensemble_status}"
    ax.set_title(title, fontsize=14)
    ax.axis("off")

plt.suptitle(f'Single Prompt vs {len(templates)}', fontsize=14)
plt.tight_layout()

7. 결과 비교

 
 

관찰 결과

  • Template Ensemble 방식이 일관되게 성능이 우수하다고는 단정 지을 수 없으나 여러 테스트 결과 정확도는 단일보다는 평균적으로 높음

8. 회고: 이 실험을 통해 얻은 인사이트

  1. CLIP은 프롬프트에 매우 민감한 모델이다.
  2. Zero-shot 분류 성능의 핵심은 모델이 아니라 텍스트 설계이다.
  3. Template Ensemble은 학습 없이 성능을 끌어올릴 수 있는 강력한 방법이다.
  4. 단순 Accuracy보다 Confidence Margin을 함께 보는 것이 중요하다.

+ Recent posts