File size: 4,038 Bytes
78e2fb5
 
f82013f
 
78e2fb5
f82013f
 
9659c11
f82013f
b04693a
9659c11
 
 
 
 
 
 
 
 
7890ebe
f82013f
 
 
 
 
 
 
 
 
 
 
 
9659c11
f82013f
 
 
 
 
 
9659c11
 
 
f82013f
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
9659c11
ea20bfc
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
9659c11
28d6e58
ea20bfc
28d6e58
 
 
 
 
 
 
 
 
78e2fb5
 
 
 
 
 
28d6e58
90ec86f
78e2fb5
 
f82013f
78e2fb5
 
 
 
 
 
 
f82013f
aba8522
 
 
 
 
 
12c1229
78e2fb5
dc3112c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import random
import re

import torch
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed
from optimum.intel.openvino import OVStableDiffusionPipeline


horoscope_model_id = "shahp7575/gpt2-horoscopes"
tokenizer = AutoTokenizer.from_pretrained(horoscope_model_id)
model = AutoModelWithLMHead.from_pretrained(horoscope_model_id)
text_generation_pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")
stable_diffusion_pipe = OVStableDiffusionPipeline.from_pretrained("echarlaix/stable-diffusion-v1-5-openvino", revision="fp16", compile=False)
height = 128
width = 128
stable_diffusion_pipe.reshape(batch_size=1, height=height, width=width, num_images_per_prompt=1)
stable_diffusion_pipe.compile()

def fn(sign, cat):
    prompt = f"<|category|> {cat} <|horoscope|> {sign}"
    prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

    sample_outputs = model.generate(
        prompt_encoded,
        do_sample=True,
        top_k=40,
        max_length=300,
        top_p=0.95,
        temperature=0.95,
        num_beams=4,
        num_return_sequences=1,
    )

    final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    starting_text = " ".join(final_out.split(" ")[4:])
    seed = random.randint(100, 1000000)
    set_seed(seed)
    response = text_generation_pipe(starting_text + " " + sign + " art", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
    image = stable_diffusion_pipe(response[0]["generated_text"], height=height, width=width, num_inference_steps=30).images[0]

    return [image, starting_text]


block = gr.Blocks(css="./css.css")

with block:
    with gr.Group():
        with gr.Box():
            with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
                text = gr.Dropdown(
                    label="Star Sign",
                    choices=["Aries", "Taurus","Gemini", "Cancer", "Leo", "Virgo", "Libra", "Scorpio", "Sagittarius", "Capricorn", "Aquarius", "Pisces"],
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                text2 = gr.Dropdown(
                    choices=["Love", "Career", "Wellness"],
                    label="Category",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, True, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                    full_width=False,
                )

        gallery = gr.Image(
            interactive=False,
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")
        text = gr.Textbox("Text")

        with gr.Group(elem_id="container-advanced-btns"):
            with gr.Group(elem_id="share-btn-container"):
                community_icon = gr.HTML(community_icon_html)
                loading_icon = gr.HTML(loading_icon_html)
                share_button = gr.Button("Share to community", elem_id="share-btn")

        btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text])
        share_button.click(
            None,
            [],
            [],
            _js=share_js,
        )
        

block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)