from __future__ import annotations import gradio as gr import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') import subprocess def runcmd(command): ret = subprocess.run(command,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE,encoding="utf-8",timeout=60) if ret.returncode == 0: print("success:",ret) else: print("error:",ret) runcmd("pip3 install --upgrade clueai") import clueai cl = clueai.Client("", check_api_key=False) ''' #luck_t2i_btn_1, #luck_s2i_btn_1, #luck_i2i_btn_1, #luck_ici_btn_1{ color: #fff; --tw-gradient-from: #BED336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #BED336; border-color: #BED336; } #luck_easy_btn_1, #luck_iti_btn_1, #luck_tsi_btn_1, #luck_isi_btn_1{ color: #fff; --tw-gradient-from: #BED336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #BED336; border-color: #BED336; } ''' css=''' .container { max-width: 800px; margin: auto; } #gen_btn_1{ color: #fff; --tw-gradient-from: #f44336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #ff9800; border-color: #ff9800; } #t2i_btn_1, #s2i_btn_1, #i2i_btn_1, #ici_btn_1, #easy_btn_1, #iti_btn_1, #tsi_btn_1, #isi_btn_1{ color: #fff; --tw-gradient-from: #f44336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #ff9800; border-color: #ff9800; } #import_t2i_btn_1, #import_s2i_btn_1, #import_i2i_btn_1, #import_ici_btn_1{ color: #fff; --tw-gradient-from: #BED336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #BED336; border-color: #BED336; } #import_easy_btn_1, #import_iti_btn_1, #import_tsi_btn_1, #import_isi_btn_1{ color: #fff; --tw-gradient-from: #BED336; --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); --tw-gradient-to: #BED336; border-color: #BED336; } #record_btn{ } #record_btn > div > button > span { width: 2.375rem; height: 2.375rem; } #record_btn > div > button > span > span { width: 2.375rem; height: 2.375rem; } audio { margin-bottom: 10px; } div#record_btn > .mt-6{ margin-top: 0!important; } div#record_btn > .mt-6 button { font-size: 1em; width: 100%; padding: 20px; height: 60px; } div#txt2img_tab { color: #BED336; } ''' default_generate_config = { "do_sample": False, "top_p": 0, "top_k": 50, "max_length": 64, "temperature": 1, "num_beams": 1, "length_penalty": 0.6 } task_styles = [] examples_list = [] task_style_to_task_prefix = {} import csv examples_set = set() def read_examples(input_file): header = True with open(input_file) as finput: csv_input = csv.reader(finput) for line in csv_input: if header: header = False continue task_style, task_prefix, example = line task_styles.append(task_style) task_style_to_task_prefix[task_style] = task_prefix examples_list.append([task_style, example]) examples_set.add((task_style, example)) read_examples("./examples.csv") #print(task_styles) def preprocess(text, task): if task == "问答": text = text.replace("?", ":").replace("?", ":") text = text + ":" return task_style_to_task_prefix[task] + "\n" + text + "\n答案:" def inference_gen(text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty): default_example = (task, text) in examples_set text = preprocess(text, task) generate_config = { "do_sample": do_sample, "top_p": top_p, "top_k": top_k, "max_length": max_token, "temperature": temperature, "num_beams": beam_size, "length_penalty": length_penalty } #print(generate_config) #print(text) default_example = default_example and generate_config == default_generate_config try_num = 3 while try_num: try: if default_example: prediction = cl.generate( model_name='clueai-base', prompt=text) else: prediction = cl.generate( model_name='clueai-base', prompt=text, generate_config=generate_config) except Exception as e: logger.error(f"error, {e}") return if prediction.generations[0].text != "含有违规词,不予展示": break try_num -= 1 return prediction.generations[0].text t2i_default_img_path_list = [] import base64, requests from io import BytesIO from PIL import Image def luck_inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale): return inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=True) def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale, luck=False): try: res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}&negative_prompt={n_text}\ &guidance_scale={guidance_scale}&num_inference_steps={steps}\ &style={style}&shape={shape}&clarity={clarity}&shape_scale={shape_scale}&luck={luck}") except Exception as e: logger.error(f"error, {e}") return json_dict = res.json() file_path_list = [] for i, image in enumerate(json_dict["images"]): image = image.encode('utf-8') binary_data = base64.b64decode(image) img_data = BytesIO(binary_data) img = Image.open(img_data) file_path_list.append(img) return file_path_list image_styles = ['无', '细节大师', '对称美', '虚拟引擎', '空间感', '机械风格', '形状艺术', '治愈', '电影构图', '电影构图(治愈)', '荒芜感', '漫画', '逃离艺术', '斯皮尔伯格', '幻想', '杰作', '壁画', '朦胧', '黑白(3d)', '梵高', '毕加索', '莫奈', '丰子恺', '现代', '欧美'] with gr.Blocks(css=css, title="ClueAI") as demo: gr.Markdown('

ClueAI全能师

') with gr.TabItem("文本生成", id='_tab'): with gr.Row(variant="compact").style( equal_height=True): text = gr.Textbox("标题:俄天然气管道泄漏爆炸", label="编辑内容", show_label=False, max_lines=20, placeholder="在这里输入...", ) task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章") btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False) with gr.Accordion("高级操作", open=False): do_sample = gr.Radio([True, False], label="是否采样", value=False) top_p = gr.Slider(0, 1, value=0, step=0.1, label="越大多样性越高, 按照概率采样") top_k = gr.Slider(1, 100, value=50, step=1, label="越大多样性越高,按照top k采样") max_token = gr.Slider(1, 512, value=64, step=1, label="生成的最大长度") temperature = gr.Slider(0,1, value=1, step=0.1, label="temperature, 越小下一个token预测概率越平滑") beam_size = gr.Slider(1, 4, value=1, step=1, label="beam size, 越大解码窗口越广,") length_penalty = gr.Slider(-1, 1, value=0.6, step=0.1, label="大于0鼓励长句子,小于0鼓励短句子") with gr.Row(variant="compact").style( equal_height=True): output_text = gr.Textbox( label="输出", show_label=True, max_lines=50, placeholder="在这里展示结果", ) gr.Examples(examples_list, [task, text], label="示例") input_params = [text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty] #text.submit(inference_gen, inputs=input_params, outputs=output_text) btn.click(inference_gen, inputs=input_params, outputs=output_text) with gr.TabItem("图像生成", id='txt2img_tab'): with gr.Row(variant="compact").style( equal_height=True): text = gr.Textbox("美丽的风景", label="编辑内容", show_label=False, max_lines=2, placeholder="在这里输入你的描述...", ) btn = gr.Button("生成图像",elem_id="t2i_btn_1").style(full_width=False) with gr.Row().style( equal_height=True): generate_prompt_btn = gr.Button("手气不错", elem_id="luck_t2i_btn_1") style = gr.Dropdown(label="风格", show_label=True, choices=image_styles, value="无") with gr.Accordion("高级操作", open=False): n_text = gr.Textbox("", label="不想要生成的元素", show_label=True, max_lines=2, placeholder="在这里输入你不需要包含的内容...", ) guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="和你的描述匹配程度,越大越匹配") shape = gr.Radio(["1x1", "16x9", "手机壁纸"], label="尺寸", value="1x1") shape_scale = gr.Radio([1, 2, 3], label="对图放大倍数", value=1) steps = gr.Slider(10, 150, value=50, step=1, label="越大质量越好,生成时间越长") clarity = gr.Radio(["标清", "高清"], label="清晰度", value="标清") gr.Examples(["秋日的晚霞", "星空", "室内装修", "婚礼鲜花"], text, label="示例") t2i_gallery = gr.Gallery( t2i_default_img_path_list, label="生成图像", show_label=False).style( grid=[2], height="auto" ) input_params = [text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale] generate_prompt_btn.click(luck_inference_image, inputs=input_params, outputs=[t2i_gallery]) text.submit(inference_image, inputs=input_params, outputs=t2i_gallery) btn.click(inference_image, inputs=input_params, outputs=t2i_gallery) # Page Count gr.Markdown("""
""") #demo.queue(concurrency_count=3) demo.launch()