tugot17 commited on
Commit
eab86f7
·
1 Parent(s): 0e6f299

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +8 -7
  2. img_gen.py +69 -0
  3. prompt_generation.py +1 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from gtts import gTTS
3
 
4
- from img_gen_v2 import generate_story
5
  from prompt_generation import pipeline
6
 
7
 
@@ -15,7 +15,7 @@ def page_navigation(current_page):
15
  current_page -= 1
16
 
17
  with col2:
18
- st.write(f'Step {current_page} of 10')
19
 
20
  if current_page < 10:
21
  with col3:
@@ -23,10 +23,11 @@ def page_navigation(current_page):
23
  if current_page == 0:
24
  user_input = st.session_state.user_input
25
  prompt_response = pipeline(user_input, 10)
26
- steps = prompt_response.get("steps")
27
  init_prompt = prompt_response.get("story")
28
 
29
- init_img, img_dict = generate_story(init_prompt, steps)
 
30
 
31
  st.session_state.pipeline_response = prompt_response
32
  st.session_state.init_img = init_img
@@ -42,7 +43,7 @@ def get_pipeline_data(page_number):
42
  pipeline_response = st.session_state.pipeline_response
43
  text_output = pipeline_response.get("steps")[page_number - 1]
44
  img_dict = st.session_state.img_dict
45
- img = img_dict[page_number-1].get("image")
46
 
47
  return {"text_output": text_output, "image_obj": img}
48
 
@@ -56,7 +57,7 @@ def main():
56
 
57
  # Display content for each page
58
  if current_page == 0:
59
- st.write("Tell me what story you would like me to tell:")
60
  user_input = st.text_area("")
61
  st.session_state.user_input = user_input
62
 
@@ -69,7 +70,7 @@ def main():
69
  # Display text output
70
  st.write(text_output)
71
 
72
- tts = gTTS(text_output)
73
  tts.save('audio.mp3')
74
  st.audio('audio.mp3')
75
 
 
1
  import streamlit as st
2
  from gtts import gTTS
3
 
4
+ from img_gen import generate_story
5
  from prompt_generation import pipeline
6
 
7
 
 
15
  current_page -= 1
16
 
17
  with col2:
18
+ print(f'Step {current_page} of 10')
19
 
20
  if current_page < 10:
21
  with col3:
 
23
  if current_page == 0:
24
  user_input = st.session_state.user_input
25
  prompt_response = pipeline(user_input, 10)
26
+ image_prompts_steps = prompt_response.get("image_prompts_steps")
27
  init_prompt = prompt_response.get("story")
28
 
29
+ init_img, img_dict = generate_story(init_prompt,
30
+ image_prompts_steps)
31
 
32
  st.session_state.pipeline_response = prompt_response
33
  st.session_state.init_img = init_img
 
43
  pipeline_response = st.session_state.pipeline_response
44
  text_output = pipeline_response.get("steps")[page_number - 1]
45
  img_dict = st.session_state.img_dict
46
+ img = img_dict[page_number - 1].get("image")
47
 
48
  return {"text_output": text_output, "image_obj": img}
49
 
 
57
 
58
  # Display content for each page
59
  if current_page == 0:
60
+ st.write("Describe a story you would like me to tell:")
61
  user_input = st.text_area("")
62
  st.session_state.user_input = user_input
63
 
 
70
  # Display text output
71
  st.write(text_output)
72
 
73
+ tts = gTTS(text_output.split(".", 1)[1])
74
  tts.save('audio.mp3')
75
  st.audio('audio.mp3')
76
 
img_gen.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import StableDiffusionImg2ImgPipeline, \
4
+ StableDiffusionPipeline
5
+
6
+
7
+ def check_cuda_device():
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ return device
10
+
11
+
12
+ def get_the_model(device=None):
13
+ model_id = "stabilityai/stable-diffusion-2"
14
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
15
+ torch_dtype=torch.float16)
16
+ if device:
17
+ pipe.to(device)
18
+ else:
19
+ device = check_cuda_device()
20
+ pipe.to(device)
21
+
22
+ return pipe
23
+
24
+
25
+ def get_image_to_image_model(path=None, device=None):
26
+ model_id = "stabilityai/stable-diffusion-2"
27
+ if path:
28
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
29
+ path,
30
+ torch_dtype=torch.float16)
31
+ else:
32
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch.float16)
35
+ if device:
36
+ if device == "cuda" or device == "cpu":
37
+ pipe.to(device)
38
+ else:
39
+ device = check_cuda_device()
40
+ pipe.to(device)
41
+
42
+ return pipe
43
+
44
+
45
+ def gen_initial_img(int_prompt):
46
+ model = get_the_model(None)
47
+ image = model(int_prompt, num_inference_steps=100).images[0]
48
+
49
+ return image
50
+
51
+
52
+ def generate_story(int_prompt, steps, iterations=133):
53
+ image_dic = {}
54
+ init_img = gen_initial_img(int_prompt)
55
+ img2img_model = get_image_to_image_model()
56
+ img = init_img
57
+
58
+ for idx, step in enumerate(steps):
59
+ print(f"step: {idx}")
60
+ print(step)
61
+ image = img2img_model(prompt=step, image=img, strength=0.75, guidance_scale=7.5,
62
+ num_inference_steps=iterations).images[0]
63
+ image_dic[idx] = {
64
+ "image": image,
65
+ "prompt": step
66
+ }
67
+ img = image
68
+
69
+ return init_img, image_dic
prompt_generation.py CHANGED
@@ -97,6 +97,7 @@ def pipeline(user_description: str, n_steps: int = 10) -> dict:
97
 
98
  image_prompts = [fut.result() for fut in image_prompts_futures]
99
 
 
100
  return {"story": story, "steps": steps, "image_prompts": image_prompts}
101
 
102
 
 
97
 
98
  image_prompts = [fut.result() for fut in image_prompts_futures]
99
 
100
+ print(story)
101
  return {"story": story, "steps": steps, "image_prompts": image_prompts}
102
 
103