import gradio as gr import os, pandas as pd from transformers import pipeline, set_seed import random, re import sys from pathlib import Path # read models from csv and init models = [r for idx,r in pd.read_csv('models.csv').iterrows()] current_model = models[0] # pipeline and core fncn for prompt generation gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2') with open("ideas.txt", "r") as f: line = f.readlines() def generate(starting_text): seed = random.randint(100, 1000000) set_seed(seed) if starting_text == "": starting_text: str = line[random.randrange(0, len(line))].replace("\n", "").lower().capitalize() starting_text: str = re.sub(r"[,:\-–.!;?_]", '', starting_text) response = gpt2_pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=4) response_list = [] for x in response: resp = x['generated_text'].strip() if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "—")) is False: response_list.append(resp+'\n') response_end = "\n".join(response_list) response_end = re.sub('[^ ]+\.[^ ]+','', response_end) response_end = response_end.replace("<", "").replace(">", "") if response_end != "": return response_end examples = [] for x in range(8): examples.append(line[random.randrange(0, len(line))].replace("\n", "").lower().capitalize()) # text_gen = gr.Interface.load("spaces/Omnibus/MagicPrompt-Stable-Diffusion_link") models2 = [] #can improve this, no need to load and keep all models in memory, increases cpu usage + latency for model in models: model_url = f"models/{model['url']}" loaded_model = gr.Interface.load(model_url, live=True, preprocess=True) models2.append(loaded_model) def text_it(inputs,text_gen=generate): return text_gen(inputs).split('\n\n')[0] def set_model(current_model_index): global current_model current_model = models[current_model_index] return gr.update(value=f"{current_model['name']}") def send_it(inputs, model_choice): proc = models2[model_choice] return proc(inputs) with gr.Blocks() as myface: gr.HTML() with gr.Row(): with gr.Row(): input_text = gr.Textbox(label="Prompt idea", lines=1) # Model selection dropdown model_name1 = gr.Dropdown( label="Choose Model", choices=[m["name"] for m in models], type="index", value=current_model["name"], interactive=True, ) with gr.Row(): see_prompts = gr.Button("Step 1 - Generate Prompts", variant="primary") run = gr.Button("Step 2 - Generate Images", variant="primary") with gr.Row(): output1 = gr.Image(label="") output2 = gr.Image(label="") output3 = gr.Image(label="") with gr.Row(): magic1 = gr.Textbox(label="Generated Prompt", lines=2) magic2 = gr.Textbox(label="Generated Prompt", lines=2) magic3 = gr.Textbox(label="Generated Prompt", lines=2) with gr.Row(): output4 = gr.Image(label="") output5 = gr.Image(label="") output6 = gr.Image(label="") with gr.Row(): magic4 = gr.Textbox(label="Generated Prompt", lines=2) magic5 = gr.Textbox(label="Generated Prompt", lines=2) magic6 = gr.Textbox(label="Generated Prompt", lines=2) with gr.Row(): output7 = gr.Image(label="") output8 = gr.Image(label="") output9 = gr.Image(label="") with gr.Row(): magic7 = gr.Textbox(label="Generated Prompt", lines=2) magic8 = gr.Textbox(label="Generated Prompt", lines=2) magic9 = gr.Textbox(label="Generated Prompt", lines=2) model_name1.change(set_model, inputs=model_name1, outputs=model_name1) run.click(send_it, inputs=[magic1, model_name1], outputs=[output1]) run.click(send_it, inputs=[magic2, model_name1], outputs=[output2]) run.click(send_it, inputs=[magic3, model_name1], outputs=[output3]) run.click(send_it, inputs=[magic4, model_name1], outputs=[output4]) run.click(send_it, inputs=[magic5, model_name1], outputs=[output5]) run.click(send_it, inputs=[magic6, model_name1], outputs=[output6]) run.click(send_it, inputs=[magic7, model_name1], outputs=[output7]) run.click(send_it, inputs=[magic8, model_name1], outputs=[output8]) run.click(send_it, inputs=[magic9, model_name1], outputs=[output9]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic1]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic2]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic3]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic4]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic5]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic6]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic7]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic8]) see_prompts.click(text_it, inputs=[input_text], outputs=[magic9]) myface.queue(concurrency_count=64) myface.launch(inline=True, show_api=False, max_threads=64)