1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| import random
import cv2 import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
def setup_seed(seed=33): """ 设置随机种子函数,采用固定的随机种子使得结果可复现 seed:种子值,int """ torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.benchmark = ( False ) torch.backends.cudnn.deterministic = True
def show_anns(anns, image): if len(anns) == 0: return image sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) img = np.zeros( ( sorted_anns[0]["segmentation"].shape[0], sorted_anns[0]["segmentation"].shape[1], 3, ), dtype=np.uint8, )
for ann in sorted_anns: m = ann["segmentation"] color_mask = np.random.choice(range(256), size=3) img[m] = color_mask
return cv2.add(image, img)
sam_checkpoint = "/disk1/datasets/models/sam/sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device)
def segment_anything(image, points_per_side, pred_iou_thresh, seed, sam=sam): """ 使用SAM分割自动一副图像,并给出分割后的效果图,在gradio上显示 """
setup_seed(seed) mask_generator = SamAutomaticMaskGenerator( sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh ) masks = mask_generator.generate(image) seg_res_img = show_anns(masks, image) return seg_res_img, len(masks)
interface = gr.Interface( fn=segment_anything, inputs=[ gr.components.Image(label="输入图像", height=500), gr.Slider(16, 128), gr.Slider(0, 1, step=0.01), gr.Slider(1, 999), ], outputs=[ gr.components.Image(label="分割结果", height=500, interactive=True), gr.components.Number(label="分割数"), ], examples=[ ["./images/girl.jpg", 32, 0.86, 31], ["./images/zdt.png", 64, 0.86, 33], ["./images/green wormcopy.jpg", 64, 0.86, 33], ], ).queue(concurrency_count=5)
interface.launch( share=False, server_name="0.0.0.0", server_port=7860, favicon_path="./images/icon.ico", )
|