Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import gradio as gr | |
| from PIL import Image | |
| from copy import deepcopy | |
| import os, sys | |
| sys.path.append('./DETRPose') | |
| sys.path.append('./DETRPose/tools/inference') | |
| from DETRPose.src.core import LazyConfig, instantiate | |
| from DETRPose.tools.inference.annotator import Annotator | |
| from DETRPose.tools.inference.annotator_crowdpose import AnnotatorCrowdpose | |
| DETRPOSE_MODELS = { | |
| # For COCO2017 | |
| "DETRPose-N": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_n.py', 'n'], | |
| "DETRPose-S": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_s.py', 's'], | |
| "DETRPose-M": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_m.py', 'm'], | |
| "DETRPose-L": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_l.py', 'l'], | |
| "DETRPose-X": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_x.py', 'x'], | |
| # For CrowdPose | |
| "DETRPose-N-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_n_crowdpose.py', 'n_crowdpose'], | |
| "DETRPose-S-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_s_crowdpose.py', 's_crowdpose'], | |
| "DETRPose-M-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_m_crowdpose.py', 'm_crowdpose'], | |
| "DETRPose-L-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_l_crowdpose.py', 'l_crowdpose'], | |
| "DETRPose-X-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_x_crowdpose.py', 'x_crowdpose'], | |
| } | |
| transforms = T.Compose( | |
| [ | |
| T.Resize((640, 640)), | |
| T.ToTensor(), | |
| ] | |
| ) | |
| example_images = [ | |
| ["assets/example1.jpg"], | |
| ["assets/example2.jpg"], | |
| ] | |
| description = """ | |
| <h1 align="center"> | |
| <ins>DETRPose</ins> | |
| <br> | |
| Real-time end-to-end transformer model for multi-person pose estimation | |
| </h1> | |
| <h2 align="center"> | |
| <a href="https://www.linkedin.com/in/sebastianjr/">Sebastian Janampa</a> | |
| and | |
| <a href="https://www.linkedin.com/in/marios-pattichis-207b0119/">Marios Pattichis</a> | |
| </h2> | |
| <h2 align="center"> | |
| <a href="https://github.com/SebastianJanampa/DETRPose.git">GitHub</a> | | |
| <a href="https://colab.research.google.com/github/SebastianJanampa/DETRPose/blob/main/DETRPose_tutorial.ipynb">Colab</a> | |
| </h2> | |
| ## Getting Started | |
| DETRPose is the first real-time end-to-end transformer model for multi-person pose estimation, | |
| achieving outstanding results on the COCO and CrowdPose datasets. In this work, we propose a | |
| new denoising technique suitable for pose estimation that uses the Object Keypoint Similarity (OKS) metric | |
| to generate positive and negative queries. Additionally, we develop a new classification head | |
| and a new classification loss that are variations of the LQE head and the varifocal loss used in D-FINE. | |
| To get started, upload an image or select one of the examples below. | |
| You can choose between different model size, change the confidence threshold and visualize the results. | |
| ### Acknowledgement | |
| This work has been supported by [LambdaLab](https://lambda.ai) | |
| """ | |
| def create_model(model_name): | |
| config_path = DETRPOSE_MODELS[model_name][0] | |
| model_name = DETRPOSE_MODELS[model_name][1] | |
| cfg = LazyConfig.load(config_path) | |
| if hasattr(cfg.model.backbone, 'pretrained'): | |
| cfg.model.backbone.pretrained = False | |
| download_url = f"https://github.com/SebastianJanampa/DETRPose/releases/download/model_weights/detrpose_hgnetv2_{model_name}.pth" | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| download_url, map_location="cpu", file_name=f"detrpose_hgnetv2_{model_name}.pth" | |
| ) | |
| model = instantiate(cfg.model) | |
| postprocessor = instantiate(cfg.postprocessor) | |
| model.load_state_dict(state_dict['model'], strict=True) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = model.deploy() | |
| self.postprocessor = postprocessor.deploy() | |
| def forward(self, images, orig_target_sizes): | |
| outputs = self.model(images) | |
| outputs = self.postprocessor(outputs, orig_target_sizes) | |
| return outputs | |
| model = Model() | |
| model.eval() | |
| global Drawer | |
| if 'crowdpose' in model_name: | |
| Drawer = AnnotatorCrowdpose | |
| else: | |
| Drawer = Annotator | |
| return model#, Drawer | |
| def draw(image, scores, labels, keypoints, h, w, thrh): | |
| annotator = Drawer(deepcopy(image)) | |
| for kpt, score in zip(keypoints, scores): | |
| if score > thrh: | |
| annotator.kpts( | |
| kpt, | |
| [h, w] | |
| ) | |
| annotated_image = annotator.result() | |
| return annotated_image[..., ::-1] | |
| def filter(lines, scores, threshold): | |
| filtered_lines, filter_scores = [], [] | |
| for line, scr in zip(lines, scores): | |
| idx = scr > threshold | |
| filtered_lines.append(line[idx]) | |
| filter_scores.append(scr[idx]) | |
| return filtered_lines, filter_scores | |
| def process_results( | |
| image_path, | |
| model_size, | |
| threshold | |
| ): | |
| """ Process the image an returns the detected lines """ | |
| if image_path is None: | |
| raise gr.Error("Please upload an image first.") | |
| model = create_model(model_size) | |
| im_pil = Image.open(image_path).convert("RGB") | |
| w, h = im_pil.size | |
| orig_size = torch.tensor([[w, h]]) | |
| im_data = transforms(im_pil).unsqueeze(0) | |
| output = model(im_data, orig_size) | |
| scores, labels, keypoints = output | |
| scores, labels, keypoints = scores[0], labels[0], keypoints[0] | |
| annotated_image = draw(im_pil, scores, labels, keypoints, h, w, thrh=threshold) | |
| return annotated_image, (scores, labels, keypoints, h, w) | |
| def update_threshold( | |
| image_path, | |
| raw_results, | |
| threshold | |
| ): | |
| scores, labels, keypoints, h, w = raw_results | |
| im_pil = Image.open(image_path).convert("RGB") | |
| annotated_image = draw(im_pil, scores, labels, keypoints, h, w, thrh=threshold) | |
| return annotated_image | |
| def update_model( | |
| image_path, | |
| model_size, | |
| threshold | |
| ): | |
| if image_path is None: | |
| raise gr.Error("Please upload an image first.") | |
| return None, None, None | |
| return process_results(image_path, model_size, threshold) | |
| def main(): | |
| global Drawer | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""## Input Image""") | |
| image_path = gr.Image(label="Upload image", type="filepath") | |
| model_size = gr.Dropdown( | |
| choices=list(DETRPOSE_MODELS.keys()), label="Choose a DETRPose model.", value="DETRPose-M" | |
| ) | |
| threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| value=0.50, | |
| ) | |
| submit_btn = gr.Button("Detect Human Keypoints") | |
| gr.Examples(examples=example_images, inputs=[image_path, model_size]) | |
| with gr.Column(): | |
| gr.Markdown("""## Results""") | |
| image_output = gr.Image(label="Detected Human Keypoints") | |
| # Define the action when the button is clicked | |
| raw_results = gr.State() | |
| plot_inputs = [ | |
| raw_results, | |
| threshold, | |
| ] | |
| submit_btn.click( | |
| fn=process_results, | |
| inputs=[image_path, model_size] + plot_inputs[1:], | |
| outputs=[image_output, raw_results], | |
| ) | |
| # Define the action when the plot checkboxes are clicked | |
| threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[image_output]) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |