1. CLIP 모델 개요 – 내가 구현한 것은 무엇인가?

이번 글에서는 OpenAI의 CLIP(Contrastive Language–Image Pretraining) 모델을 이용해
이미지를 입력하면, 미리 정의한 텍스트 라벨 중 어떤 개념과 가장 유사한지 추론하는 모델을 구현해보았다.

특징은 다음과 같다.

  • 사전 학습된 CLIP 모델 사용
  • 추가 학습 없이 (Zero-shot) 이미지 분류 수행
  • 이미지와 텍스트를 같은 임베딩 공간으로 변환
  • 여러 개의 텍스트 프롬프트를 사용하는 Prompt Ensemble 적용 여부 비교

즉,

“이 이미지는 고양이입니다” 같은 정답 라벨을 학습시키는 방식이 아니라
이미지와 문장 간의 의미적 유사도를 계산해 분류하는 방식이다.

 

초기 구현 단계에서는 CLIP의 구조와 동작 원리를 명확히 이해하지 못했기 때문에
이미지 URL, 텍스트 라벨, 프롬프트 템플릿을 직접 하드코딩하여 실험 형태로 구현했다.


2. 라이브러리 및 모델 로딩

from PIL import Image import requests 
import torch from transformers 
import CLIPProcessor, CLIPModel from IPython.display 
import display
  • transformers : Hugging Face에서 제공하는 CLIP 모델 로딩
  • CLIPModel : 이미지 인코더 + 텍스트 인코더를 포함한 모델
  • CLIPProcessor : 이미지 전처리 + 텍스트 토크나이징을 동시에 담당
  • PIL, requests : 이미지 URL을 불러오기 위한 용도
device = "cuda" if torch.cuda.is_available() else "cpu" 
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) 
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  • ViT-B/32 기반 CLIP 모델 사용
  • GPU가 있다면 CUDA 사용
  • 학습은 하지 않고, pretrained weight 그대로 사용

3. 이미지 데이터 준비 (직접 입력)

urls = [ "https://images.pexels.com/photos/103123/pexels-photo-103123.jpeg", ... ]

초기 구현 단계에서는 데이터셋을 쓰지 않고,
이미지 URL을 직접 입력해서 CLIP이 어떤 결과를 내는지 확인했다.

images = [] for url in urls: img = Image.open(requests.get(url, stream=True).raw) images.append(img)
  • URL → PIL Image 객체 변환
  • 이후 CLIP Processor를 통해 모델 입력 형태로 변환됨

4. 텍스트 라벨 & 프롬프트 템플릿 구성

labels = ["animal", "object", "banana", "bird", "person","flower","food","scenery"]
  • CLIP은 정해진 클래스가 없음
  • 내가 직접 분류하고 싶은 개념(라벨)을 정의해야 함
templates = [ "a photo of a {}", "a close-up photo of a {}", "a blurry photo of a {}", ... ]

여기서 중요한 개념이 Prompt Engineering이다.

같은 라벨이라도

  • “a photo of a bird”
  • “a close-up photo of a bird”
  • “a photo of a bird in the wild”

처럼 표현이 다르면 임베딩 결과도 달라진다.

texts = [t.format(l) for l in labels for t in templates]

그래서 모든 라벨 × 모든 템플릿 조합을 만들어 텍스트 후보 집합을 구성했다.


5. 텍스트 임베딩 사전 계산

inputs_text = processor(text=texts, return_tensors="pt", padding=True).to(device)
  • 텍스트는 한 번에 배치 처리 가능
  • 이미지와 달리 길이만 다를 뿐 구조가 동일하기 때문
with torch.no_grad(): 
	text_embeds = model.get_text_features(**inputs_text) 
	text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
  • get_text_features() → 텍스트 임베딩 추출
  • L2 정규화
    • CLIP은 코사인 유사도 기반 비교를 하기 때문

6. Prompt Ensemble (프롬프트 앙상블)

text_embeds_ensemble = text_embeds.view(len(labels), num_templates, -1).mean(dim=1)

여기서 프롬프트 앙상블을 적용했다.

  • 같은 라벨에 대해 여러 문장 프롬프트 사용
  • 각각의 임베딩을 평균내어 라벨 대표 임베딩 생성

📌 효과

  • 특정 문장 표현에 과도하게 의존하지 않음
  • Zero-shot 분류 성능이 더 안정적

7. 이미지 임베딩 및 유사도 계산

inputs_image = processor(images=image, return_tensors="pt").to(device)
  • 이미지는 1장씩 처리
  • 이미지마다 해상도/비율이 다르기 때문
image_embeds = model.get_image_features(**inputs_image) 
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

8. 프롬프트 앙상블 적용 전 vs 후 비교

(1) 프롬프트 앙상블 미적용

logits_templates = image_embeds @ text_embeds.T 
probs_templates = logits_templates.softmax(dim=1)
  • 이미지 ↔ 모든 텍스트 문장 비교
  • 가장 유사한 문장 단위 결과 출력

(2) 프롬프트 앙상블 적용

logits_labels = image_embeds @ text_embeds_ensemble.T 
probs_labels = logits_labels.softmax(dim=1)
  • 이미지 ↔ 라벨 단위 비교
  • 실제 분류에 가까운 결과

9. 최종 코드 및 결과

from PIL import Image
import requests
import torch
from transformers import CLIPProcessor, CLIPModel
from IPython.display import display

# 모델과 프로세서
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# 이미지 불러오기
urls = [
    "https://images.pexels.com/photos/103123/pexels-photo-103123.jpeg",
    "https://images.pexels.com/photos/414712/pexels-photo-414712.jpeg",
    "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg",
    "https://images.pexels.com/photos/17474740/pexels-photo-17474740.jpeg",
    "https://images.pexels.com/photos/10875195/pexels-photo-10875195.jpeg",
    "https://images.pexels.com/photos/34026276/pexels-photo-34026276.jpeg"
]

images = []
for url in urls:
    img = Image.open(requests.get(url, stream=True).raw)
    images.append(img)

# 라벨과 템플릿 준비
templates = [
    "a photo of a {}",
    "a close-up photo of a {}",
    "a blurry photo of a {}",
    "a cropped photo of a {}",
    "a photo of a {} in the wild",
    "a photo of a {} outdoors",
]
labels = ["animal", "object", "banana", "bird", "person","flower","food","scenery"]

# 모든 텍스트 후보 만들기
texts = [t.format(l) for l in labels for t in templates]

# 텍스트 임베딩 미리 계산
# 배치처리 가능으로 인해 texts들을 한꺼번에 임베딩 할 수 있음
inputs_text = processor(text=texts, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
    text_embeds = model.get_text_features(**inputs_text)
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

num_templates = len(templates)

# 프롬프트 앙상블
text_embeds_ensemble = text_embeds.view(len(labels), num_templates, -1).mean(dim=1)  # [num_labels, dim]


# 이미지별 처리
topk = 3
thumb_size = (400, 500)

for idx, image in enumerate(images):
    # 이미지 임베딩
    # text와 다르게 이미지는 사이즈 등이 다르기 때문에 processer에서는 1:1로 다루려고 하므로 각각 하였다.
    inputs_image = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_embeds = model.get_image_features(**inputs_image)
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

    # 프롬프트 앙상블 미적용
    logits_templates = image_embeds @ text_embeds.T
    probs_templates = logits_templates.softmax(dim=1)
    
    # top-3
    values, indices = probs_templates.topk(topk, dim=1)
    print("\n프롬프트 앙상블 미적용: Top-3 템플릿")
    for rank, (v, i) in enumerate(zip(values[0], indices[0]), 1):
        print(f"{rank}. '{texts[i]}' - 확률: {v.item():.3f}")

    # 프롬프트 앙상블 적용
    logits_labels = image_embeds @ text_embeds_ensemble.T  # [1, num_labels]
    probs_labels = logits_labels.softmax(dim=1)

    # top-k 라벨 출력
    values, indices = probs_labels.topk(topk, dim=1)
    print(f"\n프롬프트 앙상블 적용: Top-{topk} 라벨")
    for rank, (v, i) in enumerate(zip(values[0], indices[0]), 1):
        print(f"{rank}. '{labels[i]}' - 확률: {v.item():.3f}")

    # 이미지 출력
    image_copy = image.copy()
    image_copy.thumbnail(thumb_size)
    display(image_copy)

9. 구현 회고 (초기 구현의 한계)

이 코드는 CLIP을 처음 다루면서 작성한 코드라서:

  • 이미지, 라벨, 프롬프트를 모두 직접 정의
  • 실제 데이터셋 기반 학습/평가는 없음
  • “CLIP이 이런 식으로 동작하는구나”를 확인하는 실험용 구현

하지만 이 과정을 통해 다음을 명확히 이해할 수 있었다.

  • CLIP은 분류 모델이 아니라 임베딩 모델
  • Zero-shot 분류의 핵심은 텍스트 설계
  • Prompt Ensemble이 성능에 미치는 영향

+ Recent posts