Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from onnxruntime import InferenceSession | |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
| MAX_SEQUENCE_LENGTH = 512 | |
| models = { | |
| "Base model": "madlag/bert-large-uncased-whole-word-masking-finetuned-squadv2", | |
| "Pruned model": "madlag/bert-large-uncased-wwm-squadv2-x2.63-f82.6-d16-hybrid-v1", | |
| "Pruned ONNX Optimized FP16": "tryolabs/bert-large-uncased-wwm-squadv2-optimized-f16", | |
| } | |
| loaded_models = { | |
| "Pruned ONNX Optimized FP16": hf_hub_download( | |
| repo_id=models["Pruned ONNX Optimized FP16"], filename="model.onnx" | |
| ), | |
| "Base model": AutoModelForQuestionAnswering.from_pretrained(models["Base model"]), | |
| "Pruned model": AutoModelForQuestionAnswering.from_pretrained( | |
| models["Pruned model"] | |
| ), | |
| } | |
| def run_ort_inference(model_name, inputs): | |
| sess = InferenceSession( | |
| loaded_models[model_name], providers=["CPUExecutionProvider"] | |
| ) | |
| start_time = time.time() | |
| output = sess.run(None, input_feed=inputs) | |
| end_time = time.time() | |
| return (output[0], output[1]), (end_time - start_time) | |
| def run_normal_hf(model_name, inputs): | |
| start_time = time.time() | |
| output = loaded_models[model_name](**inputs).values() | |
| end_time = time.time() | |
| return output, (end_time - start_time) | |
| def inference(model_name, context, question): | |
| tokenizer = AutoTokenizer.from_pretrained(models[model_name]) | |
| if model_name == "Pruned ONNX Optimized FP16": | |
| inputs = dict( | |
| tokenizer( | |
| question, context, return_tensors="np", max_length=MAX_SEQUENCE_LENGTH | |
| ) | |
| ) | |
| output, inference_time = run_ort_inference(model_name, inputs) | |
| answer_start_scores, answer_end_scores = torch.tensor(output[0]), torch.tensor( | |
| output[1] | |
| ) | |
| else: | |
| inputs = tokenizer( | |
| question, context, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH | |
| ) | |
| output, inference_time = run_normal_hf(model_name, inputs) | |
| answer_start_scores, answer_end_scores = output | |
| input_ids = inputs["input_ids"].tolist()[0] | |
| answer_start = torch.argmax(answer_start_scores) | |
| answer_end = torch.argmax(answer_end_scores) + 1 | |
| answer = tokenizer.convert_tokens_to_string( | |
| tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]) | |
| ) | |
| return answer, f"{inference_time:.4f}s" | |
| model_field = gr.Dropdown( | |
| choices=["Base model", "Pruned model", "Pruned ONNX Optimized FP16"], | |
| value="Pruned ONNX Optimized FP16", | |
| label="Model", | |
| ) | |
| input_text_field = gr.Textbox(placeholder="Enter the text here", label="Text") | |
| input_question_field = gr.Text(placeholder="Enter the question here", label="Question") | |
| output_model = gr.Text(label="Model output") | |
| output_inference_time = gr.Text(label="Inference time in seconds") | |
| examples = [ | |
| [ | |
| "Pruned ONNX Optimized FP16", | |
| "The first little pig was very lazy. He didn't want to work at all and he built his house out of straw. The second little pig worked a little bit harder but he was somewhat lazy too and he built his house out of sticks. Then, they sang and danced and played together the rest of the day.", | |
| "Who worked a little bit harder?", | |
| ] | |
| ] | |
| demo = gr.Interface( | |
| inference, | |
| inputs=[model_field, input_text_field, input_question_field], | |
| outputs=[output_model, output_inference_time], | |
| examples=examples, | |
| ) | |
| demo.launch() | |