huguru / app.py
pngwn's picture
pngwn HF staff
modify-prompt (#4)
47178c0
raw
history blame
3.76 kB
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import random
import re
import torch
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed
import gradio as grad
from diffusers import StableDiffusionPipeline
tokenizer = AutoTokenizer.from_pretrained("shahp7575/gpt2-horoscopes")
model = AutoModelWithLMHead.from_pretrained("shahp7575/gpt2-horoscopes")
def fn(sign, cat):
sign = "scorpio"
prompt = f"<|category|> {cat} <|horoscope|> {sign}"
prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
sample_outputs = model.generate(
prompt_encoded,
do_sample=True,
top_k=40,
max_length=300,
top_p=0.95,
temperature=0.95,
num_beams=4,
num_return_sequences=4,
)
final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
starting_text = " ".join(final_out.split(" ")[4:])
pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")
seed = random.randint(100, 1000000)
set_seed(seed)
response = pipe(starting_text + " " + sign + " art.", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
image = pipe(response[0]["generated_text"], num_inference_steps=5).images[0]
return [image, starting_text]
block = gr.Blocks(css="./css.css")
with block:
with gr.Group():
with gr.Box():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
text = gr.Dropdown(
label="Star Sign",
choices=["aries", "taurus","gemini", "cancer", "leo", "virgo", "libra", "scorpio", "sagittarius", "capricorn", "aquarius", "Pisces"],
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text-input",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
text2 = gr.Dropdown(
choices=["love", "career", "wellness"],
label="Category",
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text-input",
).style(
border=(True, True, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
gallery = gr.Image(
interactive=False,
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
text = gr.Textbox("Text")
with gr.Group(elem_id="container-advanced-btns"):
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text])
share_button.click(
None,
[],
[],
_js=share_js,
)
block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)