tooba248 commited on
Commit
7572379
·
verified ·
1 Parent(s): e403cae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -71
app.py CHANGED
@@ -1,96 +1,98 @@
1
- import gradio as gr
2
  import torch
3
  import clip
4
  from datasets import load_dataset
5
  from PIL import Image
6
- import faiss
 
7
  import requests
8
  from io import BytesIO
 
 
9
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # 1) Load base CLIP model + preprocess
13
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
14
 
15
- # 2) Load your finetuned weights (state_dict) into model_clip
16
- state_dict = torch.load("best_model.pt", map_location=device)
17
- missing, unexpected = model_clip.load_state_dict(state_dict, strict=False)
18
- print(f"⚠️ Missing keys: {missing}\n⚠️ Unexpected keys: {unexpected}")
19
  model_clip.eval()
20
 
21
- # 3) Build retrieval pool from Flickr30k test split
22
- dataset = load_dataset("nlphuji/flickr30k", split="test")
23
 
24
- images, captions = [], []
25
- img_embs, txt_embs = [], []
 
 
 
26
 
27
- print("🔄 Preparing retrieval pool embeddings...")
28
- for example in dataset:
 
29
  try:
30
- # load & store raw image + caption
31
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
 
 
 
 
 
32
  images.append(img)
33
  captions.append(example["sentence"])
 
 
 
 
34
 
35
- # encode image
36
- img_t = preprocess(img).unsqueeze(0).to(device)
37
- with torch.no_grad():
38
- v = model_clip.encode_image(img_t)
39
- v /= v.norm(dim=-1, keepdim=True)
40
- img_embs.append(v.cpu())
41
 
42
- # encode text
43
- t = clip.tokenize([example["sentence"]]).to(device)
44
- with torch.no_grad():
45
- tfeat = model_clip.encode_text(t)
46
- tfeat /= tfeat.norm(dim=-1, keepdim=True)
47
- txt_embs.append(tfeat.cpu())
48
- except:
49
- continue
50
 
51
- # cat into tensors
52
- img_embs = torch.cat(img_embs, dim=0)
53
- txt_embs = torch.cat(txt_embs, dim=0)
 
 
 
 
54
 
55
- # build FAISS indices (Inner‐Product = cosine)
56
- img_index = faiss.IndexFlatIP(img_embs.shape[1])
57
- img_index.add(img_embs.numpy())
 
 
 
 
58
 
59
- txt_index = faiss.IndexFlatIP(txt_embs.shape[1])
60
- txt_index.add(txt_embs.numpy())
 
 
 
 
 
 
 
61
 
62
- # 4) Gradio callbacks
63
- def image_to_text(inp_img):
64
- im = preprocess(inp_img).unsqueeze(0).to(device)
65
- with torch.no_grad():
66
- v = model_clip.encode_image(im)
67
- v /= v.norm(dim=-1, keepdim=True)
68
- D, I = txt_index.search(v.cpu().numpy(), 1)
69
- score = D[0][0] * 100
70
- return f"{captions[I[0][0]]}\n(Match Score: {score:.2f}%)"
71
-
72
- def text_to_image(inp_txt):
73
- tok = clip.tokenize([inp_txt]).to(device)
74
- with torch.no_grad():
75
- t = model_clip.encode_text(tok)
76
- t /= t.norm(dim=-1, keepdim=True)
77
- D, I = img_index.search(t.cpu().numpy(), 1)
78
- score = D[0][0] * 100
79
- return images[I[0][0]], f"Match Score: {score:.2f}%"
80
-
81
- # 5) Gradio UI
82
- with gr.Blocks() as demo:
83
- gr.Markdown("## 🔄 Cross-Modal Retriever (Flickr30k Test Split)\nUpload an image or enter text to retrieve the best match.")
84
-
85
- with gr.Tab("🖼️ Image → Text"):
86
- img_in = gr.Image(type="pil", label="Upload Image")
87
- txt_out = gr.Textbox(label="Retrieved Caption")
88
- gr.Button("Search Caption").click(image_to_text, img_in, txt_out)
89
-
90
- with gr.Tab("���� Text → Image"):
91
- txt_in = gr.Textbox(label="Enter Text")
92
- img_out = gr.Image(label="Retrieved Image")
93
- score_out = gr.Textbox(label="Score")
94
- gr.Button("Search Image").click(text_to_image, txt_in, [img_out, score_out])
95
-
96
- demo.launch()
 
 
1
  import torch
2
  import clip
3
  from datasets import load_dataset
4
  from PIL import Image
5
+ import gradio as gr
6
+ from torchvision import transforms
7
  import requests
8
  from io import BytesIO
9
+ import numpy as np
10
+ import faiss
11
 
12
+ # Set device
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Load CLIP model
16
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
17
 
18
+ # Load your fine-tuned model weights
19
+ fine_tuned_state_dict = torch.load("best_model.pt", map_location=device)
20
+ model_clip.load_state_dict(fine_tuned_state_dict)
21
+
22
  model_clip.eval()
23
 
24
+ # Load 50 samples from Flickr30k test split
25
+ dataset = load_dataset("nlphuji/flickr30k", split="test[:50]")
26
 
27
+ # Precompute embeddings
28
+ image_embeddings = []
29
+ images = []
30
+ captions = []
31
+ valid_indices = []
32
 
33
+ print("Extracting embeddings...")
34
+
35
+ for i, example in enumerate(dataset):
36
  try:
 
37
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
38
+ img_tensor = preprocess(img).unsqueeze(0).to(device)
39
+ with torch.no_grad():
40
+ img_feat = model_clip.encode_image(img_tensor)
41
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
42
+ image_embeddings.append(img_feat.cpu())
43
  images.append(img)
44
  captions.append(example["sentence"])
45
+ valid_indices.append(i)
46
+ except Exception as e:
47
+ print(f"Skipping sample {i} due to error: {e}")
48
+ continue
49
 
50
+ # Stack image features
51
+ image_embeddings = torch.cat(image_embeddings, dim=0)
 
 
 
 
52
 
53
+ # Build FAISS index
54
+ image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
55
+ image_index.add(image_embeddings.numpy())
 
 
 
 
 
56
 
57
+ # Search function
58
+ def search_by_text(query):
59
+ with torch.no_grad():
60
+ tokens = clip.tokenize([query]).to(device)
61
+ text_feat = model_clip.encode_text(tokens)
62
+ text_feat /= text_feat.norm(dim=-1, keepdim=True)
63
+ text_feat_np = text_feat.cpu().numpy()
64
 
65
+ D, I = image_index.search(text_feat_np, 5)
66
+ results = []
67
+ for idx in I[0]:
68
+ img = images[idx]
69
+ caption = captions[idx]
70
+ results.append((img, caption))
71
+ return results
72
 
73
+ # Gradio interface
74
+ def display_results(text_query):
75
+ results = search_by_text(text_query)
76
+ output = ""
77
+ for i, (img, caption) in enumerate(results):
78
+ output += f"### Result {i+1}\n"
79
+ output += f"**Caption:** {caption}\n\n"
80
+ output += f"![img](data:image/png;base64,{image_to_base64(img)})\n\n"
81
+ return output
82
 
83
+ # Convert PIL image to base64
84
+ import base64
85
+ from io import BytesIO
86
+
87
+ def image_to_base64(image):
88
+ buffer = BytesIO()
89
+ image.save(buffer, format="PNG")
90
+ return base64.b64encode(buffer.getvalue()).decode()
91
+
92
+ iface = gr.Interface(fn=display_results,
93
+ inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."),
94
+ outputs="markdown",
95
+ title="Text-to-Image Retrieval with CLIP",
96
+ description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.")
97
+
98
+ iface.launch()