| | |
| | import gradio as gr |
| | from PIL import Image |
| | import torch |
| | from transformers import SamModel, SamProcessor |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| |
|
| |
|
| | |
| | model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") |
| | processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") |
| |
|
| |
|
| | |
| | input_points = [] |
| |
|
| | |
| | def show_mask(mask, ax, random_color=False): |
| | if random_color: |
| | color = np.concatenate([np.random.random(3), |
| | np.array([0.6])], |
| | axis=0) |
| | else: |
| | color = np.array([30/255, 144/255, 255/255, 0.6]) |
| | h, w = mask.shape[-2:] |
| | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| | ax.imshow(mask_image) |
| | |
| | def get_pixel_coordinates(image, evt: gr.SelectData): |
| | global input_points |
| | x, y = evt.index[0], evt.index[1] |
| | input_points = [[[x, y]]] |
| | return perform_prediction(image) |
| |
|
| | |
| | def perform_prediction(image): |
| | global input_points |
| | |
| | inputs = processor(images=image, input_points=input_points, return_tensors="pt") |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | iou = outputs.iou_scores |
| | max_iou_index = torch.argmax(iou) |
| |
|
| | |
| | predicted_masks = processor.image_processor.post_process_masks( |
| | outputs.pred_masks, |
| | inputs['original_sizes'], |
| | inputs['reshaped_input_sizes'] |
| | ) |
| | predicted_mask = predicted_masks[0] |
| |
|
| | |
| | mask_image = show_mask_on_image(image, predicted_mask[:,max_iou_index], return_image=True) |
| | return mask_image |
| |
|
| | |
| | def show_mask_on_image(raw_image, mask, return_image=False): |
| | if not isinstance(mask, torch.Tensor): |
| | mask = torch.Tensor(mask) |
| |
|
| | if len(mask.shape) == 4: |
| | mask = mask.squeeze() |
| |
|
| | fig, axes = plt.subplots(1, 1, figsize=(15, 15)) |
| |
|
| | mask = mask.cpu().detach() |
| | axes.imshow(np.array(raw_image)) |
| | show_mask(mask, axes) |
| | axes.axis("off") |
| | plt.show() |
| |
|
| | if return_image: |
| | fig = plt.gcf() |
| | fig.canvas.draw() |
| | |
| | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| | img = Image.fromarray(img) |
| | plt.close(fig) |
| | return img |
| |
|
| |
|
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown( |
| | """ |
| | <div style='text-align: center; font-family: "Times New Roman";'> |
| | <h1 style='color: #FF6347;'>One Click Image Segmentation App</h1> |
| | <h3 style='color: #4682B4;'>Model: SlimSAM-uniform-77</h3> |
| | <h3 style='color: #32CD32;'>Made By: Md. Mahmudun Nabi</h3> |
| | </div> |
| | """ |
| | ) |
| | with gr.Row(): |
| | |
| | img = gr.Image(type="pil", label="Input Image",height=400, width=600) |
| | output_image = gr.Image(label="Masked Image") |
| |
|
| | img.select(get_pixel_coordinates, inputs=[img], outputs=[output_image]) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=False) |