devbox / app.py
tiancheng91's picture
update params
2431fe1
import gradio as gr
import rembg
import functools
from gradio.mix import Parallel
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
generator = pipeline('text-generation', model='gpt2')
model_txt_embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# 文本向量
def embedding(txt):
return str(model_txt_embedding.encode(txt, convert_to_tensor=True, convert_to_numpy=True).tolist())
# 文本相似度
def embedding_sim(text1, text2):
embeddings1 = model_txt_embedding.encode(text1, convert_to_tensor=True)
embeddings2 = model_txt_embedding.encode(text2, convert_to_tensor=True)
return str(util.cos_sim(embeddings1, embeddings2).item())
# 抠图
remove_bg_models = {
"常规": "u2net",
"人物": "u2net_human_seg",
"衣服": "u2net_cloth_seg"
}
@functools.lru_cache()
def get_remove_session(model):
return rembg.new_session(model)
def remove_bg(image, model_cn, only_mask):
# return image
session = get_remove_session(remove_bg_models.get(model_cn))
return rembg.remove(image, session=session, only_mask=only_mask)
tab1 = gr.Interface(fn=embedding, inputs="text", outputs="text")
tab2 = gr.Interface(fn=embedding_sim, inputs=["text", "text"], outputs="text")
tab3 = gr.Interface(fn=remove_bg, inputs=[
"image",
gr.Dropdown(label="模型", choices=list(remove_bg_models.keys()), value="常规"),
gr.Checkbox(label="只返回蒙版", value=False)
], outputs="image")
app = gr.TabbedInterface([tab1, tab2, tab3], ["文本向量", "文本相似度", "抠图"])
app.queue()
app.launch()