|
|
import re |
|
|
import torch |
|
|
import clip |
|
|
import numpy as np |
|
|
from numpy.linalg import norm |
|
|
from PIL import Image |
|
|
|
|
|
def get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints): |
|
|
hint = [] |
|
|
if mos > 5: |
|
|
mos = (mos / 100) * 5 |
|
|
if mos >= 4.5: |
|
|
hint.append(quality_hints["mos"]["excellent"]) |
|
|
elif 3.5 <= mos < 4.5: |
|
|
hint.append(quality_hints["mos"]["good"]) |
|
|
elif 2.5 <= mos < 3.5: |
|
|
hint.append(quality_hints["mos"]["fair"]) |
|
|
elif 1.5 <= mos < 2.5: |
|
|
hint.append(quality_hints["mos"]["bad"]) |
|
|
else: |
|
|
hint.append(quality_hints["mos"]["poor"]) |
|
|
|
|
|
res = width * height |
|
|
if res < 640 * 480: |
|
|
hint.append(quality_hints["resolution"]["low"]) |
|
|
elif res < 1280 * 720: |
|
|
hint.append(quality_hints["resolution"]["sd"]) |
|
|
else: |
|
|
hint.append(quality_hints["resolution"]["hd"]) |
|
|
if bitrate < 500_000: |
|
|
hint.append(quality_hints["bitrate"]["low"]) |
|
|
elif bitrate < 1_000_000: |
|
|
hint.append(quality_hints["bitrate"]["medium"]) |
|
|
else: |
|
|
hint.append(quality_hints["bitrate"]["high"]) |
|
|
|
|
|
if 0 < bitdepth <= 8: |
|
|
hint.append(quality_hints["bitdepth"]["low"]) |
|
|
elif bitdepth == 0: |
|
|
hint.append(quality_hints["bitdepth"]["standard"]) |
|
|
else: |
|
|
hint.append(quality_hints["bitdepth"]["high"]) |
|
|
if framerate < 24: |
|
|
hint.append(quality_hints["framerate"]["low"]) |
|
|
elif framerate > 60: |
|
|
hint.append(quality_hints["framerate"]["high"]) |
|
|
else: |
|
|
hint.append(quality_hints["framerate"]["standard"]) |
|
|
return " ".join(hint) |
|
|
|
|
|
def generate_caption(blip_processor, blip_model, device, image, prompt): |
|
|
inputs = blip_processor(image, prompt, return_tensors="pt").to(device) |
|
|
generated_ids = blip_model.generate(**inputs, max_new_tokens=50) |
|
|
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True) |
|
|
return caption |
|
|
|
|
|
def tensor_to_pil(image_tensor): |
|
|
if isinstance(image_tensor, torch.Tensor): |
|
|
arr = image_tensor.cpu().numpy() |
|
|
if arr.ndim == 4 and arr.shape[0] == 1: |
|
|
arr = arr[0] |
|
|
arr = arr.astype('uint8') |
|
|
return Image.fromarray(arr) |
|
|
|
|
|
def extract_semantic_captions(blip_processor, blip_model, curr_frame, frag_residual, frag_frame, prompts, device, metadata=None, use_metadata_prompt=False): |
|
|
quality_prompt_base = prompts["quality_prompt_base"] |
|
|
residual_prompt = prompts["residual_prompt"] |
|
|
frag_prompt = prompts["frag_prompt"] |
|
|
|
|
|
quality_hint = "" |
|
|
if use_metadata_prompt and metadata: |
|
|
mos, width, height, bitrate, bitdepth, framerate = metadata |
|
|
quality_hint = get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints=prompts["quality_hints"]) |
|
|
|
|
|
prompt_hints = [] |
|
|
if quality_hint: |
|
|
prompt_hints.append(quality_hint) |
|
|
|
|
|
quality_prompt = "\n\n".join(prompt_hints + [quality_prompt_base]) |
|
|
fragment_prompt = "\n\n".join(prompt_hints) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captions = { |
|
|
"curr_frame_quality": generate_caption(blip_processor, blip_model, device, curr_frame, prompt=quality_prompt), |
|
|
"frag_residual": generate_caption(blip_processor, blip_model, device, frag_residual, prompt=(fragment_prompt + "\n\n" + residual_prompt)), |
|
|
"frag_frame": generate_caption(blip_processor, blip_model, device, frag_frame, prompt=(fragment_prompt + "\n\n" + frag_prompt)) |
|
|
} |
|
|
return captions |
|
|
|
|
|
def clean_caption_text(text): |
|
|
text = re.sub(r"- .*?stock videos & royalty-free footage", "", text) |
|
|
text = re.sub(r"\s+", " ", text) |
|
|
return text.strip() |
|
|
|
|
|
def dedup_keywords(text, split_tokens=[",", ".", ";"]): |
|
|
for token in split_tokens: |
|
|
text = text.replace(token, ",") |
|
|
parts = [p.strip().lower() for p in text.split(",") if p.strip()] |
|
|
seen = set() |
|
|
unique_parts = [] |
|
|
for part in parts: |
|
|
if part not in seen: |
|
|
unique_parts.append(part) |
|
|
seen.add(part) |
|
|
return " ".join(unique_parts) |
|
|
|
|
|
def get_clip_text_embedding(clip_model, device, text): |
|
|
text_tokens = clip.tokenize([text]).to(device) |
|
|
with torch.no_grad(): |
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
text_features = clip_model.encode_text(text_tokens) |
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
return text_features.squeeze() |
|
|
|
|
|
def get_clip_image_embedding(clip_model, clip_preprocess, device, image): |
|
|
image_input = clip_preprocess(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
image_features = clip_model.encode_image(image_input) |
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
return image_features.squeeze() |
|
|
|
|
|
def extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions): |
|
|
if not isinstance(curr_frame, Image.Image): |
|
|
curr_frame = Image.fromarray(curr_frame) |
|
|
|
|
|
quality_caption = dedup_keywords(clean_caption_text(captions["curr_frame_quality"])) |
|
|
artifact_caption_1 = dedup_keywords(clean_caption_text(captions["frag_residual"])) |
|
|
artifact_caption_2 = dedup_keywords(clean_caption_text(captions["frag_frame"])) |
|
|
artifact_caption = dedup_keywords(f"{artifact_caption_1}, {artifact_caption_2}") |
|
|
|
|
|
image_embed = get_clip_image_embedding(clip_model, clip_preprocess, device, curr_frame) |
|
|
quality_embed = get_clip_text_embedding(clip_model, device, quality_caption) |
|
|
artifact_embed = get_clip_text_embedding(clip_model, device, artifact_caption) |
|
|
return image_embed, quality_embed, artifact_embed |
|
|
|
|
|
def extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device): |
|
|
feature_image_embed = [] |
|
|
feature_quality_embed = [] |
|
|
feature_artifact_embed = [] |
|
|
for i, (curr_frame, frag_residual, frag_frame) in enumerate(frames_info): |
|
|
curr_frame = tensor_to_pil(curr_frame) |
|
|
frag_residual = tensor_to_pil(frag_residual) |
|
|
frag_frame = tensor_to_pil(frag_frame) |
|
|
|
|
|
captions = extract_semantic_captions( |
|
|
blip_processor, blip_model, |
|
|
curr_frame, frag_residual, frag_frame, prompts, |
|
|
device, |
|
|
metadata=metadata, |
|
|
use_metadata_prompt=True, |
|
|
) |
|
|
image_embed, quality_embed, artifact_embed = extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions) |
|
|
feature_image_embed.append(image_embed) |
|
|
feature_quality_embed.append(quality_embed) |
|
|
feature_artifact_embed.append(artifact_embed) |
|
|
|
|
|
|
|
|
image_embedding = torch.stack(feature_image_embed, dim=0) |
|
|
quality_embedding = torch.stack(feature_quality_embed, dim=0) |
|
|
artifact_embedding = torch.stack(feature_artifact_embed, dim=0) |
|
|
|
|
|
return image_embedding, quality_embedding, artifact_embedding |
|
|
|