CAMP-VQA / src /extractor /extract_clip_embeds.py
Xinyi Wang
initial commit
b9b1b10
import re
import torch
import clip
import numpy as np
from numpy.linalg import norm
from PIL import Image
def get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints):
hint = []
if mos > 5:
mos = (mos / 100) * 5
if mos >= 4.5:
hint.append(quality_hints["mos"]["excellent"])
elif 3.5 <= mos < 4.5:
hint.append(quality_hints["mos"]["good"])
elif 2.5 <= mos < 3.5:
hint.append(quality_hints["mos"]["fair"])
elif 1.5 <= mos < 2.5:
hint.append(quality_hints["mos"]["bad"])
else:
hint.append(quality_hints["mos"]["poor"])
res = width * height
if res < 640 * 480:
hint.append(quality_hints["resolution"]["low"])
elif res < 1280 * 720:
hint.append(quality_hints["resolution"]["sd"])
else:
hint.append(quality_hints["resolution"]["hd"])
if bitrate < 500_000:
hint.append(quality_hints["bitrate"]["low"])
elif bitrate < 1_000_000:
hint.append(quality_hints["bitrate"]["medium"])
else:
hint.append(quality_hints["bitrate"]["high"])
if 0 < bitdepth <= 8:
hint.append(quality_hints["bitdepth"]["low"])
elif bitdepth == 0:
hint.append(quality_hints["bitdepth"]["standard"])
else:
hint.append(quality_hints["bitdepth"]["high"])
if framerate < 24:
hint.append(quality_hints["framerate"]["low"])
elif framerate > 60:
hint.append(quality_hints["framerate"]["high"])
else:
hint.append(quality_hints["framerate"]["standard"])
return " ".join(hint)
def generate_caption(blip_processor, blip_model, device, image, prompt):
inputs = blip_processor(image, prompt, return_tensors="pt").to(device)
generated_ids = blip_model.generate(**inputs, max_new_tokens=50)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
def tensor_to_pil(image_tensor):
if isinstance(image_tensor, torch.Tensor):
arr = image_tensor.cpu().numpy()
if arr.ndim == 4 and arr.shape[0] == 1:
arr = arr[0] # remove batch dimension
arr = arr.astype('uint8')
return Image.fromarray(arr)
def extract_semantic_captions(blip_processor, blip_model, curr_frame, frag_residual, frag_frame, prompts, device, metadata=None, use_metadata_prompt=False):
quality_prompt_base = prompts["quality_prompt_base"]
residual_prompt = prompts["residual_prompt"]
frag_prompt = prompts["frag_prompt"]
quality_hint = ""
if use_metadata_prompt and metadata:
mos, width, height, bitrate, bitdepth, framerate = metadata
quality_hint = get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints=prompts["quality_hints"])
prompt_hints = []
if quality_hint:
prompt_hints.append(quality_hint)
quality_prompt = "\n\n".join(prompt_hints + [quality_prompt_base])
fragment_prompt = "\n\n".join(prompt_hints)
# print('content_prompt:', content_prompt)
# print('quality_prompt:', quality_prompt)
# print('residual_prompt:', fragment_prompt + "\n\n" + residual_prompt)
# print('frame_fragment_prompt:', fragment_prompt + "\n\n" + frag_prompt)
captions = {
"curr_frame_quality": generate_caption(blip_processor, blip_model, device, curr_frame, prompt=quality_prompt),
"frag_residual": generate_caption(blip_processor, blip_model, device, frag_residual, prompt=(fragment_prompt + "\n\n" + residual_prompt)),
"frag_frame": generate_caption(blip_processor, blip_model, device, frag_frame, prompt=(fragment_prompt + "\n\n" + frag_prompt))
}
return captions
def clean_caption_text(text):
text = re.sub(r"- .*?stock videos & royalty-free footage", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def dedup_keywords(text, split_tokens=[",", ".", ";"]):
for token in split_tokens:
text = text.replace(token, ",")
parts = [p.strip().lower() for p in text.split(",") if p.strip()]
seen = set()
unique_parts = []
for part in parts:
if part not in seen:
unique_parts.append(part)
seen.add(part)
return " ".join(unique_parts) # good for embedding
def get_clip_text_embedding(clip_model, device, text):
text_tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
with torch.amp.autocast(device_type='cuda'):
text_features = clip_model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.squeeze()
def get_clip_image_embedding(clip_model, clip_preprocess, device, image):
image_input = clip_preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
with torch.amp.autocast(device_type='cuda'):
image_features = clip_model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.squeeze()
def extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions):
if not isinstance(curr_frame, Image.Image):
curr_frame = Image.fromarray(curr_frame)
quality_caption = dedup_keywords(clean_caption_text(captions["curr_frame_quality"]))
artifact_caption_1 = dedup_keywords(clean_caption_text(captions["frag_residual"]))
artifact_caption_2 = dedup_keywords(clean_caption_text(captions["frag_frame"]))
artifact_caption = dedup_keywords(f"{artifact_caption_1}, {artifact_caption_2}")
image_embed = get_clip_image_embedding(clip_model, clip_preprocess, device, curr_frame)
quality_embed = get_clip_text_embedding(clip_model, device, quality_caption)
artifact_embed = get_clip_text_embedding(clip_model, device, artifact_caption)
return image_embed, quality_embed, artifact_embed
def extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device):
feature_image_embed = []
feature_quality_embed = []
feature_artifact_embed = []
for i, (curr_frame, frag_residual, frag_frame) in enumerate(frames_info):
curr_frame = tensor_to_pil(curr_frame)
frag_residual = tensor_to_pil(frag_residual)
frag_frame = tensor_to_pil(frag_frame)
captions = extract_semantic_captions(
blip_processor, blip_model,
curr_frame, frag_residual, frag_frame, prompts,
device,
metadata=metadata,
use_metadata_prompt=True,
)
image_embed, quality_embed, artifact_embed = extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions)
feature_image_embed.append(image_embed)
feature_quality_embed.append(quality_embed)
feature_artifact_embed.append(artifact_embed)
# concatenate features
image_embedding = torch.stack(feature_image_embed, dim=0)
quality_embedding = torch.stack(feature_quality_embed, dim=0)
artifact_embedding = torch.stack(feature_artifact_embed, dim=0)
# print("image_embedding.shape:", image_embedding.shape, "quality_embedding.shape:", quality_embedding.shape, "artifact_embedding.shape:", artifact_embedding.shape)
return image_embedding, quality_embedding, artifact_embedding