Spaces:
Sleeping
Sleeping
| import pathlib | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoFeatureExtractor, DetrForObjectDetection | |
| from visualization import visualize_attention_map, visualize_prediction | |
| from style import css, description, title | |
| from PIL import Image | |
| def make_prediction(img, feature_extractor, model): | |
| inputs = feature_extractor(img, return_tensors="pt") | |
| outputs = model(**inputs) | |
| img_size = torch.tensor([tuple(reversed(img.size))]) | |
| processed_outputs = feature_extractor.post_process(outputs, img_size) | |
| print(outputs.keys()) | |
| return ( | |
| processed_outputs[0], | |
| outputs["decoder_attentions"], | |
| outputs["encoder_attentions"], | |
| ) | |
| def construct_model_name( | |
| experiment_type, | |
| convbase, | |
| attention_heads_num, | |
| enc_dec_layers, | |
| ffn_dim, | |
| act_func, | |
| d_model, | |
| dilation=None | |
| ): | |
| base = "polejowska/" | |
| if convbase == "RESNET-50": | |
| base += "detr-r50" | |
| elif convbase == "RESNET-101": | |
| if enc_dec_layers == 6: | |
| return "polejowska/detr-r101-official" | |
| elif enc_dec_layers == 4: | |
| return "polejowska/detr-r101-cd45rb-8ah-4l" | |
| elif enc_dec_layers == 12: | |
| return "polejowska/detr-r101-cd45rb-8ah-12l" | |
| base += "-cd45rb" | |
| base += f"-{attention_heads_num}ah" | |
| base += f"-{enc_dec_layers}l" | |
| if attention_heads_num == 1: | |
| base += "-corrected" | |
| if d_model != 256: | |
| base += f"-{d_model}d" | |
| if ffn_dim == 1024: | |
| base += "-1024ffn" | |
| elif ffn_dim == 4096: | |
| base += "-4096ffn-correcetd" | |
| if act_func == "GeLU": | |
| base += "-gelu-corrected" | |
| if dilation == "True": | |
| base += "-dilation-corrected" | |
| return base | |
| def detect_objects( | |
| experiment_type, | |
| convbase, | |
| attention_heads_num, | |
| enc_dec_layers, | |
| ffn_dim, | |
| act_func, | |
| d_model, | |
| dilation, | |
| image_input, | |
| threshold=0.7, | |
| display_mask=False, | |
| img_input_mask=None | |
| ): | |
| if experiment_type in ["Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)"]: | |
| if experiment_type == "Parameters verification": | |
| model_repo = construct_model_name(experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation) | |
| elif experiment_type == "Reproducability check (1)": | |
| model_repo = "polejowska/detr-r50-cd45rb-all-2ah" | |
| elif experiment_type == "Reproducability check (2)": | |
| model_repo = "polejowska/detr-r50-cd45rb-all-4ah" | |
| elif experiment_type == "Reproducability check (3)": | |
| model_repo = "polejowska/detr-r50-cd45rb-all-8ah" | |
| elif experiment_type == "Reproducability check (4)": | |
| model_repo = "polejowska/detr-r50-cd45rb-all-16ah" | |
| model = DetrForObjectDetection.from_pretrained(model_repo) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo) | |
| ( | |
| processed_outputs, | |
| decoder_attention_map, | |
| encoder_attention_map, | |
| ) = make_prediction(image_input, feature_extractor, model) | |
| viz_img = visualize_prediction( | |
| pil_img=image_input, | |
| output_dict=processed_outputs, | |
| threshold=threshold, | |
| id2label=model.config.id2label, | |
| display_mask=display_mask, | |
| mask=img_input_mask | |
| ) | |
| decoder_attention_map_img = visualize_attention_map( | |
| image_input, decoder_attention_map | |
| ) | |
| encoder_attention_map_img = visualize_attention_map( | |
| image_input, encoder_attention_map | |
| ) | |
| return ( | |
| viz_img, | |
| decoder_attention_map_img, | |
| encoder_attention_map_img, | |
| ) | |
| def set_example_image(example: list): | |
| return gr.Image(value=example[0]), gr.Image(value=example[1]) | |
| with gr.Blocks(css=css) as app: | |
| gr.Markdown(title) | |
| with gr.Tabs(): | |
| with gr.TabItem("Image upload and detections visualization"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| experiment_type = gr.Dropdown( | |
| value="Parameters verification", | |
| choices=[ | |
| "Parameters verification", | |
| "Reproducability check (1)", | |
| "Reproducability check (2)", | |
| "Reproducability check (3)", | |
| "Reproducability check (4)", | |
| ], | |
| label="Select an experiment type", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| convbase= gr.Dropdown( | |
| value="RESNET-50", | |
| choices=[ | |
| "RESNET-50", | |
| "RESNET-101", | |
| ], | |
| label="Select a base model for convolution part", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| attention_heads_num = gr.Dropdown( | |
| value=8, | |
| choices=[1, 2, 4, 8, 16], | |
| label="The number of attention heads in encoder and decoder", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| enc_dec_layers = gr.Dropdown( | |
| value=6, | |
| choices=[4, 6, 12], | |
| label="The number of layers in encoder and decoder", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| ffn_dim = gr.Dropdown( | |
| value=2048, | |
| choices=[1024, 2048, 4096], | |
| label="Select FFN dimension", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| act_func= gr.Dropdown( | |
| value="ReLU", | |
| choices=[ | |
| "ReLU", | |
| "GeLU", | |
| ], | |
| label="Select an activation function", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| d_model= gr.Dropdown( | |
| value=256, | |
| choices=[128, 256, 512], | |
| label="Select a hidden size", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| dilation= gr.Dropdown( | |
| value="False", | |
| choices=[ | |
| "True", | |
| "False", | |
| ], | |
| label="Use dilation", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| slider_input = gr.Slider( | |
| minimum=0.2, maximum=1, value=0.7, label="Prediction threshold" | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil") | |
| img_input_mask = gr.Image(type="pil", visible=False) | |
| with gr.Row(): | |
| example_images = gr.Dataset( | |
| components=[img_input, img_input_mask], | |
| samples=[ | |
| [path.as_posix(), path.as_posix().replace("_HE", "_mask")] | |
| for path in sorted( | |
| pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png") | |
| ) | |
| ], | |
| samples_per_page=2, | |
| ) | |
| with gr.Row(): | |
| display_mask = gr.Checkbox( | |
| label="Display masks", | |
| ) | |
| with gr.Row(): | |
| detect_button = gr.Button("Detect leukocytes") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_output_from_upload = gr.Image(width=900, height=900) | |
| with gr.TabItem("Attentions visualization"): | |
| gr.Markdown("""Encoder attentions""") | |
| with gr.Row(): | |
| encoder_att_map_output = gr.Image(width=850, height=850) | |
| gr.Markdown("""Decoder attentions""") | |
| with gr.Row(): | |
| decoder_att_map_output = gr.Image(width=850, height=850) | |
| with gr.TabItem("Dataset details"): | |
| with gr.Row(): | |
| gr.Markdown(description) | |
| detect_button.click( | |
| detect_objects, | |
| inputs=[ | |
| experiment_type, | |
| convbase, | |
| attention_heads_num, | |
| enc_dec_layers, | |
| ffn_dim, | |
| act_func, | |
| d_model, | |
| dilation, | |
| img_input, | |
| slider_input, | |
| display_mask, | |
| img_input_mask | |
| ], | |
| outputs=[ | |
| img_output_from_upload, | |
| decoder_att_map_output, | |
| encoder_att_map_output, | |
| ], | |
| queue=True, | |
| ) | |
| example_images.click( | |
| fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask], | |
| show_progress=True | |
| ) | |
| app.launch() |