CAMP-VQA / src /main_camp-vqa.py
Xinyi Wang
initial commit
b9b1b10
import argparse
import os
import time
import json
import logging
import torch
from torchvision import transforms
from tqdm import tqdm
import clip
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from extractor.extract_frag import VideoDataset_feature
from extractor.extract_clip_embeds import extract_features_clip_embed
from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
from utils.logger_setup import setup_logging
def setup_device():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
if device.type == "cuda":
torch.cuda.set_device(0)
logging.info(f"Using device: {device}")
return device
def load_prompts(json_path):
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
def get_transform(resize):
return transforms.Compose([transforms.Resize([resize, resize]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
def load_dataset(database, resize_transform, resize, patch_size, target_size, top_n):
if database == 'konvid_1k':
videos_dir = '/media/on23019/server1/video_dataset/KoNViD_1k/KoNViD_1k_videos/'
# videos_dir = 'D:/video_dataset/KoNViD_1k/KoNViD_1k_videos/'
elif database == 'live_vqc':
videos_dir = '/media/on23019/server1/video_dataset/LIVE-VQC/Video/'
elif database == 'cvd_2014':
videos_dir = '/media/on23019/server1/video_dataset/CVD2014/'
elif database == 'youtube_ugc_h264':
videos_dir = '/media/on23019/server1/video_dataset/youtube_ugc/all_h264/'
elif database == 'live_yt_gaming':
videos_dir = '/media/on23019/server1/video_dataset/LIVE-YT-Gaming/mp4/'
elif database == 'kvq':
videos_dir = '/media/on23019/server1/video_dataset/KVQ/'
elif database == 'finevd' or database == 'finevd_test':
videos_dir = '/media/on23019/server1/video_dataset/FineVD/videos_all/videos_all_videos/'
elif database == 'lsvq_test_1080p' or database =='lsvq_test' or database == 'lsvq_train':
videos_dir = '/media/on23019/server1/LSVQ/'
elif database == 'test':
videos_dir = '../test_videos/'
metadata_csv = f'../metadata/{database.upper()}_metadata.csv'
# split test: temp
# metadata_csv = f'../metadata/{database.upper()}_metadata_part5.csv'
print(metadata_csv)
return VideoDataset_feature(metadata_csv, videos_dir, database, resize_transform, resize, patch_size, target_size, top_n)
def process_videos(data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device):
slowfast_embs_list = []
swint_embs_list = []
semantic_embs_list = []
features_list = []
model_slowfast.eval()
model_swint.eval()
with torch.no_grad():
for i, (video_segments, video_res_frag_all, video_frag_all, video_name, frames_info, metadata) in enumerate(tqdm(data_loader, desc="Processing Videos")):
start_time = time.time()
try:
# slowfast features
_, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
_, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
_, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
# swinT feature
swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
# semantic features
image_embedding, quality_embedding, artifact_embedding = extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device)
image_embedding_avg = image_embedding.mean(dim=0)
quality_embedding_avg = quality_embedding.mean(dim=0)
artifact_embedding_avg = artifact_embedding.mean(dim=0)
# frame + residual fragment + frame fragment features
slowfast_features = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg), dim=0)
swint_features = torch.cat((swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
clip_features = torch.cat((image_embedding_avg, quality_embedding_avg, artifact_embedding_avg), dim=0)
logging.info(f"Slowfast Feature shape: {slowfast_features.shape}")
logging.info(f"SwinT Feature shape: {swint_features.shape}")
logging.info(f"Semantic CLIP Feature shape: {clip_features.shape}")
slowfast_embs_list.append(slowfast_features)
swint_embs_list.append(swint_features)
semantic_embs_list.append(clip_features)
vqa_feats = torch.cat((slowfast_features, swint_features, clip_features), dim=0)
logging.info(f"VQA Feature shape: {vqa_feats.shape}")
features_list.append(vqa_feats)
logging.info(f"Processed {video_name[0]} in {time.time() - start_time:.2f} seconds")
torch.cuda.empty_cache()
except (KeyboardInterrupt, Exception) as e:
if isinstance(e, KeyboardInterrupt):
logging.error("Processing interrupted by user.")
else:
logging.error(f"Failed to process video {video_name[0]}: {e}")
return slowfast_embs_list, swint_embs_list, semantic_embs_list, features_list
def save_features(features_list, pt_path):
if features_list:
features_tensor = torch.stack(features_list)
try:
torch.save(features_tensor, f"{pt_path}")
except Exception as e:
print(f"Failed to save features: {e}")
logging.info(f"Features saved to {pt_path}: {features_tensor.shape}\n")
else:
logging.warning("No features processed. Nothing to save.")
def main(config):
feature_save_path = os.path.abspath(os.path.join(config.feature_save_path, config.feat_name))
slowfast_pt_path = f'{feature_save_path}/slowfast_{config.database}_features.pt'
swint_pt_path = f'{feature_save_path}/swint_{config.database}_features.pt'
semantic_pt_path = f'{feature_save_path}/semantic_embs_{config.database}_features.pt'
vqa_pt_path = f'{feature_save_path}/{config.feat_name}_{config.database}_features.pt'
# split test: temp
# slowfast_pt_path = f'{feature_save_path}/slowfast_{config.database}_features_part5.pt'
# swint_pt_path = f'{feature_save_path}/swint_{config.database}_features_part5.pt'
# semantic_pt_path = f'{feature_save_path}/semantic_embs_{config.database}_features_part5.pt'
# vqa_pt_path = f'{feature_save_path}/{config.feat_name}_{config.database}_features_part5.pt'
print(vqa_pt_path)
if not os.path.exists(feature_save_path):
os.makedirs(feature_save_path)
prompts = load_prompts(config.prompt_path)
setup_logging(config.log_file)
device = setup_device()
resize_transform = get_transform(config.resize)
dataset = load_dataset(config.database, resize_transform, config.resize, config.patch_size, config.target_size, config.top_n)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count()), pin_memory=True
)
logging.info(f"Dataset loaded. Total videos: {len(dataset)}, Total batches: {len(data_loader)}")
model_slowfast = SlowFast().to(device)
model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device) # swin_large_patch4_window7_224.ms_in22k_ft_in1k
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
slowfast_embs_list, swint_embs_list, semantic_embs_list, features_list = process_videos(data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device)
save_features(slowfast_embs_list, slowfast_pt_path)
save_features(swint_embs_list, swint_pt_path)
save_features(semantic_embs_list, semantic_pt_path)
save_features(features_list, vqa_pt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--database', type=str, default='konvid_1k')
parser.add_argument('--feat_name', type=str, default='camp-vqa')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--resize', type=int, default=224)
parser.add_argument('--patch_size', type=int, default=16)
parser.add_argument('--target_size', type=int, default=224)
parser.add_argument('--top_n', type=int, default=14*14)
parser.add_argument('--feature_save_path', type=str, default=f"../features")
parser.add_argument('--prompt_path', type=str, default="./config/prompts.json")
parser.add_argument('--log_file', type=str, default="./utils/logging_vqa_feats.log")
config = parser.parse_args()
print(config.feat_name)
main(config)