File size: 3,737 Bytes
78e2fb5
 
f82013f
 
78e2fb5
f82013f
 
 
 
 
b04693a
7890ebe
 
 
f82013f
 
 
 
 
7890ebe
f82013f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36e8b20
f82013f
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
f82013f
ea20bfc
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
 
 
ea20bfc
28d6e58
 
 
 
 
 
 
 
 
78e2fb5
 
 
 
 
 
28d6e58
90ec86f
78e2fb5
 
f82013f
78e2fb5
 
 
 
 
 
 
f82013f
aba8522
 
 
 
 
 
12c1229
78e2fb5
dc3112c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import random
import re

import torch
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed

import gradio as grad
from diffusers import StableDiffusionPipeline

tokenizer = AutoTokenizer.from_pretrained("shahp7575/gpt2-horoscopes")
model = AutoModelWithLMHead.from_pretrained("shahp7575/gpt2-horoscopes")

def fn(sign, cat):
    sign = "scorpio"

    prompt = f"<|category|> {cat} <|horoscope|> {sign}"

    

    prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

    sample_outputs = model.generate(
        prompt_encoded,
        do_sample=True,
        top_k=40,
        max_length=300,
        top_p=0.95,
        temperature=0.95,
        num_beams=4,
        num_return_sequences=4,
    )

    final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    starting_text = " ".join(final_out.split(" ")[4:])
    pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")

    seed = random.randint(100, 1000000)
    set_seed(seed)

    response = pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    image = pipe(response[0]["generated_text"], num_inference_steps=10).images[0]
    return [image, starting_text]


block = gr.Blocks(css="./css.css")

with block:
    with gr.Group():
        with gr.Box():
            with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
                text = gr.Dropdown(
                    label="Star Sign",
                    choices=["aries", "taurus","gemini", "cancer", "leo", "virgo", "libra", "scorpio", "sagittarius", "capricorn", "aquarius", "Pisces"],
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                text2 = gr.Dropdown(
                    choices=["love", "career", "wellness"],
                    label="Category",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, True, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                    full_width=False,
                )

        gallery = gr.Image(
            interactive=False,
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")
        text = gr.Textbox("Text")

        with gr.Group(elem_id="container-advanced-btns"):
            with gr.Group(elem_id="share-btn-container"):
                community_icon = gr.HTML(community_icon_html)
                loading_icon = gr.HTML(loading_icon_html)
                share_button = gr.Button("Share to community", elem_id="share-btn")

        btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text])
        share_button.click(
            None,
            [],
            [],
            _js=share_js,
        )
        

block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)