Spaces:
Sleeping
Sleeping
| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
| from PIL import Image, ImageFilter | |
| import torch.nn as nn | |
| import os | |
| import gradio as gr | |
| processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
| model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
| title = "Background remover π" | |
| description = " Image segmentation model which removes the background and optionally adds a white border." | |
| article = 'Inference done on "mattmdjaga/segformer_b2_clothes" model' | |
| folder_path = "Images" | |
| example_list = [] | |
| if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
| file_paths = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)] | |
| for file_path in file_paths: | |
| example_list.append(['Large',file_path]) | |
| def predict(border_size, image): | |
| sizes = {'Large': 5, 'Medium': 3, 'Small': 1, 'None': 0} | |
| image = image.convert('RGB') | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| pred_seg = upsampled_logits.argmax(dim=1)[0] | |
| non_background_mask = pred_seg != 0 | |
| # Convert tensor mask to PIL Image with an alpha channel | |
| non_background_pil_mask = Image.fromarray(non_background_mask.numpy().astype('uint8') * 255, 'L') | |
| # Create a composite image using the non-background mask | |
| composite_image = Image.new('RGBA', image.size, color=(0, 0, 0, 0)) | |
| composite_image.paste(image.convert('RGBA'), mask=non_background_pil_mask) | |
| if sizes[border_size] != 0: | |
| stroke_radius = sizes[border_size] | |
| img = composite_image # RGBA image | |
| stroke_image = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
| img_alpha = img.getchannel(3).point(lambda x: 255 if x>0 else 0) | |
| stroke_alpha = img_alpha.filter(ImageFilter.MaxFilter(stroke_radius)) | |
| stroke_alpha = stroke_alpha.filter(ImageFilter.SMOOTH) | |
| stroke_image.putalpha(stroke_alpha) | |
| output = Image.alpha_composite(stroke_image, img) | |
| return output | |
| else: | |
| return composite_image | |
| iface = gr.Interface(fn=predict, | |
| inputs=[gr.Dropdown(['None','Small', 'Medium', 'Large'], label='Select Border Size'), | |
| gr.Image(type='pil', label='Select Image.')], | |
| outputs=gr.Image(type='pil', label='Output with background removed (sorta?)'), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=example_list) | |
| iface.launch() |