koutu / app.py
tiancheng91's picture
update params
2431fe1
raw
history blame contribute delete
No virus
1.63 kB
import gradio as gr
import rembg
import functools
from gradio.mix import Parallel
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
generator = pipeline('text-generation', model='gpt2')
model_txt_embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# 文本向量
def embedding(txt):
return str(model_txt_embedding.encode(txt, convert_to_tensor=True, convert_to_numpy=True).tolist())
# 文本相似度
def embedding_sim(text1, text2):
embeddings1 = model_txt_embedding.encode(text1, convert_to_tensor=True)
embeddings2 = model_txt_embedding.encode(text2, convert_to_tensor=True)
return str(util.cos_sim(embeddings1, embeddings2).item())
# 抠图
remove_bg_models = {
"常规": "u2net",
"人物": "u2net_human_seg",
"衣服": "u2net_cloth_seg"
}
@functools.lru_cache()
def get_remove_session(model):
return rembg.new_session(model)
def remove_bg(image, model_cn, only_mask):
# return image
session = get_remove_session(remove_bg_models.get(model_cn))
return rembg.remove(image, session=session, only_mask=only_mask)
tab1 = gr.Interface(fn=embedding, inputs="text", outputs="text")
tab2 = gr.Interface(fn=embedding_sim, inputs=["text", "text"], outputs="text")
tab3 = gr.Interface(fn=remove_bg, inputs=[
"image",
gr.Dropdown(label="模型", choices=list(remove_bg_models.keys()), value="常规"),
gr.Checkbox(label="只返回蒙版", value=False)
], outputs="image")
app = gr.TabbedInterface([tab1, tab2, tab3], ["文本向量", "文本相似度", "抠图"])
app.queue()
app.launch()