|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoTokenizer |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import pandas as pd |
|
|
|
|
|
def wa5(logits): |
|
|
logprobs = np.array([logits["Excellent"], logits["Good"], logits["Fair"], logits["Poor"], logits["Bad"]]) |
|
|
probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) |
|
|
return np.inner(probs, np.array([1, 0.75, 0.5, 0.25, 0])) |
|
|
|
|
|
|
|
|
model_id = "models/q-sit" |
|
|
model = LlavaOnevisionForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
).to(0) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
toks = ["Excellent", "Good", "Fair", "Poor", "Bad"] |
|
|
ids_ = [id_[0] for id_ in tokenizer(toks)["input_ids"]] |
|
|
|
|
|
|
|
|
conversation = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "Assume you are an image quality evaluator. \nYour rating should be chosen from the following five categories: Excellent, Good, Fair, Poor, and Bad (from high to low). \nHow would you rate the quality of this image?"}, |
|
|
{"type": "image"}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
|
|
|
def predict_score(image_file): |
|
|
"""Return quality score only""" |
|
|
raw_image = Image.open(image_file).convert('RGB') |
|
|
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16) |
|
|
|
|
|
prefix_text = "The quality of this image is " |
|
|
prefix_ids = tokenizer(prefix_text, return_tensors="pt")["input_ids"].to(0) |
|
|
inputs["input_ids"] = torch.cat([inputs["input_ids"], prefix_ids], dim=-1) |
|
|
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) |
|
|
|
|
|
output = model.generate(**inputs, max_new_tokens=1, output_logits=True, return_dict_in_generate=True) |
|
|
last_logits = output.logits[-1][0] |
|
|
logits_dict = {tok: last_logits[id_].item() for tok, id_ in zip(toks, ids_)} |
|
|
weighted_score = wa5(logits_dict) |
|
|
|
|
|
|
|
|
probs = {tok: np.exp(logits_dict[tok]) for tok in toks} |
|
|
total = sum(probs.values()) |
|
|
probs = {tok: prob/total for tok, prob in probs.items()} |
|
|
rating = max(probs, key=probs.get) |
|
|
|
|
|
return weighted_score, rating, probs |
|
|
|
|
|
|
|
|
image_dir = Path("./your_images") |
|
|
image_files = sorted(list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png"))) |
|
|
|
|
|
print(f"Found {len(image_files)} images") |
|
|
|
|
|
results = [] |
|
|
for img_path in tqdm(image_files, desc="Processing images"): |
|
|
try: |
|
|
score, rating, probs = predict_score(img_path) |
|
|
results.append({ |
|
|
"image": str(img_path.name), |
|
|
"path": str(img_path), |
|
|
"score_0_1": round(float(score), 4), |
|
|
"score_0_5": round(float(score * 5), 4), |
|
|
"rating": rating, |
|
|
"probs": {k: round(v, 4) for k, v in probs.items()} |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error processing {img_path.name}: {e}") |
|
|
|
|
|
|
|
|
with open("quality_scores.json", "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
|
|
|
df = pd.DataFrame([{ |
|
|
"image": r["image"], |
|
|
"score_0_5": r["score_0_5"], |
|
|
"rating": r["rating"], |
|
|
"prob_excellent": r["probs"]["Excellent"], |
|
|
"prob_good": r["probs"]["Good"], |
|
|
"prob_fair": r["probs"]["Fair"], |
|
|
"prob_poor": r["probs"]["Poor"], |
|
|
"prob_bad": r["probs"]["Bad"], |
|
|
} for r in results]) |
|
|
df.to_csv("quality_scores.csv", index=False) |
|
|
|
|
|
|
|
|
df_sorted = df.sort_values("score_0_5", ascending=False) |
|
|
print("\n=== Top 10 High Quality Images ===") |
|
|
print(df_sorted.head(10)[["image", "score_0_5", "rating"]]) |
|
|
print("\n=== Bottom 10 Low Quality Images ===") |
|
|
print(df_sorted.tail(10)[["image", "score_0_5", "rating"]]) |
|
|
|
|
|
print(f"\nResults saved to quality_scores.json and quality_scores.csv") |