import gradio as gr from share_btn import community_icon_html, loading_icon_html, share_js import random import re import torch from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed import gradio as grad from diffusers import StableDiffusionPipeline tokenizer = AutoTokenizer.from_pretrained("shahp7575/gpt2-horoscopes") model = AutoModelWithLMHead.from_pretrained("shahp7575/gpt2-horoscopes") def fn(sign, cat): sign = "scorpio" prompt = f"<|category|> {cat} <|horoscope|> {sign}" prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) sample_outputs = model.generate( prompt_encoded, do_sample=True, top_k=40, max_length=300, top_p=0.95, temperature=0.95, num_beams=4, num_return_sequences=4, ) final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) starting_text = " ".join(final_out.split(" ")[4:]) pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2") seed = random.randint(100, 1000000) set_seed(seed) response = pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1) pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") image = pipe(response[0]["generated_text"], num_inference_steps=10).images[0] return [image, starting_text] block = gr.Blocks(css="./css.css") with block: with gr.Group(): with gr.Box(): with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): text = gr.Dropdown( label="Star Sign", choices=["aries", "taurus","gemini", "cancer", "leo", "virgo", "libra", "scorpio", "sagittarius", "capricorn", "aquarius", "Pisces"], show_label=True, max_lines=1, placeholder="Enter your prompt", elem_id="prompt-text-input", ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) text2 = gr.Dropdown( choices=["love", "career", "wellness"], label="Category", show_label=True, max_lines=1, placeholder="Enter your prompt", elem_id="prompt-text-input", ).style( border=(True, True, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), full_width=False, ) gallery = gr.Image( interactive=False, label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") text = gr.Textbox("Text") with gr.Group(elem_id="container-advanced-btns"): with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_html) loading_icon = gr.HTML(loading_icon_html) share_button = gr.Button("Share to community", elem_id="share-btn") btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text]) share_button.click( None, [], [], _js=share_js, ) block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)