Kvikontent commited on
Commit
ba0ed52
1 Parent(s): 62ca682

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -49
app.py CHANGED
@@ -1,52 +1,57 @@
1
- import streamlit as st
2
- import replicate
3
- import requests
4
- from PIL import Image
5
- from io import BytesIO
6
- import threading
7
-
8
- # Set the replicate API token
9
- import os
10
- os.environ["REPLICATE_API_TOKEN"] = "r8_JSR8xlRoCk6cmq3qEOOThVTn3dAgdPq1bWXdj"
11
-
12
- model_name = "fofr/sdxl-turbo:6244ebc4d96ffcc48fa1270d22a1f014addf79c41732fe205fb1ff638c409267"
13
-
14
- class PromptGenerator(threading.Thread):
15
- def __init__(self, prompt, update_prompt, lock):
16
- self.prompt = prompt
17
- self.update_prompt = update_prompt
18
- self.lock = lock
19
- super(PromptGenerator, self).__init__()
20
-
21
- def run(self):
22
- while True:
23
- new_prompt = input('Enter a new prompt: ')
24
- with self.lock:
25
- self.prompt = new_prompt
26
- self.update_prompt(self.prompt)
27
-
28
- def generate_output(prompt):
29
- return replicate.run(model_name, input={"prompt": prompt})
30
 
31
- st.title("Hugging Face Model Real-time Interface")
32
 
33
- lock = threading.Lock()
34
-
35
- prompt = st.text_input("Enter your prompt here")
36
- update_prompt = st.empty()
37
-
38
- def update_promt_thread(prompt):
39
- with lock:
40
- update_prompt.text(prompt)
41
-
42
- t = PromptGenerator(prompt, update_prompt, lock)
43
- t.start()
44
-
45
- output_url = generate_output(prompt)
 
 
 
 
 
 
 
 
46
 
47
- try:
48
- response = requests.get(output_url)
49
- img = Image.open(BytesIO(response.content))
50
- st.image(img, caption="Model Output", use_column_width=True)
51
- except Exception as e:
52
- st.write("Unable to display the image.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForText2Image
2
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
5
 
6
+ import os
7
+ import shlex
8
+ import subprocess
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+ id_rsa_file = "/content/id_rsa"
13
+ id_rsa_pub_file = "/content/id_rsa.pub"
14
+ if os.path.exists(id_rsa_file):
15
+ os.remove(id_rsa_file)
16
+ if os.path.exists(id_rsa_pub_file):
17
+ os.remove(id_rsa_pub_file)
18
+
19
+ def gen_key(path: Union[str, Path]) -> None:
20
+ path = Path(path)
21
+ arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
22
+ args = shlex.split(arg_string)
23
+ subprocess.run(args, check=True)
24
+ path.chmod(0o600)
25
+
26
+ gen_key(id_rsa_file)
27
 
28
+ import threading
29
+ def tunnel():
30
+ !ssh -R 80:127.0.0.1:7860 -o StrictHostKeyChecking=no -i /content/id_rsa remote.moe
31
+ threading.Thread(target=tunnel, daemon=True).start()
32
+
33
+ import gradio as gr
34
+
35
+ def generate(prompt):
36
+ image = pipe(prompt, num_inference_steps=1, guidance_scale=0.0, width=512, height=512).images[0]
37
+ return image.resize((512, 512))
38
+
39
+ with gr.Blocks(title=f"Realtime SDXL Turbo", css=".gradio-container {max-width: 544px !important}") as demo:
40
+ with gr.Row():
41
+ with gr.Column():
42
+ textbox = gr.Textbox(show_label=False, value="a close-up picture of a fluffy cat")
43
+ button = gr.Button()
44
+ with gr.Row(variant="default"):
45
+ output_image = gr.Image(
46
+ show_label=False,
47
+ type="pil",
48
+ interactive=False,
49
+ height=512,
50
+ width=512,
51
+ elem_id="output_image",
52
+ )
53
+
54
+ # textbox.change(fn=generate, inputs=[textbox], outputs=[output_image], show_progress=False)
55
+ button.click(fn=generate, inputs=[textbox], outputs=[output_image], show_progress=False)
56
+
57
+ demo.queue().launch(inline=False, share=True, debug=True)