Spaces:
Build error
Build error
Add submit button
Browse files- app.py +12 -4
- fromage/models.py +14 -10
app.py
CHANGED
|
@@ -127,19 +127,27 @@ with gr.Blocks(css=css) as demo:
|
|
| 127 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
| 128 |
|
| 129 |
with gr.Row():
|
| 130 |
-
with gr.Column(scale=0.3, min_width=
|
| 131 |
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
| 132 |
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
| 133 |
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
| 134 |
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
| 135 |
|
| 136 |
-
with gr.Column(scale=0.7, min_width=
|
| 137 |
image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
|
| 138 |
-
text_input = gr.Textbox(label="
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
| 142 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
|
|
|
|
|
|
|
|
|
| 143 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
| 144 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
| 145 |
share_button.click(None, [], [], _js=share_js)
|
|
|
|
| 127 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
| 128 |
|
| 129 |
with gr.Row():
|
| 130 |
+
with gr.Column(scale=0.3, min_width=100):
|
| 131 |
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
| 132 |
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
| 133 |
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
| 134 |
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
| 135 |
|
| 136 |
+
with gr.Column(scale=0.7, min_width=400):
|
| 137 |
image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
|
| 138 |
+
text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
|
| 139 |
+
|
| 140 |
+
with gr.Row():
|
| 141 |
+
with gr.Column(scale=0.5):
|
| 142 |
+
submit_btn = gr.Button("Submit", interactive=True, variant="primary")
|
| 143 |
+
with gr.Column(scale=0.5):
|
| 144 |
+
clear_btn = gr.Button("Clear History")
|
| 145 |
|
| 146 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
| 147 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
| 148 |
+
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
| 149 |
+
submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
|
| 150 |
+
|
| 151 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
| 152 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
| 153 |
share_button.click(None, [], [], _js=share_js)
|
fromage/models.py
CHANGED
|
@@ -634,21 +634,25 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
| 634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
| 635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 637 |
-
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
| 639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
| 640 |
|
| 641 |
# Initialize model for inference.
|
| 642 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
| 643 |
model = model.eval()
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
|
|
|
| 652 |
|
| 653 |
logit_scale = model.model.logit_scale.exp()
|
| 654 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
|
|
|
| 634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
| 635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 637 |
+
|
| 638 |
+
debug = False
|
| 639 |
+
if debug:
|
| 640 |
+
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
| 641 |
+
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
| 642 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
| 643 |
|
| 644 |
# Initialize model for inference.
|
| 645 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
| 646 |
model = model.eval()
|
| 647 |
+
if not debug:
|
| 648 |
+
model = model.bfloat16()
|
| 649 |
+
model = model.cuda()
|
| 650 |
+
|
| 651 |
+
# Load pretrained linear mappings and [RET] embeddings.
|
| 652 |
+
checkpoint = torch.load(model_ckpt_path)
|
| 653 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 654 |
+
with torch.no_grad():
|
| 655 |
+
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
|
| 656 |
|
| 657 |
logit_scale = model.model.logit_scale.exp()
|
| 658 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|