Komorebizyd commited on
Commit
bab0afc
1 Parent(s): a5931f6

add vispunk to draw function

Browse files
Files changed (1) hide show
  1. streamlit_app.py +66 -5
streamlit_app.py CHANGED
@@ -1,5 +1,11 @@
1
  from huggingface_hub import InferenceClient
2
  import streamlit as st
 
 
 
 
 
 
3
 
4
  if "draw_model" not in st.session_state:
5
  st.session_state.draw_model_list = {
@@ -35,7 +41,8 @@ if "draw_model" not in st.session_state:
35
  "Dalle-proteus-v0.2":"https://api-inference.huggingface.co/models/dataautogpt3/ProteusV0.2",
36
  }
37
  st.session_state.draw_model = st.session_state.draw_model_list["Dalle-v1.1"]
38
-
 
39
 
40
  show_app = st.container()
41
 
@@ -43,20 +50,74 @@ def change_paramater():
43
  st.session_state.draw_model = st.session_state.draw_model
44
 
45
 
46
- def free_text_to_image(text):
 
 
 
 
 
 
 
 
 
47
  client = InferenceClient(model=st.session_state.draw_model_list[st.session_state.draw_model])
48
  image = client.text_to_image(text)
49
  return image
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def main(prompt):
53
  show_app.write("**You:** " + prompt)
54
- image = free_text_to_image(prompt)
55
- show_app.image(image,use_column_width=True)
 
 
 
56
 
57
 
58
  with st.sidebar:
59
- st.session_state.draw_model = st.selectbox('Draw Models', sorted(st.session_state.draw_model_list.keys(),key=lambda x:x.split("-")[0]),on_change=change_paramater)
 
 
 
 
 
60
 
61
  prompt = st.chat_input("Send your prompt")
62
  if prompt:
 
1
  from huggingface_hub import InferenceClient
2
  import streamlit as st
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import requests
6
+ import base64
7
+ import time
8
+ import base64
9
 
10
  if "draw_model" not in st.session_state:
11
  st.session_state.draw_model_list = {
 
41
  "Dalle-proteus-v0.2":"https://api-inference.huggingface.co/models/dataautogpt3/ProteusV0.2",
42
  }
43
  st.session_state.draw_model = st.session_state.draw_model_list["Dalle-v1.1"]
44
+ st.session_state.image_choice = True
45
+ st.session_state.image_choice_name = "Huggingface"
46
 
47
  show_app = st.container()
48
 
 
50
  st.session_state.draw_model = st.session_state.draw_model
51
 
52
 
53
+ def image_choice():
54
+ if st.session_state.image_choice:
55
+ st.session_state.image_choice = False
56
+ st.session_state.image_choice_name = "Vispunk"
57
+ else:
58
+ st.session_state.image_choice = True
59
+ st.session_state.image_choice_name = "Huggingface"
60
+
61
+
62
+ def huggingface_text_to_image(text):
63
  client = InferenceClient(model=st.session_state.draw_model_list[st.session_state.draw_model])
64
  image = client.text_to_image(text)
65
  return image
66
 
67
 
68
+ def query_vispunk(prompt):
69
+ def request_generate(prompt):
70
+ url = "https://motion-api.vispunk.com/v1/generate/generate_image"
71
+ headers = {"Content-Type": "application/json"}
72
+ data = {"prompt": prompt}
73
+ try:
74
+ response = requests.post(url, headers=headers, json=data)
75
+ return True,response.json()["task_id"]
76
+ except Exception as e:
77
+ st.error(f"Error: {e}")
78
+ return False,None
79
+
80
+
81
+ def request_image(task_id):
82
+ url = "https://motion-api.vispunk.com/v1/generate/check_image_task"
83
+ headers = {"Content-Type": "application/json"}
84
+ data = {"task_id": task_id}
85
+ try:
86
+ response = requests.post(url, headers=headers, json=data)
87
+ return True,response.json()["images"][0]
88
+ except Exception as e:
89
+ return False,e
90
+
91
+ flag_generate,task_id = request_generate(prompt)
92
+ if flag_generate:
93
+ while True:
94
+ flag_wait,image_src = request_image(task_id)
95
+ if not flag_wait:
96
+ time.sleep(1)
97
+ else:
98
+ image_data = base64.b64decode(image_src)
99
+ image = BytesIO(image_data)
100
+ return True,image
101
+ else:
102
+ return False,task_id
103
+
104
+
105
  def main(prompt):
106
  show_app.write("**You:** " + prompt)
107
+ if st.session_state.image_choice:
108
+ image = huggingface_text_to_image(prompt)
109
+ else:
110
+ flag,image = query_vispunk(prompt)
111
+ show_app.image(image,caption=prompt,use_column_width=True)
112
 
113
 
114
  with st.sidebar:
115
+ st.session_state.image_choice = st.toggle(st.session_state.image_choice_name,value=st.session_state.image_choice,on_change=image_choice)
116
+ if st.session_state.image_choice:
117
+ st.session_state.draw_model = st.selectbox('Draw Models', sorted(st.session_state.draw_model_list.keys(),key=lambda x:x.split("-")[0]),on_change=change_paramater)
118
+ else:
119
+ pass
120
+
121
 
122
  prompt = st.chat_input("Send your prompt")
123
  if prompt: