Spaces:
Build error
Build error
Bug fixes
Browse files- .gitignore +1 -0
- app.py +19 -15
- fromage/models.py +3 -2
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
.DS_Store
|
|
|
|
|
|
| 1 |
.DS_Store
|
| 2 |
+
venv/
|
app.py
CHANGED
|
@@ -19,13 +19,15 @@ model = models.load_fromage('./', args_path, ckpt_path)
|
|
| 19 |
|
| 20 |
|
| 21 |
def upload_image(state, image_input):
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
| 24 |
-
return [
|
| 25 |
|
| 26 |
|
| 27 |
def reset():
|
| 28 |
-
return [[], None], []
|
| 29 |
|
| 30 |
|
| 31 |
def save_image_to_local(image: Image.Image):
|
|
@@ -37,16 +39,19 @@ def save_image_to_local(image: Image.Image):
|
|
| 37 |
|
| 38 |
def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
| 39 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
| 40 |
-
|
| 41 |
-
chat_history
|
|
|
|
| 42 |
print('Generating for', chat_history, flush=True)
|
| 43 |
|
| 44 |
# If an image was uploaded, prepend it to the model.
|
| 45 |
model_inputs = None
|
| 46 |
if input_image is not None:
|
| 47 |
-
model_inputs = [input_image
|
| 48 |
else:
|
| 49 |
-
model_inputs =
|
|
|
|
|
|
|
| 50 |
|
| 51 |
top_p = 1.0
|
| 52 |
if temperature != 0.0:
|
|
@@ -74,15 +79,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
|
|
| 74 |
response += f'<img src="/file={filename}">'
|
| 75 |
|
| 76 |
# TODO(jykoh): Persist image inputs.
|
| 77 |
-
chat_history
|
| 78 |
-
|
| 79 |
-
chat_history += '\n'
|
| 80 |
-
|
| 81 |
-
state.append((input_text, response))
|
| 82 |
|
| 83 |
# Set input image to None.
|
| 84 |
print('state', state, flush=True)
|
| 85 |
-
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
with gr.Blocks() as demo:
|
|
@@ -91,7 +94,7 @@ with gr.Blocks() as demo:
|
|
| 91 |
)
|
| 92 |
|
| 93 |
chatbot = gr.Chatbot()
|
| 94 |
-
gr_state = gr.State([[], None]) # chat_history, input_image
|
| 95 |
|
| 96 |
with gr.Row():
|
| 97 |
with gr.Column(scale=0.3, min_width=0):
|
|
@@ -106,7 +109,8 @@ with gr.Blocks() as demo:
|
|
| 106 |
clear_btn = gr.Button("Clear History")
|
| 107 |
|
| 108 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
|
|
|
| 109 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
| 110 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
| 111 |
|
| 112 |
-
demo.launch(share=False, debug=True, server_name="
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def upload_image(state, image_input):
|
| 22 |
+
conversation = state[0]
|
| 23 |
+
chat_history = state[1]
|
| 24 |
+
conversation += [(f"", "")]
|
| 25 |
input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
| 26 |
+
return [conversation, chat_history, input_image], conversation
|
| 27 |
|
| 28 |
|
| 29 |
def reset():
|
| 30 |
+
return [[], [], None], []
|
| 31 |
|
| 32 |
|
| 33 |
def save_image_to_local(image: Image.Image):
|
|
|
|
| 39 |
|
| 40 |
def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
| 41 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
| 42 |
+
conversation = state[0]
|
| 43 |
+
chat_history = state[1]
|
| 44 |
+
input_image = state[2]
|
| 45 |
print('Generating for', chat_history, flush=True)
|
| 46 |
|
| 47 |
# If an image was uploaded, prepend it to the model.
|
| 48 |
model_inputs = None
|
| 49 |
if input_image is not None:
|
| 50 |
+
model_inputs = chat_history + [input_image]
|
| 51 |
else:
|
| 52 |
+
model_inputs = chat_history
|
| 53 |
+
|
| 54 |
+
model_inputs.append(input_prompt)
|
| 55 |
|
| 56 |
top_p = 1.0
|
| 57 |
if temperature != 0.0:
|
|
|
|
| 79 |
response += f'<img src="/file={filename}">'
|
| 80 |
|
| 81 |
# TODO(jykoh): Persist image inputs.
|
| 82 |
+
chat_history = model_inputs + model_outputs
|
| 83 |
+
conversation.append((input_text, response))
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Set input image to None.
|
| 86 |
print('state', state, flush=True)
|
| 87 |
+
print('updated state', [conversation, chat_history, None], flush=True)
|
| 88 |
+
return [conversation, chat_history, None], conversation
|
| 89 |
|
| 90 |
|
| 91 |
with gr.Blocks() as demo:
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
chatbot = gr.Chatbot()
|
| 97 |
+
gr_state = gr.State([[], [], None]) # chat_history, input_image
|
| 98 |
|
| 99 |
with gr.Row():
|
| 100 |
with gr.Column(scale=0.3, min_width=0):
|
|
|
|
| 109 |
clear_btn = gr.Button("Clear History")
|
| 110 |
|
| 111 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
| 112 |
+
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
| 113 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
| 114 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
| 115 |
|
| 116 |
+
demo.launch(share=False, debug=True, server_name="127.0.0.1")
|
fromage/models.py
CHANGED
|
@@ -628,13 +628,14 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
| 628 |
|
| 629 |
# Initialize tokenizer.
|
| 630 |
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
| 631 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 632 |
# Add special tokens to the model to enable [RET].
|
| 633 |
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
| 634 |
tokenizer.add_tokens('[RET]')
|
| 635 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
| 636 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 637 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
|
|
|
|
|
|
| 638 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
| 639 |
|
| 640 |
# Initialize model for inference.
|
|
@@ -643,7 +644,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
| 643 |
model = model.bfloat16()
|
| 644 |
model = model.cuda()
|
| 645 |
|
| 646 |
-
|
| 647 |
checkpoint = torch.load(model_ckpt_path)
|
| 648 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 649 |
with torch.no_grad():
|
|
|
|
| 628 |
|
| 629 |
# Initialize tokenizer.
|
| 630 |
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
|
|
|
| 631 |
# Add special tokens to the model to enable [RET].
|
| 632 |
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
| 633 |
tokenizer.add_tokens('[RET]')
|
| 634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
| 635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 637 |
+
# model_kwargs['opt_version'] = 'facebook/opt-125m'
|
| 638 |
+
# model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
| 639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
| 640 |
|
| 641 |
# Initialize model for inference.
|
|
|
|
| 644 |
model = model.bfloat16()
|
| 645 |
model = model.cuda()
|
| 646 |
|
| 647 |
+
Load pretrained linear mappings and [RET] embeddings.
|
| 648 |
checkpoint = torch.load(model_ckpt_path)
|
| 649 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 650 |
with torch.no_grad():
|