tiancheng91 commited on
Commit
2418413
1 Parent(s): 451c867
Files changed (1) hide show
  1. app.py +20 -5
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
 
 
2
  from gradio.mix import Parallel
3
  from transformers import pipeline
4
  from sentence_transformers import SentenceTransformer, util
5
- from rembg import remove
6
-
7
 
8
  generator = pipeline('text-generation', model='gpt2')
9
  model_txt_embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
@@ -19,13 +19,28 @@ def embedding_sim(text1, text2):
19
  embeddings2 = model_txt_embedding.encode(text2, convert_to_tensor=True)
20
  return str(util.cos_sim(embeddings1, embeddings2).item())
21
 
22
- def remove_bg(image):
 
 
 
 
 
 
 
 
 
 
23
  # return image
24
- return remove(image)
 
25
 
26
  tab1 = gr.Interface(fn=embedding, inputs="text", outputs="text")
27
  tab2 = gr.Interface(fn=embedding_sim, inputs=["text", "text"], outputs="text")
28
- tab3 = gr.Interface(fn=remove_bg, inputs="image", outputs="image")
 
 
 
 
29
 
30
  app = gr.TabbedInterface([tab1, tab2, tab3], ["文本向量", "文本相似度", "抠图"])
31
  app.queue()
 
1
  import gradio as gr
2
+ import rembg
3
+ import functools
4
  from gradio.mix import Parallel
5
  from transformers import pipeline
6
  from sentence_transformers import SentenceTransformer, util
 
 
7
 
8
  generator = pipeline('text-generation', model='gpt2')
9
  model_txt_embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
 
19
  embeddings2 = model_txt_embedding.encode(text2, convert_to_tensor=True)
20
  return str(util.cos_sim(embeddings1, embeddings2).item())
21
 
22
+ # 抠图
23
+ remove_bg_models = {
24
+ "常规": "u2net",
25
+ "人物": "u2net_human_seg",
26
+ "衣服": "u2net_cloth_seg"
27
+ }
28
+ @functools.lru_cache()
29
+ def get_remove_session(model):
30
+ return rembg.new_session(model)
31
+
32
+ def remove_bg(image, model_cn, only_mask):
33
  # return image
34
+ session = get_remove_session(remove_bg_models.get(model_cn))
35
+ return rembg.remove(image, session=session, only_mask=only_mask)
36
 
37
  tab1 = gr.Interface(fn=embedding, inputs="text", outputs="text")
38
  tab2 = gr.Interface(fn=embedding_sim, inputs=["text", "text"], outputs="text")
39
+ tab3 = gr.Interface(fn=remove_bg, inputs=[
40
+ "image",
41
+ gr.Dropdown(label="模型", choices=list(remove_bg_models.keys()), value="u2net"),
42
+ gr.Checkbox(label="只返回蒙版", value=False)
43
+ ], outputs="image")
44
 
45
  app = gr.TabbedInterface([tab1, tab2, tab3], ["文本向量", "文本相似度", "抠图"])
46
  app.queue()