File size: 7,331 Bytes
b9b1b10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import re
import torch
import clip
import numpy as np
from numpy.linalg import norm
from PIL import Image
def get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints):
hint = []
if mos > 5:
mos = (mos / 100) * 5
if mos >= 4.5:
hint.append(quality_hints["mos"]["excellent"])
elif 3.5 <= mos < 4.5:
hint.append(quality_hints["mos"]["good"])
elif 2.5 <= mos < 3.5:
hint.append(quality_hints["mos"]["fair"])
elif 1.5 <= mos < 2.5:
hint.append(quality_hints["mos"]["bad"])
else:
hint.append(quality_hints["mos"]["poor"])
res = width * height
if res < 640 * 480:
hint.append(quality_hints["resolution"]["low"])
elif res < 1280 * 720:
hint.append(quality_hints["resolution"]["sd"])
else:
hint.append(quality_hints["resolution"]["hd"])
if bitrate < 500_000:
hint.append(quality_hints["bitrate"]["low"])
elif bitrate < 1_000_000:
hint.append(quality_hints["bitrate"]["medium"])
else:
hint.append(quality_hints["bitrate"]["high"])
if 0 < bitdepth <= 8:
hint.append(quality_hints["bitdepth"]["low"])
elif bitdepth == 0:
hint.append(quality_hints["bitdepth"]["standard"])
else:
hint.append(quality_hints["bitdepth"]["high"])
if framerate < 24:
hint.append(quality_hints["framerate"]["low"])
elif framerate > 60:
hint.append(quality_hints["framerate"]["high"])
else:
hint.append(quality_hints["framerate"]["standard"])
return " ".join(hint)
def generate_caption(blip_processor, blip_model, device, image, prompt):
inputs = blip_processor(image, prompt, return_tensors="pt").to(device)
generated_ids = blip_model.generate(**inputs, max_new_tokens=50)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
def tensor_to_pil(image_tensor):
if isinstance(image_tensor, torch.Tensor):
arr = image_tensor.cpu().numpy()
if arr.ndim == 4 and arr.shape[0] == 1:
arr = arr[0] # remove batch dimension
arr = arr.astype('uint8')
return Image.fromarray(arr)
def extract_semantic_captions(blip_processor, blip_model, curr_frame, frag_residual, frag_frame, prompts, device, metadata=None, use_metadata_prompt=False):
quality_prompt_base = prompts["quality_prompt_base"]
residual_prompt = prompts["residual_prompt"]
frag_prompt = prompts["frag_prompt"]
quality_hint = ""
if use_metadata_prompt and metadata:
mos, width, height, bitrate, bitdepth, framerate = metadata
quality_hint = get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints=prompts["quality_hints"])
prompt_hints = []
if quality_hint:
prompt_hints.append(quality_hint)
quality_prompt = "\n\n".join(prompt_hints + [quality_prompt_base])
fragment_prompt = "\n\n".join(prompt_hints)
# print('content_prompt:', content_prompt)
# print('quality_prompt:', quality_prompt)
# print('residual_prompt:', fragment_prompt + "\n\n" + residual_prompt)
# print('frame_fragment_prompt:', fragment_prompt + "\n\n" + frag_prompt)
captions = {
"curr_frame_quality": generate_caption(blip_processor, blip_model, device, curr_frame, prompt=quality_prompt),
"frag_residual": generate_caption(blip_processor, blip_model, device, frag_residual, prompt=(fragment_prompt + "\n\n" + residual_prompt)),
"frag_frame": generate_caption(blip_processor, blip_model, device, frag_frame, prompt=(fragment_prompt + "\n\n" + frag_prompt))
}
return captions
def clean_caption_text(text):
text = re.sub(r"- .*?stock videos & royalty-free footage", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def dedup_keywords(text, split_tokens=[",", ".", ";"]):
for token in split_tokens:
text = text.replace(token, ",")
parts = [p.strip().lower() for p in text.split(",") if p.strip()]
seen = set()
unique_parts = []
for part in parts:
if part not in seen:
unique_parts.append(part)
seen.add(part)
return " ".join(unique_parts) # good for embedding
def get_clip_text_embedding(clip_model, device, text):
text_tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
with torch.amp.autocast(device_type='cuda'):
text_features = clip_model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.squeeze()
def get_clip_image_embedding(clip_model, clip_preprocess, device, image):
image_input = clip_preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
with torch.amp.autocast(device_type='cuda'):
image_features = clip_model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.squeeze()
def extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions):
if not isinstance(curr_frame, Image.Image):
curr_frame = Image.fromarray(curr_frame)
quality_caption = dedup_keywords(clean_caption_text(captions["curr_frame_quality"]))
artifact_caption_1 = dedup_keywords(clean_caption_text(captions["frag_residual"]))
artifact_caption_2 = dedup_keywords(clean_caption_text(captions["frag_frame"]))
artifact_caption = dedup_keywords(f"{artifact_caption_1}, {artifact_caption_2}")
image_embed = get_clip_image_embedding(clip_model, clip_preprocess, device, curr_frame)
quality_embed = get_clip_text_embedding(clip_model, device, quality_caption)
artifact_embed = get_clip_text_embedding(clip_model, device, artifact_caption)
return image_embed, quality_embed, artifact_embed
def extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device):
feature_image_embed = []
feature_quality_embed = []
feature_artifact_embed = []
for i, (curr_frame, frag_residual, frag_frame) in enumerate(frames_info):
curr_frame = tensor_to_pil(curr_frame)
frag_residual = tensor_to_pil(frag_residual)
frag_frame = tensor_to_pil(frag_frame)
captions = extract_semantic_captions(
blip_processor, blip_model,
curr_frame, frag_residual, frag_frame, prompts,
device,
metadata=metadata,
use_metadata_prompt=True,
)
image_embed, quality_embed, artifact_embed = extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions)
feature_image_embed.append(image_embed)
feature_quality_embed.append(quality_embed)
feature_artifact_embed.append(artifact_embed)
# concatenate features
image_embedding = torch.stack(feature_image_embed, dim=0)
quality_embedding = torch.stack(feature_quality_embed, dim=0)
artifact_embedding = torch.stack(feature_artifact_embed, dim=0)
# print("image_embedding.shape:", image_embedding.shape, "quality_embedding.shape:", quality_embedding.shape, "artifact_embedding.shape:", artifact_embedding.shape)
return image_embedding, quality_embedding, artifact_embedding
|