deepskyreal's picture
update translator engine to alibaba
8300616
raw
history blame
4.93 kB
import os
import gradio as gr
import numpy as np
import translators as ts
from PIL import Image
from gradio import Blocks, Button, Textbox, Row, Column, Dropdown, Examples, Audio, Markdown
from langchain import Cohere, LLMChain, PromptTemplate
from transformers import BlipProcessor, BlipForConditionalGeneration
from bark_speaker.txt2audio import gen_tts, AVAILABLE_PROMPTS
from comic_style.comic_style import inference
from sad_talker.src.gradio_demo import SadTalker
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def translate_into_cn(source):
print(ts.translators_pool)
result = ts.translate_text(query_text=source, translator='alibaba', from_language='en', to_language='zh')
return result
def predict_step(cohere_key, img, style):
i_image = Image.fromarray(np.array(img), 'RGB')
pixel_values = processor(images=i_image, return_tensors="pt", max_length=1024, verbose=True).pixel_values
output = model.generate(pixel_values)
preds = processor.batch_decode(output, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
# 条件:严格按照要求完成任务,输出内容直接为主体内容,输出内容前后不要有其他符号,注意语句保持通顺,输出内容全部是中文," \ " 不要重复输出内容, 不需要换行,不需要有标题,不需要排版格式。" \ "\n "\n2. Give the
# final output content an evaluation score as required. The score range is 0-100, 0 is the worst, 100 is the best,
# and the score should be objective. The format is [score:xxx]. Add at the end." \
question = "Requirements: \nYou are a writing master. According to the content: {}, write a 50 words essay in any " \
"form, by the style of \"{}\" as the final output content. " \
"\nfinal output content:" \
.format(preds[0], style)
print("question:{}".format(question))
template = """{question}"""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = Cohere(cohere_api_key=cohere_key, model="command", temperature=0.3, verbose=True)
llm_chain = LLMChain(prompt=prompt, llm=llm)
result = llm_chain.run(question)
print("result:{}".format(result))
# result = llm.generate([prompt])
return preds[0], translate_into_cn(result)
sad_talker = SadTalker(lazy_load=True)
with Blocks() as demo:
with Row():
with Column(scale=1):
Markdown("[Cohere](https://dashboard.cohere.ai/)")
cohere_key = gr.Text(label="Cohere Key:")
Markdown("Scene 1:Img2Img(图生图)")
with Row():
image_upload = gr.Image(type="pil", label="Essay Image")
comic_style_output = gr.Image(type="filepath", label="Comic Style")
Examples(
examples=[os.path.join(os.path.dirname(__file__), "example1.jpeg"),
os.path.join(os.path.dirname(__file__), "example2.jpg")],
fn=inference,
inputs=image_upload,
)
dropdown = Dropdown(
["shakespeare", "luxun", "xuzhimo", "moyan", "laoshe"],
value="luxun",
label="Essay Style",
info="选择你需要的文章的风格"
)
essay_btn = Button("Generate Essay", variant='primary')
with Column(scale=1):
Markdown("Scene 2:ReadImg(识图)")
prediction_output = Textbox(label="Prediction")
Markdown("Scene 3:GenEssay(风格小作文)")
essay_output = Textbox(label="Essay", info="大约50字")
Markdown("Scene 4:Txt2Aud(文字转语音)")
audio_out = Audio(label="Generated Audio", type="filepath").style(height=20)
audio_option = Dropdown(AVAILABLE_PROMPTS, value="Speaker 7 (zh)", label="Acoustic Prompt",
elem_id="speaker_option")
audio_btn = Button("Generate Audio", variant='primary')
with Column(scale=1):
Markdown("Scene 5: Img&Aud2Talker(图片&语音转talker)")
gen_video = gr.Video(label="Generated video", format="mp4")
talker_btn = Button('Generate Talker', elem_id="sadtalker_generate", variant='primary')
# Step 1
image_upload.change(fn=inference, inputs=image_upload, outputs=comic_style_output)
# Step 2
essay_btn.click(fn=predict_step, inputs=[cohere_key, image_upload, dropdown], outputs=[prediction_output, essay_output],
api_name="essay_generate")
# Step 3
audio_btn.click(fn=gen_tts, inputs=[essay_output, audio_option], outputs=audio_out)
# Step 4
talker_btn.click(fn=sad_talker.test, inputs=[comic_style_output, audio_out], outputs=[gen_video])
demo.launch(debug=True)