Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -1,15 +1,24 @@ 
     | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            -
            import  
     | 
| 3 | 
         
             
            from generate_prompts import generate_prompt
         
     | 
| 4 | 
         
             
            from diffusers import AutoPipelineForText2Image
         
     | 
| 5 | 
         
             
            from io import BytesIO
         
     | 
| 6 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 9 | 
         
             
                try:
         
     | 
| 10 | 
         
             
                    print(f"Generating response for {prompt_name} with prompt: {prompt}")
         
     | 
| 11 | 
         
            -
                    # Load the model instance for each prompt
         
     | 
| 12 | 
         
            -
                    model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
         
     | 
| 13 | 
         
             
                    output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
         
     | 
| 14 | 
         
             
                    print(f"Output for {prompt_name}: {output}")
         
     | 
| 15 | 
         | 
| 
         @@ -21,21 +30,20 @@ async def generate_image(prompt, prompt_name): 
     | 
|
| 21 | 
         
             
                            image.save(buffered, format="JPEG")
         
     | 
| 22 | 
         
             
                            image_bytes = buffered.getvalue()
         
     | 
| 23 | 
         
             
                            print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
         
     | 
| 24 | 
         
            -
                            return image_bytes
         
     | 
| 25 | 
         
             
                        except Exception as e:
         
     | 
| 26 | 
         
             
                            print(f"Error saving image for {prompt_name}: {e}")
         
     | 
| 27 | 
         
            -
                            return None
         
     | 
| 28 | 
         
             
                    else:
         
     | 
| 29 | 
         
             
                        raise Exception(f"No images returned by the model for {prompt_name}.")
         
     | 
| 30 | 
         
             
                except Exception as e:
         
     | 
| 31 | 
         
             
                    print(f"Error generating image for {prompt_name}: {e}")
         
     | 
| 32 | 
         
            -
                    return None
         
     | 
| 33 | 
         | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
                print(f" 
     | 
| 
         | 
|
| 36 | 
         
             
                prompts = []
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                # Generate prompts for each paragraph
         
     | 
| 39 | 
         
             
                for paragraph_number, sentences in sentence_mapping.items():
         
     | 
| 40 | 
         
             
                    combined_sentence = " ".join(sentences)
         
     | 
| 41 | 
         
             
                    print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
         
     | 
| 
         @@ -43,33 +51,27 @@ async def queue_api_calls(sentence_mapping, character_dict, selected_style): 
     | 
|
| 43 | 
         
             
                    prompts.append((paragraph_number, prompt))
         
     | 
| 44 | 
         
             
                    print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
         
     | 
| 45 | 
         | 
| 46 | 
         
            -
                 
     | 
| 47 | 
         
            -
                 
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
                 
     | 
| 50 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 51 | 
         | 
| 52 | 
         
            -
                images = { 
     | 
| 53 | 
         
             
                print(f"Images generated: {images}")
         
     | 
| 54 | 
         
             
                return images
         
     | 
| 55 | 
         | 
| 56 | 
         
             
            def process_prompt(sentence_mapping, character_dict, selected_style):
         
     | 
| 57 | 
         
             
                print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
         
     | 
| 58 | 
         
            -
                 
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
                     
     | 
| 61 | 
         
            -
                 
     | 
| 62 | 
         
            -
                     
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                    asyncio.set_event_loop(loop)
         
     | 
| 65 | 
         
            -
                print("Event loop created.")
         
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
                # This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
         
     | 
| 68 | 
         
            -
                cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
         
     | 
| 69 | 
         
            -
                print(f"process_prompt completed with return value: {cmpt_return}")
         
     | 
| 70 | 
         
            -
                return cmpt_return
         
     | 
| 71 | 
         | 
| 72 | 
         
            -
            # Gradio interface with high concurrency limit
         
     | 
| 73 | 
         
             
            gradio_interface = gr.Interface(
         
     | 
| 74 | 
         
             
                fn=process_prompt,
         
     | 
| 75 | 
         
             
                inputs=[
         
     | 
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            +
            import multiprocessing
         
     | 
| 3 | 
         
             
            from generate_prompts import generate_prompt
         
     | 
| 4 | 
         
             
            from diffusers import AutoPipelineForText2Image
         
     | 
| 5 | 
         
             
            from io import BytesIO
         
     | 
| 6 | 
         
             
            import gradio as gr
         
     | 
| 7 | 
         
            +
            import json
         
     | 
| 8 | 
         | 
| 9 | 
         
            +
            # Define a global variable to hold the model
         
     | 
| 10 | 
         
            +
            model = None
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def initialize_model():
         
     | 
| 13 | 
         
            +
                global model
         
     | 
| 14 | 
         
            +
                if model is None:  # Ensure the model is loaded only once per process
         
     | 
| 15 | 
         
            +
                    print("Loading the model...")
         
     | 
| 16 | 
         
            +
                    model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
         
     | 
| 17 | 
         
            +
                    print("Model loaded successfully.")
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def generate_image(prompt, prompt_name):
         
     | 
| 20 | 
         
             
                try:
         
     | 
| 21 | 
         
             
                    print(f"Generating response for {prompt_name} with prompt: {prompt}")
         
     | 
| 
         | 
|
| 
         | 
|
| 22 | 
         
             
                    output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
         
     | 
| 23 | 
         
             
                    print(f"Output for {prompt_name}: {output}")
         
     | 
| 24 | 
         | 
| 
         | 
|
| 30 | 
         
             
                            image.save(buffered, format="JPEG")
         
     | 
| 31 | 
         
             
                            image_bytes = buffered.getvalue()
         
     | 
| 32 | 
         
             
                            print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
         
     | 
| 33 | 
         
            +
                            return prompt_name, image_bytes
         
     | 
| 34 | 
         
             
                        except Exception as e:
         
     | 
| 35 | 
         
             
                            print(f"Error saving image for {prompt_name}: {e}")
         
     | 
| 36 | 
         
            +
                            return prompt_name, None
         
     | 
| 37 | 
         
             
                    else:
         
     | 
| 38 | 
         
             
                        raise Exception(f"No images returned by the model for {prompt_name}.")
         
     | 
| 39 | 
         
             
                except Exception as e:
         
     | 
| 40 | 
         
             
                    print(f"Error generating image for {prompt_name}: {e}")
         
     | 
| 41 | 
         
            +
                    return prompt_name, None
         
     | 
| 42 | 
         | 
| 43 | 
         
            +
            def process_prompts(sentence_mapping, character_dict, selected_style):
         
     | 
| 44 | 
         
            +
                print(f"process_prompts called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
         
     | 
| 45 | 
         
            +
                
         
     | 
| 46 | 
         
             
                prompts = []
         
     | 
| 
         | 
|
| 
         | 
|
| 47 | 
         
             
                for paragraph_number, sentences in sentence_mapping.items():
         
     | 
| 48 | 
         
             
                    combined_sentence = " ".join(sentences)
         
     | 
| 49 | 
         
             
                    print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
         
     | 
| 
         | 
|
| 51 | 
         
             
                    prompts.append((paragraph_number, prompt))
         
     | 
| 52 | 
         
             
                    print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
         
     | 
| 53 | 
         | 
| 54 | 
         
            +
                num_prompts = len(prompts)
         
     | 
| 55 | 
         
            +
                print(f"Number of prompts: {num_prompts}")
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                # Limit the number of worker processes to the number of prompts
         
     | 
| 58 | 
         
            +
                with multiprocessing.Pool(processes=num_prompts, initializer=initialize_model) as pool:
         
     | 
| 59 | 
         
            +
                    tasks = [(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
         
     | 
| 60 | 
         
            +
                    results = pool.starmap(generate_image, tasks)
         
     | 
| 61 | 
         | 
| 62 | 
         
            +
                images = {prompt_name: image for prompt_name, image in results}
         
     | 
| 63 | 
         
             
                print(f"Images generated: {images}")
         
     | 
| 64 | 
         
             
                return images
         
     | 
| 65 | 
         | 
| 66 | 
         
             
            def process_prompt(sentence_mapping, character_dict, selected_style):
         
     | 
| 67 | 
         
             
                print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
         
     | 
| 68 | 
         
            +
                # Check if inputs are already in dict form
         
     | 
| 69 | 
         
            +
                if isinstance(sentence_mapping, str):
         
     | 
| 70 | 
         
            +
                    sentence_mapping = json.loads(sentence_mapping)
         
     | 
| 71 | 
         
            +
                if isinstance(character_dict, str):
         
     | 
| 72 | 
         
            +
                    character_dict = json.loads(character_dict)
         
     | 
| 73 | 
         
            +
                return process_prompts(sentence_mapping, character_dict, selected_style)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 74 | 
         | 
| 
         | 
|
| 75 | 
         
             
            gradio_interface = gr.Interface(
         
     | 
| 76 | 
         
             
                fn=process_prompt,
         
     | 
| 77 | 
         
             
                inputs=[
         
     |