RanM commited on
Commit
e199da3
·
verified ·
1 Parent(s): 8facf69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from concurrent.futures import ProcessPoolExecutor
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from diffusers import StableDiffusionPipeline
7
+ import gradio as gr
8
+ from generate_prompts import generate_prompt
9
+
10
+ # Load the model once at the start
11
+ print("Loading the Stable Diffusion model...")
12
+ try:
13
+ model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
14
+ print("Model loaded successfully.")
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ model = None
18
+
19
+ def generate_image(prompt, prompt_name):
20
+ try:
21
+ if model is None:
22
+ raise ValueError("Model not loaded properly.")
23
+
24
+ print(f"Generating image for {prompt_name} with prompt: {prompt}")
25
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
26
+ print(f"Model output for {prompt_name}: {output}")
27
+
28
+ if output is None:
29
+ raise ValueError(f"Model returned None for {prompt_name}")
30
+
31
+ if hasattr(output, 'images') and output.images:
32
+ print(f"Image generated for {prompt_name}")
33
+ image = output.images[0]
34
+ buffered = BytesIO()
35
+ image.save(buffered, format="JPEG")
36
+ image_bytes = buffered.getvalue()
37
+ return image_bytes
38
+ else:
39
+ print(f"No images found in model output for {prompt_name}")
40
+ raise ValueError(f"No images found in model output for {prompt_name}")
41
+ except Exception as e:
42
+ print(f"An error occurred while generating image for {prompt_name}: {e}")
43
+ return None
44
+
45
+ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
46
+ print("Starting to queue API calls...")
47
+ prompts = []
48
+ for paragraph_number, sentences in sentence_mapping.items():
49
+ combined_sentence = " ".join(sentences)
50
+ prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
51
+ prompts.append((paragraph_number, prompt))
52
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
53
+
54
+ loop = asyncio.get_running_loop()
55
+ with ProcessPoolExecutor() as pool:
56
+ tasks = [
57
+ loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
58
+ for paragraph_number, prompt in prompts
59
+ ]
60
+ responses = await asyncio.gather(*tasks)
61
+
62
+ images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
63
+ print("Finished queuing API calls. Generated images: ", images)
64
+ return images
65
+
66
+ def process_prompt(sentence_mapping, character_dict, selected_style):
67
+ print("Processing prompt...")
68
+ print(f"Sentence Mapping: {sentence_mapping}")
69
+ print(f"Character Dict: {character_dict}")
70
+ print(f"Selected Style: {selected_style}")
71
+ try:
72
+ loop = asyncio.get_running_loop()
73
+ print("Using existing event loop.")
74
+ except RuntimeError:
75
+ loop = asyncio.new_event_loop()
76
+ asyncio.set_event_loop(loop)
77
+ print("Created new event loop.")
78
+
79
+ cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
80
+ print("Prompt processing complete. Generated images: ", cmpt_return)
81
+ return cmpt_return
82
+
83
+ gradio_interface = gr.Interface(
84
+ fn=process_prompt,
85
+ inputs=[
86
+ gr.JSON(label="Sentence Mapping"),
87
+ gr.JSON(label="Character Dict"),
88
+ gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
89
+ ],
90
+ outputs="json"
91
+ ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
92
+
93
+ if __name__ == "__main__":
94
+ print("Launching Gradio interface...")
95
+ gradio_interface.launch()