Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	first commit
Browse files- README.md +3 -3
 - app.py +948 -0
 - examples/prompt_background.txt +8 -0
 - examples/prompt_background_advanced.txt +0 -0
 - examples/prompt_boy.txt +15 -0
 - examples/prompt_girl.txt +16 -0
 - examples/prompt_props.txt +43 -0
 - model.py +1095 -0
 - prompt_util.py +154 -0
 - requirements.txt +16 -0
 - share_btn.py +70 -0
 - util.py +315 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,12 +1,12 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
            -
            title:  
     | 
| 3 | 
         
            -
            emoji:  
     | 
| 4 | 
         
             
            colorFrom: red
         
     | 
| 5 | 
         
             
            colorTo: yellow
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
             
            sdk_version: 4.36.1
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
            -
            pinned:  
     | 
| 10 | 
         
             
            license: mit
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 12 | 
         | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
            +
            title: Semantic Palette with Stable Diffusion 3
         
     | 
| 3 | 
         
            +
            emoji: 🧠🎨3️
         
     | 
| 4 | 
         
             
            colorFrom: red
         
     | 
| 5 | 
         
             
            colorTo: yellow
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
             
            sdk_version: 4.36.1
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
            +
            pinned: true
         
     | 
| 10 | 
         
             
            license: mit
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 12 | 
         | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,948 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Jaerin Lee
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 4 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 5 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 6 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 7 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 8 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 11 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 14 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 15 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 16 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 17 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 18 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 19 | 
         
            +
            # SOFTWARE.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import sys
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            sys.path.append('../../src')
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            import argparse
         
     | 
| 26 | 
         
            +
            import random
         
     | 
| 27 | 
         
            +
            import time
         
     | 
| 28 | 
         
            +
            import json
         
     | 
| 29 | 
         
            +
            import os
         
     | 
| 30 | 
         
            +
            import glob
         
     | 
| 31 | 
         
            +
            import pathlib
         
     | 
| 32 | 
         
            +
            from functools import partial
         
     | 
| 33 | 
         
            +
            from pprint import pprint
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            import numpy as np
         
     | 
| 36 | 
         
            +
            from PIL import Image
         
     | 
| 37 | 
         
            +
            import torch
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            import gradio as gr
         
     | 
| 40 | 
         
            +
            from huggingface_hub import snapshot_download
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            from model import StableMultiDiffusion3Pipeline
         
     | 
| 43 | 
         
            +
            from util import seed_everything
         
     | 
| 44 | 
         
            +
            from prompt_util import preprocess_prompts, _quality_dict, _style_dict
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            ### Utils
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def log_state(state):
         
     | 
| 53 | 
         
            +
                pprint(vars(opt))
         
     | 
| 54 | 
         
            +
                if isinstance(state, gr.State):
         
     | 
| 55 | 
         
            +
                    state = state.value
         
     | 
| 56 | 
         
            +
                pprint(vars(state))
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def is_empty_image(im: Image.Image) -> bool:
         
     | 
| 60 | 
         
            +
                if im is None:
         
     | 
| 61 | 
         
            +
                    return True
         
     | 
| 62 | 
         
            +
                im = np.array(im)
         
     | 
| 63 | 
         
            +
                has_alpha = (im.shape[2] == 4)
         
     | 
| 64 | 
         
            +
                if not has_alpha:
         
     | 
| 65 | 
         
            +
                    return False
         
     | 
| 66 | 
         
            +
                elif im.sum() == 0:
         
     | 
| 67 | 
         
            +
                    return True
         
     | 
| 68 | 
         
            +
                else:
         
     | 
| 69 | 
         
            +
                    return False
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            ### Argument passing
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SD3 support.')
         
     | 
| 75 | 
         
            +
            parser.add_argument('-H', '--height', type=int, default=1024)
         
     | 
| 76 | 
         
            +
            parser.add_argument('-W', '--width', type=int, default=2560)
         
     | 
| 77 | 
         
            +
            parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
         
     | 
| 78 | 
         
            +
            parser.add_argument('--bootstrap_steps', type=int, default=2)
         
     | 
| 79 | 
         
            +
            parser.add_argument('--seed', type=int, default=-1)
         
     | 
| 80 | 
         
            +
            parser.add_argument('--device', type=int, default=0)
         
     | 
| 81 | 
         
            +
            parser.add_argument('--port', type=int, default=8000)
         
     | 
| 82 | 
         
            +
            opt = parser.parse_args()
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            ### Global variables and data structures
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            if opt.model is None:
         
     | 
| 91 | 
         
            +
                model_dict = {
         
     | 
| 92 | 
         
            +
                    'Stable Diffusion 3': 'stabilityai/stable-diffusion-3-medium-diffusers',
         
     | 
| 93 | 
         
            +
                }
         
     | 
| 94 | 
         
            +
            else:
         
     | 
| 95 | 
         
            +
                if opt.model.endswith('.safetensors'):
         
     | 
| 96 | 
         
            +
                    opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
         
     | 
| 97 | 
         
            +
                model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            dtype = torch.float32 if device == 'cpu' else torch.float16
         
     | 
| 100 | 
         
            +
            models = {
         
     | 
| 101 | 
         
            +
                k: StableMultiDiffusion3Pipeline(device, dtype=dtype, hf_key=v, has_i2t=False)
         
     | 
| 102 | 
         
            +
                for k, v in model_dict.items()
         
     | 
| 103 | 
         
            +
            }
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            prompt_suggestions = [
         
     | 
| 107 | 
         
            +
                '1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
         
     | 
| 108 | 
         
            +
                '1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
         
     | 
| 109 | 
         
            +
                '1girl, arima kana, oshi no ko, solo, upper body, from behind',
         
     | 
| 110 | 
         
            +
            ]
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            opt.max_palettes = 4
         
     | 
| 113 | 
         
            +
            opt.default_prompt_strength = 1.0
         
     | 
| 114 | 
         
            +
            opt.default_mask_strength = 1.0
         
     | 
| 115 | 
         
            +
            opt.default_mask_std = 0.0
         
     | 
| 116 | 
         
            +
            opt.default_negative_prompt = (
         
     | 
| 117 | 
         
            +
                'nsfw, worst quality, bad quality, normal quality, cropped, framed'
         
     | 
| 118 | 
         
            +
            )
         
     | 
| 119 | 
         
            +
            opt.verbose = True
         
     | 
| 120 | 
         
            +
            opt.colors = [
         
     | 
| 121 | 
         
            +
                '#000000',
         
     | 
| 122 | 
         
            +
                '#2692F3',
         
     | 
| 123 | 
         
            +
                '#F89E12',
         
     | 
| 124 | 
         
            +
                '#16C232',
         
     | 
| 125 | 
         
            +
                '#F92F6C',
         
     | 
| 126 | 
         
            +
                # '#AC6AEB',
         
     | 
| 127 | 
         
            +
                # '#92C62C',
         
     | 
| 128 | 
         
            +
                # '#92C6EC',
         
     | 
| 129 | 
         
            +
                # '#FECAC0',
         
     | 
| 130 | 
         
            +
            ]
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            ### Event handlers
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            def add_palette(state):
         
     | 
| 136 | 
         
            +
                old_actives = state.active_palettes
         
     | 
| 137 | 
         
            +
                state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                if opt.verbose:
         
     | 
| 140 | 
         
            +
                    log_state(state)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                if state.active_palettes != old_actives:
         
     | 
| 143 | 
         
            +
                    return [state] + [
         
     | 
| 144 | 
         
            +
                        gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
         
     | 
| 145 | 
         
            +
                    ] + [
         
     | 
| 146 | 
         
            +
                        gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
         
     | 
| 147 | 
         
            +
                        for i in range(opt.max_palettes)
         
     | 
| 148 | 
         
            +
                    ]
         
     | 
| 149 | 
         
            +
                else:
         
     | 
| 150 | 
         
            +
                    return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            def select_palette(state, button, idx):
         
     | 
| 154 | 
         
            +
                if idx < 0 or idx > opt.max_palettes:
         
     | 
| 155 | 
         
            +
                    idx = 0
         
     | 
| 156 | 
         
            +
                old_idx = state.current_palette
         
     | 
| 157 | 
         
            +
                if old_idx == idx:
         
     | 
| 158 | 
         
            +
                    return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                state.current_palette = idx
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                if opt.verbose:
         
     | 
| 163 | 
         
            +
                    log_state(state)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                updates = [state] + [
         
     | 
| 166 | 
         
            +
                    gr.update() if i not in (idx, old_idx) else
         
     | 
| 167 | 
         
            +
                    gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
         
     | 
| 168 | 
         
            +
                    for i in range(opt.max_palettes + 1)
         
     | 
| 169 | 
         
            +
                ]
         
     | 
| 170 | 
         
            +
                label = 'Background' if idx == 0 else f'Palette {idx}'
         
     | 
| 171 | 
         
            +
                updates.extend([
         
     | 
| 172 | 
         
            +
                    gr.update(value=button, interactive=(idx > 0)),
         
     | 
| 173 | 
         
            +
                    gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
         
     | 
| 174 | 
         
            +
                    gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
         
     | 
| 175 | 
         
            +
                    (
         
     | 
| 176 | 
         
            +
                        gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
         
     | 
| 177 | 
         
            +
                        gr.update(value=opt.default_mask_strength, interactive=False)
         
     | 
| 178 | 
         
            +
                    ),
         
     | 
| 179 | 
         
            +
                    (
         
     | 
| 180 | 
         
            +
                        gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
         
     | 
| 181 | 
         
            +
                        gr.update(value=opt.default_prompt_strength, interactive=False)
         
     | 
| 182 | 
         
            +
                    ),
         
     | 
| 183 | 
         
            +
                    (
         
     | 
| 184 | 
         
            +
                        gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
         
     | 
| 185 | 
         
            +
                        gr.update(value=opt.default_mask_std, interactive=False)
         
     | 
| 186 | 
         
            +
                    ),
         
     | 
| 187 | 
         
            +
                ])
         
     | 
| 188 | 
         
            +
                return updates
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            def change_prompt_strength(state, strength):
         
     | 
| 192 | 
         
            +
                if state.current_palette == 0:
         
     | 
| 193 | 
         
            +
                    return state
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                state.prompt_strengths[state.current_palette - 1] = strength
         
     | 
| 196 | 
         
            +
                if opt.verbose:
         
     | 
| 197 | 
         
            +
                    log_state(state)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                return state
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            def change_std(state, std):
         
     | 
| 203 | 
         
            +
                if state.current_palette == 0:
         
     | 
| 204 | 
         
            +
                    return state
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                state.mask_stds[state.current_palette - 1] = std
         
     | 
| 207 | 
         
            +
                if opt.verbose:
         
     | 
| 208 | 
         
            +
                    log_state(state)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                return state
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            def change_mask_strength(state, strength):
         
     | 
| 214 | 
         
            +
                if state.current_palette == 0:
         
     | 
| 215 | 
         
            +
                    return state
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                state.mask_strengths[state.current_palette - 1] = strength
         
     | 
| 218 | 
         
            +
                if opt.verbose:
         
     | 
| 219 | 
         
            +
                    log_state(state)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                return state
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            def reset_seed(state, seed):
         
     | 
| 225 | 
         
            +
                state.seed = seed
         
     | 
| 226 | 
         
            +
                if opt.verbose:
         
     | 
| 227 | 
         
            +
                    log_state(state)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                return state
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            def rename_prompt(state, name):
         
     | 
| 232 | 
         
            +
                state.prompt_names[state.current_palette] = name
         
     | 
| 233 | 
         
            +
                if opt.verbose:
         
     | 
| 234 | 
         
            +
                    log_state(state)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                return [state] + [
         
     | 
| 237 | 
         
            +
                    gr.update() if i != state.current_palette else gr.update(value=name)
         
     | 
| 238 | 
         
            +
                    for i in range(opt.max_palettes + 1)
         
     | 
| 239 | 
         
            +
                ]
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
            def change_prompt(state, prompt):
         
     | 
| 243 | 
         
            +
                state.prompts[state.current_palette] = prompt
         
     | 
| 244 | 
         
            +
                if opt.verbose:
         
     | 
| 245 | 
         
            +
                    log_state(state)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                return state
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def change_neg_prompt(state, neg_prompt):
         
     | 
| 251 | 
         
            +
                state.neg_prompts[state.current_palette] = neg_prompt
         
     | 
| 252 | 
         
            +
                if opt.verbose:
         
     | 
| 253 | 
         
            +
                    log_state(state)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                return state
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            def select_model(state, model_id):
         
     | 
| 259 | 
         
            +
                state.model_id = model_id
         
     | 
| 260 | 
         
            +
                if opt.verbose:
         
     | 
| 261 | 
         
            +
                    log_state(state)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                return state
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
            def select_style(state, style_name):
         
     | 
| 267 | 
         
            +
                state.style_name = style_name
         
     | 
| 268 | 
         
            +
                if opt.verbose:
         
     | 
| 269 | 
         
            +
                    log_state(state)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                return state
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            def select_quality(state, quality_name):
         
     | 
| 275 | 
         
            +
                state.quality_name = quality_name
         
     | 
| 276 | 
         
            +
                if opt.verbose:
         
     | 
| 277 | 
         
            +
                    log_state(state)
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                return state
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
            def import_state(state, json_text):
         
     | 
| 283 | 
         
            +
                current_palette = state.current_palette
         
     | 
| 284 | 
         
            +
                # active_palettes = state.active_palettes
         
     | 
| 285 | 
         
            +
                state = argparse.Namespace(**json.loads(json_text))
         
     | 
| 286 | 
         
            +
                state.active_palettes = opt.max_palettes
         
     | 
| 287 | 
         
            +
                return [state] + [
         
     | 
| 288 | 
         
            +
                    gr.update(value=v, visible=True) for v in state.prompt_names
         
     | 
| 289 | 
         
            +
                ] + [
         
     | 
| 290 | 
         
            +
                    # state.model_id,
         
     | 
| 291 | 
         
            +
                    # state.style_name,
         
     | 
| 292 | 
         
            +
                    # state.quality_name,
         
     | 
| 293 | 
         
            +
                    state.prompts[current_palette],
         
     | 
| 294 | 
         
            +
                    state.prompt_names[current_palette],
         
     | 
| 295 | 
         
            +
                    state.neg_prompts[current_palette],
         
     | 
| 296 | 
         
            +
                    state.prompt_strengths[current_palette - 1],
         
     | 
| 297 | 
         
            +
                    state.mask_strengths[current_palette - 1],
         
     | 
| 298 | 
         
            +
                    state.mask_stds[current_palette - 1],
         
     | 
| 299 | 
         
            +
                    state.seed,
         
     | 
| 300 | 
         
            +
                ]
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
            ### Main worker
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            def generate(state, *args, **kwargs):
         
     | 
| 306 | 
         
            +
                return models[state.model_id](*args, **kwargs)
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            def run(state, drawpad):
         
     | 
| 311 | 
         
            +
                seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
         
     | 
| 312 | 
         
            +
                print('Generate!')
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                background = drawpad['background'].convert('RGBA')
         
     | 
| 315 | 
         
            +
                inpainting_mode = np.asarray(background).sum() != 0
         
     | 
| 316 | 
         
            +
                print('Inpainting mode: ', inpainting_mode)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
         
     | 
| 319 | 
         
            +
                foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
         
     | 
| 320 | 
         
            +
                user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                palette = torch.tensor([
         
     | 
| 323 | 
         
            +
                    tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
         
     | 
| 324 | 
         
            +
                    for s in opt.colors[1:]
         
     | 
| 325 | 
         
            +
                ]) # (N, 3)
         
     | 
| 326 | 
         
            +
                masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
         
     | 
| 327 | 
         
            +
                has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
         
     | 
| 328 | 
         
            +
                print('Has mask: ', has_masks)
         
     | 
| 329 | 
         
            +
                masks = masks * foreground_mask
         
     | 
| 330 | 
         
            +
                masks = masks[has_masks]
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                if inpainting_mode:
         
     | 
| 333 | 
         
            +
                    prompts = [state.prompts[v + 1] for v in has_masks]
         
     | 
| 334 | 
         
            +
                    negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
         
     | 
| 335 | 
         
            +
                    mask_strengths = [state.mask_strengths[v] for v in has_masks]
         
     | 
| 336 | 
         
            +
                    mask_stds = [state.mask_stds[v] for v in has_masks]
         
     | 
| 337 | 
         
            +
                    prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
         
     | 
| 338 | 
         
            +
                else:
         
     | 
| 339 | 
         
            +
                    masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
         
     | 
| 340 | 
         
            +
                    prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
         
     | 
| 341 | 
         
            +
                    negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
         
     | 
| 342 | 
         
            +
                    mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
         
     | 
| 343 | 
         
            +
                    mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
         
     | 
| 344 | 
         
            +
                    prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                prompts, negative_prompts = preprocess_prompts(
         
     | 
| 347 | 
         
            +
                    prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                return generate(
         
     | 
| 350 | 
         
            +
                    state,
         
     | 
| 351 | 
         
            +
                    prompts,
         
     | 
| 352 | 
         
            +
                    negative_prompts,
         
     | 
| 353 | 
         
            +
                    masks=masks,
         
     | 
| 354 | 
         
            +
                    mask_strengths=mask_strengths,
         
     | 
| 355 | 
         
            +
                    mask_stds=mask_stds,
         
     | 
| 356 | 
         
            +
                    prompt_strengths=prompt_strengths,
         
     | 
| 357 | 
         
            +
                    background=background.convert('RGB'),
         
     | 
| 358 | 
         
            +
                    background_prompt=state.prompts[0],
         
     | 
| 359 | 
         
            +
                    background_negative_prompt=state.neg_prompts[0],
         
     | 
| 360 | 
         
            +
                    height=opt.height,
         
     | 
| 361 | 
         
            +
                    width=opt.width,
         
     | 
| 362 | 
         
            +
                    bootstrap_steps=2,
         
     | 
| 363 | 
         
            +
                    guidance_scale=0,
         
     | 
| 364 | 
         
            +
                )
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
            ### Load examples
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
            root = pathlib.Path(__file__).parent
         
     | 
| 372 | 
         
            +
            print(root)
         
     | 
| 373 | 
         
            +
            example_root = os.path.join(root, 'examples')
         
     | 
| 374 | 
         
            +
            example_images = glob.glob(os.path.join(example_root, '*.webp'))
         
     | 
| 375 | 
         
            +
            example_images = [Image.open(i) for i in example_images]
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
            with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
         
     | 
| 378 | 
         
            +
                prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
            with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
         
     | 
| 381 | 
         
            +
                prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
            with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
         
     | 
| 384 | 
         
            +
                prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
            with open(os.path.join(example_root, 'prompt_props.txt')) as f:
         
     | 
| 387 | 
         
            +
                prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
         
     | 
| 388 | 
         
            +
                prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
            prompt_background = lambda: random.choice(prompts_background)
         
     | 
| 391 | 
         
            +
            prompt_girl = lambda: random.choice(prompts_girl)
         
     | 
| 392 | 
         
            +
            prompt_boy = lambda: random.choice(prompts_boy)
         
     | 
| 393 | 
         
            +
            prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
            ### Main application
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
            css = f"""
         
     | 
| 399 | 
         
            +
            #run-button {{
         
     | 
| 400 | 
         
            +
                font-size: 30pt;
         
     | 
| 401 | 
         
            +
                background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
         
     | 
| 402 | 
         
            +
                margin: 0;
         
     | 
| 403 | 
         
            +
                padding: 15px 45px;
         
     | 
| 404 | 
         
            +
                text-align: center;
         
     | 
| 405 | 
         
            +
                text-transform: uppercase;
         
     | 
| 406 | 
         
            +
                transition: 0.5s;
         
     | 
| 407 | 
         
            +
                background-size: 200% auto;
         
     | 
| 408 | 
         
            +
                color: white;
         
     | 
| 409 | 
         
            +
                box-shadow: 0 0 20px #eee;
         
     | 
| 410 | 
         
            +
                border-radius: 10px;
         
     | 
| 411 | 
         
            +
                display: block;
         
     | 
| 412 | 
         
            +
                background-position: right center;
         
     | 
| 413 | 
         
            +
            }}
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
            #run-button:hover {{
         
     | 
| 416 | 
         
            +
                background-position: left center;
         
     | 
| 417 | 
         
            +
                color: #fff;
         
     | 
| 418 | 
         
            +
                text-decoration: none;
         
     | 
| 419 | 
         
            +
            }}
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            #semantic-palette {{
         
     | 
| 422 | 
         
            +
                border-style: solid;
         
     | 
| 423 | 
         
            +
                border-width: 0.2em;
         
     | 
| 424 | 
         
            +
                border-color: #eee;
         
     | 
| 425 | 
         
            +
            }}
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
            #semantic-palette:hover {{
         
     | 
| 428 | 
         
            +
                box-shadow: 0 0 20px #eee;
         
     | 
| 429 | 
         
            +
            }}
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
            #output-screen {{
         
     | 
| 432 | 
         
            +
                width: 100%;
         
     | 
| 433 | 
         
            +
                aspect-ratio: {opt.width} / {opt.height};
         
     | 
| 434 | 
         
            +
            }}
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
            .layer-wrap {{
         
     | 
| 437 | 
         
            +
                display: none;
         
     | 
| 438 | 
         
            +
            }}
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
            .rainbow {{
         
     | 
| 441 | 
         
            +
                text-align: center;
         
     | 
| 442 | 
         
            +
                text-decoration: underline;
         
     | 
| 443 | 
         
            +
                font-size: 32px;
         
     | 
| 444 | 
         
            +
                font-family: monospace;
         
     | 
| 445 | 
         
            +
                letter-spacing: 5px;
         
     | 
| 446 | 
         
            +
            }}
         
     | 
| 447 | 
         
            +
            .rainbow_text_animated {{
         
     | 
| 448 | 
         
            +
                background: linear-gradient(to right, #6666ff, #0099ff , #00ff00, #ff3399, #6666ff);
         
     | 
| 449 | 
         
            +
                -webkit-background-clip: text;
         
     | 
| 450 | 
         
            +
                background-clip: text;
         
     | 
| 451 | 
         
            +
                color: transparent;
         
     | 
| 452 | 
         
            +
                animation: rainbow_animation 6s ease-in-out infinite;
         
     | 
| 453 | 
         
            +
                background-size: 400% 100%;
         
     | 
| 454 | 
         
            +
            }}
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
            @keyframes rainbow_animation {{
         
     | 
| 457 | 
         
            +
                0%,100% {{
         
     | 
| 458 | 
         
            +
                    background-position: 0 0;
         
     | 
| 459 | 
         
            +
                }}
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                50% {{
         
     | 
| 462 | 
         
            +
                    background-position: 100% 0;
         
     | 
| 463 | 
         
            +
                }}
         
     | 
| 464 | 
         
            +
            }}
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
            .gallery {{
         
     | 
| 467 | 
         
            +
              --z: 16px;  /* control the zig-zag  */
         
     | 
| 468 | 
         
            +
              --s: 144px; /* control the size */
         
     | 
| 469 | 
         
            +
              --g: 4px;   /* control the gap */
         
     | 
| 470 | 
         
            +
              
         
     | 
| 471 | 
         
            +
              display: grid;
         
     | 
| 472 | 
         
            +
              gap: var(--g);
         
     | 
| 473 | 
         
            +
              width: calc(2*var(--s) + var(--g));
         
     | 
| 474 | 
         
            +
              grid-auto-flow: column;
         
     | 
| 475 | 
         
            +
            }}
         
     | 
| 476 | 
         
            +
            .gallery > a {{
         
     | 
| 477 | 
         
            +
              width: 0;
         
     | 
| 478 | 
         
            +
              min-width: calc(100% + var(--z)/2);
         
     | 
| 479 | 
         
            +
              height: var(--s);
         
     | 
| 480 | 
         
            +
              object-fit: cover;
         
     | 
| 481 | 
         
            +
              -webkit-mask: var(--mask);
         
     | 
| 482 | 
         
            +
                      mask: var(--mask);
         
     | 
| 483 | 
         
            +
              cursor: pointer;
         
     | 
| 484 | 
         
            +
              transition: .5s;
         
     | 
| 485 | 
         
            +
            }}
         
     | 
| 486 | 
         
            +
            .gallery > a:hover {{
         
     | 
| 487 | 
         
            +
              width: calc(var(--s)/2);
         
     | 
| 488 | 
         
            +
            }}
         
     | 
| 489 | 
         
            +
            .gallery > a:first-child {{
         
     | 
| 490 | 
         
            +
              place-self: start;
         
     | 
| 491 | 
         
            +
              clip-path: polygon(calc(2*var(--z)) 0,100% 0,100% 100%,0 100%);
         
     | 
| 492 | 
         
            +
              --mask: 
         
     | 
| 493 | 
         
            +
                conic-gradient(from -135deg at right,#0000,#000 1deg 89deg,#0000 90deg) 
         
     | 
| 494 | 
         
            +
                  50%/100% calc(2*var(--z)) repeat-y;
         
     | 
| 495 | 
         
            +
            }}
         
     | 
| 496 | 
         
            +
            .gallery > a:last-child {{
         
     | 
| 497 | 
         
            +
              place-self: end;
         
     | 
| 498 | 
         
            +
              clip-path: polygon(0 0,100% 0,calc(100% - 2*var(--z)) 100%,0 100%);
         
     | 
| 499 | 
         
            +
              --mask: 
         
     | 
| 500 | 
         
            +
                conic-gradient(from   45deg at left ,#0000,#000 1deg 89deg,#0000 90deg) 
         
     | 
| 501 | 
         
            +
                  50% calc(50% - var(--z))/100% calc(2*var(--z)) repeat-y;
         
     | 
| 502 | 
         
            +
            }}
         
     | 
| 503 | 
         
            +
            """
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
            for i in range(opt.max_palettes + 1):
         
     | 
| 506 | 
         
            +
                css = css + f"""
         
     | 
| 507 | 
         
            +
            .secondary#semantic-palette-{i} {{
         
     | 
| 508 | 
         
            +
                background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
         
     | 
| 509 | 
         
            +
                color: white;
         
     | 
| 510 | 
         
            +
            }}
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
            .primary#semantic-palette-{i} {{
         
     | 
| 513 | 
         
            +
                background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
         
     | 
| 514 | 
         
            +
                color: white;
         
     | 
| 515 | 
         
            +
            }}
         
     | 
| 516 | 
         
            +
            """
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
            with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                iface = argparse.Namespace()
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                def _define_state():
         
     | 
| 524 | 
         
            +
                    state = argparse.Namespace()
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    # Cursor.
         
     | 
| 527 | 
         
            +
                    state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
         
     | 
| 528 | 
         
            +
                    state.model_id = list(model_dict.keys())[0]
         
     | 
| 529 | 
         
            +
                    state.style_name = '(None)'
         
     | 
| 530 | 
         
            +
                    state.quality_name = '(None)' # 'Standard v3.1'
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                    # State variables (one-hot).
         
     | 
| 533 | 
         
            +
                    state.active_palettes = 1
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                    # Front-end initialized to the default values.
         
     | 
| 536 | 
         
            +
                    prompt_props_ = prompt_props()
         
     | 
| 537 | 
         
            +
                    state.prompt_names = [
         
     | 
| 538 | 
         
            +
                        '🌄 Background',
         
     | 
| 539 | 
         
            +
                        '👧 Girl',
         
     | 
| 540 | 
         
            +
                        '👦 Boy',
         
     | 
| 541 | 
         
            +
                    ] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
         
     | 
| 542 | 
         
            +
                    state.prompts = [
         
     | 
| 543 | 
         
            +
                        prompt_background(),
         
     | 
| 544 | 
         
            +
                        prompt_girl(),
         
     | 
| 545 | 
         
            +
                        prompt_boy(),
         
     | 
| 546 | 
         
            +
                    ] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
         
     | 
| 547 | 
         
            +
                    state.neg_prompts = [
         
     | 
| 548 | 
         
            +
                        opt.default_negative_prompt
         
     | 
| 549 | 
         
            +
                        + (', humans, humans, humans' if i == 0 else '')
         
     | 
| 550 | 
         
            +
                        for i in range(opt.max_palettes + 1)
         
     | 
| 551 | 
         
            +
                    ]
         
     | 
| 552 | 
         
            +
                    state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
         
     | 
| 553 | 
         
            +
                    state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
         
     | 
| 554 | 
         
            +
                    state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
         
     | 
| 555 | 
         
            +
                    state.seed = opt.seed
         
     | 
| 556 | 
         
            +
                    return state
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                state = gr.State(value=_define_state)
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                ### Demo user interface
         
     | 
| 562 | 
         
            +
             
     | 
| 563 | 
         
            +
                gr.HTML(
         
     | 
| 564 | 
         
            +
                    """
         
     | 
| 565 | 
         
            +
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
         
     | 
| 566 | 
         
            +
                <div>
         
     | 
| 567 | 
         
            +
                    <h1>🧠 Semantic Palette with <font class="rainbow rainbow_text_animated">Stable Diffusion 3</font> 🎨</h1>
         
     | 
| 568 | 
         
            +
                    <h5 style="margin: 0;">powered by</h5>
         
     | 
| 569 | 
         
            +
                    <h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
         
     | 
| 570 | 
         
            +
                    <h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
         
     | 
| 571 | 
         
            +
                    </br>
         
     | 
| 572 | 
         
            +
                    <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
         
     | 
| 573 | 
         
            +
                        <a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
         
     | 
| 574 | 
         
            +
                            <img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
         
     | 
| 575 | 
         
            +
                        </a>
         
     | 
| 576 | 
         
            +
                         
         
     | 
| 577 | 
         
            +
                        <a href='https://arxiv.org/abs/2403.09055'>
         
     | 
| 578 | 
         
            +
                            <img src="https://img.shields.io/badge/arXiv-2403.09055-red">
         
     | 
| 579 | 
         
            +
                        </a>
         
     | 
| 580 | 
         
            +
                         
         
     | 
| 581 | 
         
            +
                        <a href='https://github.com/ironjr/StreamMultiDiffusion'>
         
     | 
| 582 | 
         
            +
                            <img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
         
     | 
| 583 | 
         
            +
                        </a>
         
     | 
| 584 | 
         
            +
                         
         
     | 
| 585 | 
         
            +
                        <a href='https://twitter.com/_ironjr_'>
         
     | 
| 586 | 
         
            +
                            <img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
         
     | 
| 587 | 
         
            +
                        </a>
         
     | 
| 588 | 
         
            +
                         
         
     | 
| 589 | 
         
            +
                        <a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
         
     | 
| 590 | 
         
            +
                            <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
         
     | 
| 591 | 
         
            +
                        </a>
         
     | 
| 592 | 
         
            +
                         
         
     | 
| 593 | 
         
            +
                        <a href='https://huggingface.co/spaces/ironjr/StreamMultiDiffusion'>
         
     | 
| 594 | 
         
            +
                            <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-StreamMultiDiffusion-yellow'>
         
     | 
| 595 | 
         
            +
                        </a>
         
     | 
| 596 | 
         
            +
                         
         
     | 
| 597 | 
         
            +
                        <a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
         
     | 
| 598 | 
         
            +
                            <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD1.5-yellow'>
         
     | 
| 599 | 
         
            +
                        </a>
         
     | 
| 600 | 
         
            +
                         
         
     | 
| 601 | 
         
            +
                        <a href='https://huggingface.co/spaces/ironjr/SemanticPaletteXL'>
         
     | 
| 602 | 
         
            +
                            <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SDXL-yellow'>
         
     | 
| 603 | 
         
            +
                        </a>
         
     | 
| 604 | 
         
            +
                         
         
     | 
| 605 | 
         
            +
                        <a href='https://huggingface.co/spaces/ironjr/SemanticPalette3'>
         
     | 
| 606 | 
         
            +
                            <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD3-yellow'>
         
     | 
| 607 | 
         
            +
                        </a>
         
     | 
| 608 | 
         
            +
                    </div>
         
     | 
| 609 | 
         
            +
                </div>
         
     | 
| 610 | 
         
            +
            </div>
         
     | 
| 611 | 
         
            +
            <div>
         
     | 
| 612 | 
         
            +
                </br>
         
     | 
| 613 | 
         
            +
            </div>
         
     | 
| 614 | 
         
            +
                    """
         
     | 
| 615 | 
         
            +
                )
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
                with gr.Row():
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
                    iface.image_slot = gr.Image(
         
     | 
| 620 | 
         
            +
                        interactive=False,
         
     | 
| 621 | 
         
            +
                        show_label=False,
         
     | 
| 622 | 
         
            +
                        show_download_button=True,
         
     | 
| 623 | 
         
            +
                        type='pil',
         
     | 
| 624 | 
         
            +
                        label='Generated Result',
         
     | 
| 625 | 
         
            +
                        elem_id='output-screen',
         
     | 
| 626 | 
         
            +
                        value=lambda: random.choice(example_images),
         
     | 
| 627 | 
         
            +
                    )
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                with gr.Row():
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
                    with gr.Column(scale=1):
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                        with gr.Group(elem_id='semantic-palette'):
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                            gr.HTML(
         
     | 
| 636 | 
         
            +
                                """
         
     | 
| 637 | 
         
            +
            <div style="justify-content: center; align-items: center;">
         
     | 
| 638 | 
         
            +
                <br/>
         
     | 
| 639 | 
         
            +
                <h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
         
     | 
| 640 | 
         
            +
                <br/>
         
     | 
| 641 | 
         
            +
            </div>
         
     | 
| 642 | 
         
            +
                                """
         
     | 
| 643 | 
         
            +
                            )
         
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
                            iface.btn_semantics = [gr.Button(
         
     | 
| 646 | 
         
            +
                                value=state.value.prompt_names[0],
         
     | 
| 647 | 
         
            +
                                variant='primary',
         
     | 
| 648 | 
         
            +
                                elem_id='semantic-palette-0',
         
     | 
| 649 | 
         
            +
                            )]
         
     | 
| 650 | 
         
            +
                            for i in range(opt.max_palettes):
         
     | 
| 651 | 
         
            +
                                iface.btn_semantics.append(gr.Button(
         
     | 
| 652 | 
         
            +
                                    value=state.value.prompt_names[i + 1],
         
     | 
| 653 | 
         
            +
                                    variant='secondary',
         
     | 
| 654 | 
         
            +
                                    visible=(i < state.value.active_palettes),
         
     | 
| 655 | 
         
            +
                                    elem_id=f'semantic-palette-{i + 1}'
         
     | 
| 656 | 
         
            +
                                ))
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                            iface.btn_add_palette = gr.Button(
         
     | 
| 659 | 
         
            +
                                value='Create New Semantic Brush',
         
     | 
| 660 | 
         
            +
                                variant='primary',
         
     | 
| 661 | 
         
            +
                            )
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                        with gr.Accordion(label='Import/Export Semantic Palette', open=False):
         
     | 
| 664 | 
         
            +
                            iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
         
     | 
| 665 | 
         
            +
                            iface.json_state_export = gr.JSON(label='Exported Palette')
         
     | 
| 666 | 
         
            +
                            iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
         
     | 
| 667 | 
         
            +
                            iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                        gr.HTML(
         
     | 
| 670 | 
         
            +
                            """
         
     | 
| 671 | 
         
            +
            <div>
         
     | 
| 672 | 
         
            +
            </br>
         
     | 
| 673 | 
         
            +
            </div>
         
     | 
| 674 | 
         
            +
            <div style="justify-content: center; align-items: center;">
         
     | 
| 675 | 
         
            +
            <h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
         
     | 
| 676 | 
         
            +
            </br>
         
     | 
| 677 | 
         
            +
            <div style="justify-content: center; align-items: left; text-align: left;">
         
     | 
| 678 | 
         
            +
                <p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
         
     | 
| 679 | 
         
            +
                <p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
         
     | 
| 680 | 
         
            +
                <p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
         
     | 
| 681 | 
         
            +
                <p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
         
     | 
| 682 | 
         
            +
                <p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
         
     | 
| 683 | 
         
            +
                <p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
         
     | 
| 684 | 
         
            +
            </div>
         
     | 
| 685 | 
         
            +
            </div>
         
     | 
| 686 | 
         
            +
                            """
         
     | 
| 687 | 
         
            +
                        )
         
     | 
| 688 | 
         
            +
             
     | 
| 689 | 
         
            +
                        gr.HTML(
         
     | 
| 690 | 
         
            +
                            """
         
     | 
| 691 | 
         
            +
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
         
     | 
| 692 | 
         
            +
            <h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
         
     | 
| 693 | 
         
            +
            </div>
         
     | 
| 694 | 
         
            +
                            """
         
     | 
| 695 | 
         
            +
                        )
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                        gr.DuplicateButton()
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                    with gr.Column(scale=4):
         
     | 
| 700 | 
         
            +
             
     | 
| 701 | 
         
            +
                        with gr.Row():
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
                            with gr.Column(scale=3):
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                                iface.ctrl_semantic = gr.ImageEditor(
         
     | 
| 706 | 
         
            +
                                    image_mode='RGBA',
         
     | 
| 707 | 
         
            +
                                    sources=['upload', 'clipboard', 'webcam'],
         
     | 
| 708 | 
         
            +
                                    transforms=['crop'],
         
     | 
| 709 | 
         
            +
                                    crop_size=(opt.width, opt.height),
         
     | 
| 710 | 
         
            +
                                    brush=gr.Brush(
         
     | 
| 711 | 
         
            +
                                        colors=opt.colors[1:],
         
     | 
| 712 | 
         
            +
                                        color_mode="fixed",
         
     | 
| 713 | 
         
            +
                                    ),
         
     | 
| 714 | 
         
            +
                                    layers=False,
         
     | 
| 715 | 
         
            +
                                    canvas_size=(opt.width, opt.height),
         
     | 
| 716 | 
         
            +
                                    type='pil',
         
     | 
| 717 | 
         
            +
                                    label='Semantic Drawpad',
         
     | 
| 718 | 
         
            +
                                    elem_id='drawpad',
         
     | 
| 719 | 
         
            +
                                )
         
     | 
| 720 | 
         
            +
             
     | 
| 721 | 
         
            +
                            with gr.Column(scale=1):
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                                iface.btn_generate = gr.Button(
         
     | 
| 724 | 
         
            +
                                    value='Generate!',
         
     | 
| 725 | 
         
            +
                                    variant='primary',
         
     | 
| 726 | 
         
            +
                                    # scale=1,
         
     | 
| 727 | 
         
            +
                                    elem_id='run-button'
         
     | 
| 728 | 
         
            +
                                )
         
     | 
| 729 | 
         
            +
             
     | 
| 730 | 
         
            +
             
     | 
| 731 | 
         
            +
                                gr.HTML(
         
     | 
| 732 | 
         
            +
                                    """
         
     | 
| 733 | 
         
            +
            <h3 style="text-align: center;">Try other demos in HF 🤗 Space!</h3>
         
     | 
| 734 | 
         
            +
            <div style="display: flex; justify-content: center; text-align: center;">
         
     | 
| 735 | 
         
            +
                <div><b style="color: #2692F3">Semantic Palette<br>Animagine XL 3.1</b></div>
         
     | 
| 736 | 
         
            +
                <div style="margin-left: 10px; margin-right: 10px; margin-top: 8px">or</div>
         
     | 
| 737 | 
         
            +
                <div><b style="color: #F89E12">Official Demo of<br>StreamMultiDiffusion</b></div>
         
     | 
| 738 | 
         
            +
            </div>
         
     | 
| 739 | 
         
            +
            <div style="display: inline-block; margin-top: 10px">
         
     | 
| 740 | 
         
            +
                <div class="gallery">
         
     | 
| 741 | 
         
            +
                    <a href="https://huggingface.co/spaces/ironjr/SemanticPaletteXL" target="_blank">
         
     | 
| 742 | 
         
            +
                        <img alt="AnimagineXL3.1 Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/sdxl.webp?raw=true">
         
     | 
| 743 | 
         
            +
                    </a>
         
     | 
| 744 | 
         
            +
                    <a href="https://huggingface.co/spaces/ironjr/StreamMultiDiffusion" target="_blank">
         
     | 
| 745 | 
         
            +
                        <img alt="StreamMultiDiffusion Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/smd.gif?raw=true">
         
     | 
| 746 | 
         
            +
                    </a>
         
     | 
| 747 | 
         
            +
                </div>
         
     | 
| 748 | 
         
            +
            </div>
         
     | 
| 749 | 
         
            +
                                    """
         
     | 
| 750 | 
         
            +
                                )
         
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
                                # iface.model_select = gr.Radio(
         
     | 
| 753 | 
         
            +
                                #     list(model_dict.keys()),
         
     | 
| 754 | 
         
            +
                                #     label='Stable Diffusion Checkpoint',
         
     | 
| 755 | 
         
            +
                                #     info='Choose your favorite style.',
         
     | 
| 756 | 
         
            +
                                #     value=state.value.model_id,
         
     | 
| 757 | 
         
            +
                                # )
         
     | 
| 758 | 
         
            +
             
     | 
| 759 | 
         
            +
                                # with gr.Accordion(label='Prompt Engineering', open=True):
         
     | 
| 760 | 
         
            +
                                #     iface.quality_select = gr.Dropdown(
         
     | 
| 761 | 
         
            +
                                #         label='Quality Presets',
         
     | 
| 762 | 
         
            +
                                #         interactive=True,
         
     | 
| 763 | 
         
            +
                                #         choices=list(_quality_dict.keys()),
         
     | 
| 764 | 
         
            +
                                #         value='Standard v3.1',
         
     | 
| 765 | 
         
            +
                                #     )
         
     | 
| 766 | 
         
            +
                                #     iface.style_select = gr.Radio(
         
     | 
| 767 | 
         
            +
                                #         label='Style Preset',
         
     | 
| 768 | 
         
            +
                                #         container=True,
         
     | 
| 769 | 
         
            +
                                #         interactive=True,
         
     | 
| 770 | 
         
            +
                                #         choices=list(_style_dict.keys()),
         
     | 
| 771 | 
         
            +
                                #         value='(None)',
         
     | 
| 772 | 
         
            +
                                #     )
         
     | 
| 773 | 
         
            +
             
     | 
| 774 | 
         
            +
                        with gr.Group(elem_id='control-panel'):
         
     | 
| 775 | 
         
            +
             
     | 
| 776 | 
         
            +
                            with gr.Row():
         
     | 
| 777 | 
         
            +
                                iface.tbox_prompt = gr.Textbox(
         
     | 
| 778 | 
         
            +
                                    label='Edit Prompt for Background',
         
     | 
| 779 | 
         
            +
                                    info='What do you want to draw?',
         
     | 
| 780 | 
         
            +
                                    value=state.value.prompts[0],
         
     | 
| 781 | 
         
            +
                                    placeholder=lambda: random.choice(prompt_suggestions),
         
     | 
| 782 | 
         
            +
                                    scale=2,
         
     | 
| 783 | 
         
            +
                                )
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
                                iface.tbox_name = gr.Textbox(
         
     | 
| 786 | 
         
            +
                                    label='Edit Brush Name',
         
     | 
| 787 | 
         
            +
                                    info='Just for your convenience.',
         
     | 
| 788 | 
         
            +
                                    value=state.value.prompt_names[0],
         
     | 
| 789 | 
         
            +
                                    placeholder='🌄 Background',
         
     | 
| 790 | 
         
            +
                                    scale=1,
         
     | 
| 791 | 
         
            +
                                )
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                            with gr.Row():
         
     | 
| 794 | 
         
            +
                                iface.tbox_neg_prompt = gr.Textbox(
         
     | 
| 795 | 
         
            +
                                    label='Edit Negative Prompt for Background',
         
     | 
| 796 | 
         
            +
                                    info='Add unwanted objects for this semantic brush.',
         
     | 
| 797 | 
         
            +
                                    value=opt.default_negative_prompt,
         
     | 
| 798 | 
         
            +
                                    scale=2,
         
     | 
| 799 | 
         
            +
                                )
         
     | 
| 800 | 
         
            +
             
     | 
| 801 | 
         
            +
                                iface.slider_strength = gr.Slider(
         
     | 
| 802 | 
         
            +
                                    label='Prompt Strength',
         
     | 
| 803 | 
         
            +
                                    info='Blends fg & bg in the prompt level, >0.8 Preferred.',
         
     | 
| 804 | 
         
            +
                                    minimum=0.5,
         
     | 
| 805 | 
         
            +
                                    maximum=1.0,
         
     | 
| 806 | 
         
            +
                                    value=opt.default_prompt_strength,
         
     | 
| 807 | 
         
            +
                                    scale=1,
         
     | 
| 808 | 
         
            +
                                )
         
     | 
| 809 | 
         
            +
             
     | 
| 810 | 
         
            +
                            with gr.Row():
         
     | 
| 811 | 
         
            +
                                iface.slider_alpha = gr.Slider(
         
     | 
| 812 | 
         
            +
                                    label='Mask Alpha',
         
     | 
| 813 | 
         
            +
                                    info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
         
     | 
| 814 | 
         
            +
                                    minimum=0.5,
         
     | 
| 815 | 
         
            +
                                    maximum=1.0,
         
     | 
| 816 | 
         
            +
                                    value=opt.default_mask_strength,
         
     | 
| 817 | 
         
            +
                                )
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
                                iface.slider_std = gr.Slider(
         
     | 
| 820 | 
         
            +
                                    label='Mask Blur STD',
         
     | 
| 821 | 
         
            +
                                    info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
         
     | 
| 822 | 
         
            +
                                    minimum=0.0001,
         
     | 
| 823 | 
         
            +
                                    maximum=100.0,
         
     | 
| 824 | 
         
            +
                                    value=opt.default_mask_std,
         
     | 
| 825 | 
         
            +
                                )
         
     | 
| 826 | 
         
            +
             
     | 
| 827 | 
         
            +
                                iface.slider_seed = gr.Slider(
         
     | 
| 828 | 
         
            +
                                    label='Seed',
         
     | 
| 829 | 
         
            +
                                    info='The global seed.',
         
     | 
| 830 | 
         
            +
                                    minimum=-1,
         
     | 
| 831 | 
         
            +
                                    maximum=2147483647,
         
     | 
| 832 | 
         
            +
                                    step=1,
         
     | 
| 833 | 
         
            +
                                    value=opt.seed,
         
     | 
| 834 | 
         
            +
                                )
         
     | 
| 835 | 
         
            +
             
     | 
| 836 | 
         
            +
                ### Attach event handlers
         
     | 
| 837 | 
         
            +
             
     | 
| 838 | 
         
            +
                for idx, btn in enumerate(iface.btn_semantics):
         
     | 
| 839 | 
         
            +
                    btn.click(
         
     | 
| 840 | 
         
            +
                        fn=partial(select_palette, idx=idx),
         
     | 
| 841 | 
         
            +
                        inputs=[state, btn],
         
     | 
| 842 | 
         
            +
                        outputs=[state] + iface.btn_semantics + [
         
     | 
| 843 | 
         
            +
                            iface.tbox_name,
         
     | 
| 844 | 
         
            +
                            iface.tbox_prompt,
         
     | 
| 845 | 
         
            +
                            iface.tbox_neg_prompt,
         
     | 
| 846 | 
         
            +
                            iface.slider_alpha,
         
     | 
| 847 | 
         
            +
                            iface.slider_strength,
         
     | 
| 848 | 
         
            +
                            iface.slider_std,
         
     | 
| 849 | 
         
            +
                        ],
         
     | 
| 850 | 
         
            +
                        api_name=f'select_palette_{idx}',
         
     | 
| 851 | 
         
            +
                    )
         
     | 
| 852 | 
         
            +
             
     | 
| 853 | 
         
            +
                iface.btn_add_palette.click(
         
     | 
| 854 | 
         
            +
                    fn=add_palette,
         
     | 
| 855 | 
         
            +
                    inputs=state,
         
     | 
| 856 | 
         
            +
                    outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
         
     | 
| 857 | 
         
            +
                    api_name='create_new',
         
     | 
| 858 | 
         
            +
                )
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                iface.btn_generate.click(
         
     | 
| 861 | 
         
            +
                    fn=run,
         
     | 
| 862 | 
         
            +
                    inputs=[state, iface.ctrl_semantic],
         
     | 
| 863 | 
         
            +
                    outputs=iface.image_slot,
         
     | 
| 864 | 
         
            +
                    api_name='run',
         
     | 
| 865 | 
         
            +
                )
         
     | 
| 866 | 
         
            +
             
     | 
| 867 | 
         
            +
                iface.slider_alpha.input(
         
     | 
| 868 | 
         
            +
                    fn=change_mask_strength,
         
     | 
| 869 | 
         
            +
                    inputs=[state, iface.slider_alpha],
         
     | 
| 870 | 
         
            +
                    outputs=state,
         
     | 
| 871 | 
         
            +
                    api_name='change_alpha',
         
     | 
| 872 | 
         
            +
                )
         
     | 
| 873 | 
         
            +
                iface.slider_std.input(
         
     | 
| 874 | 
         
            +
                    fn=change_std,
         
     | 
| 875 | 
         
            +
                    inputs=[state, iface.slider_std],
         
     | 
| 876 | 
         
            +
                    outputs=state,
         
     | 
| 877 | 
         
            +
                    api_name='change_std',
         
     | 
| 878 | 
         
            +
                )
         
     | 
| 879 | 
         
            +
                iface.slider_strength.input(
         
     | 
| 880 | 
         
            +
                    fn=change_prompt_strength,
         
     | 
| 881 | 
         
            +
                    inputs=[state, iface.slider_strength],
         
     | 
| 882 | 
         
            +
                    outputs=state,
         
     | 
| 883 | 
         
            +
                    api_name='change_strength',
         
     | 
| 884 | 
         
            +
                )
         
     | 
| 885 | 
         
            +
                iface.slider_seed.input(
         
     | 
| 886 | 
         
            +
                    fn=reset_seed,
         
     | 
| 887 | 
         
            +
                    inputs=[state, iface.slider_seed],
         
     | 
| 888 | 
         
            +
                    outputs=state,
         
     | 
| 889 | 
         
            +
                    api_name='reset_seed',
         
     | 
| 890 | 
         
            +
                )
         
     | 
| 891 | 
         
            +
             
     | 
| 892 | 
         
            +
                iface.tbox_name.input(
         
     | 
| 893 | 
         
            +
                    fn=rename_prompt,
         
     | 
| 894 | 
         
            +
                    inputs=[state, iface.tbox_name],
         
     | 
| 895 | 
         
            +
                    outputs=[state] + iface.btn_semantics,
         
     | 
| 896 | 
         
            +
                    api_name='prompt_rename',
         
     | 
| 897 | 
         
            +
                )
         
     | 
| 898 | 
         
            +
                iface.tbox_prompt.input(
         
     | 
| 899 | 
         
            +
                    fn=change_prompt,
         
     | 
| 900 | 
         
            +
                    inputs=[state, iface.tbox_prompt],
         
     | 
| 901 | 
         
            +
                    outputs=state,
         
     | 
| 902 | 
         
            +
                    api_name='prompt_edit',
         
     | 
| 903 | 
         
            +
                )
         
     | 
| 904 | 
         
            +
                iface.tbox_neg_prompt.input(
         
     | 
| 905 | 
         
            +
                    fn=change_neg_prompt,
         
     | 
| 906 | 
         
            +
                    inputs=[state, iface.tbox_neg_prompt],
         
     | 
| 907 | 
         
            +
                    outputs=state,
         
     | 
| 908 | 
         
            +
                    api_name='neg_prompt_edit',
         
     | 
| 909 | 
         
            +
                )
         
     | 
| 910 | 
         
            +
             
     | 
| 911 | 
         
            +
                # iface.model_select.change(
         
     | 
| 912 | 
         
            +
                #     fn=select_model,
         
     | 
| 913 | 
         
            +
                #     inputs=[state, iface.model_select],
         
     | 
| 914 | 
         
            +
                #     outputs=state,
         
     | 
| 915 | 
         
            +
                #     api_name='model_select',
         
     | 
| 916 | 
         
            +
                # )
         
     | 
| 917 | 
         
            +
                # iface.style_select.change(
         
     | 
| 918 | 
         
            +
                #     fn=select_style,
         
     | 
| 919 | 
         
            +
                #     inputs=[state, iface.style_select],
         
     | 
| 920 | 
         
            +
                #     outputs=state,
         
     | 
| 921 | 
         
            +
                #     api_name='style_select',
         
     | 
| 922 | 
         
            +
                # )
         
     | 
| 923 | 
         
            +
                # iface.quality_select.change(
         
     | 
| 924 | 
         
            +
                #     fn=select_quality,
         
     | 
| 925 | 
         
            +
                #     inputs=[state, iface.quality_select],
         
     | 
| 926 | 
         
            +
                #     outputs=state,
         
     | 
| 927 | 
         
            +
                #     api_name='quality_select',
         
     | 
| 928 | 
         
            +
                # )
         
     | 
| 929 | 
         
            +
             
     | 
| 930 | 
         
            +
                iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
         
     | 
| 931 | 
         
            +
                iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
         
     | 
| 932 | 
         
            +
                    state,
         
     | 
| 933 | 
         
            +
                    *iface.btn_semantics,
         
     | 
| 934 | 
         
            +
                    # iface.model_select,
         
     | 
| 935 | 
         
            +
                    # iface.style_select,
         
     | 
| 936 | 
         
            +
                    # iface.quality_select,
         
     | 
| 937 | 
         
            +
                    iface.tbox_prompt,
         
     | 
| 938 | 
         
            +
                    iface.tbox_name,
         
     | 
| 939 | 
         
            +
                    iface.tbox_neg_prompt,
         
     | 
| 940 | 
         
            +
                    iface.slider_strength,
         
     | 
| 941 | 
         
            +
                    iface.slider_alpha,
         
     | 
| 942 | 
         
            +
                    iface.slider_std,
         
     | 
| 943 | 
         
            +
                    iface.slider_seed,
         
     | 
| 944 | 
         
            +
                ])
         
     | 
| 945 | 
         
            +
             
     | 
| 946 | 
         
            +
             
     | 
| 947 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 948 | 
         
            +
                demo.launch(server_port=opt.port)
         
     | 
    	
        examples/prompt_background.txt
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
         
     | 
| 2 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
         
     | 
| 3 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
         
     | 
| 4 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, galaxy
         
     | 
| 5 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, sky, daylight
         
     | 
| 6 | 
         
            +
            Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
         
     | 
| 7 | 
         
            +
            Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
         
     | 
| 8 | 
         
            +
            Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
         
     | 
    	
        examples/prompt_background_advanced.txt
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        examples/prompt_boy.txt
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            1boy, looking at viewer, brown hair, blue shirt
         
     | 
| 2 | 
         
            +
            1boy, looking at viewer, brown hair, red shirt
         
     | 
| 3 | 
         
            +
            1boy, looking at viewer, brown hair, purple shirt
         
     | 
| 4 | 
         
            +
            1boy, looking at viewer, brown hair, orange shirt
         
     | 
| 5 | 
         
            +
            1boy, looking at viewer, brown hair, yellow shirt
         
     | 
| 6 | 
         
            +
            1boy, looking at viewer, brown hair, green shirt
         
     | 
| 7 | 
         
            +
            1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
         
     | 
| 8 | 
         
            +
            1boy, looking back, short hair, renaissance cloths, noble boy
         
     | 
| 9 | 
         
            +
            1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
         
     | 
| 10 | 
         
            +
            1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
         
     | 
| 11 | 
         
            +
            1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
         
     | 
| 12 | 
         
            +
            1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
         
     | 
| 13 | 
         
            +
            1boy, looking at viewer, black haired, old eastern cloth
         
     | 
| 14 | 
         
            +
            1boy, looking back, messy hair, suit, short beard, noir
         
     | 
| 15 | 
         
            +
            1boy, looking at viewer, cute face, light smile, starry eyes, jeans
         
     | 
    	
        examples/prompt_girl.txt
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
         
     | 
| 2 | 
         
            +
            1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
         
     | 
| 3 | 
         
            +
            1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
         
     | 
| 4 | 
         
            +
            1girl, looking at viewer, fantasy adventurer, backpack
         
     | 
| 5 | 
         
            +
            1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
         
     | 
| 6 | 
         
            +
            1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
         
     | 
| 7 | 
         
            +
            1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
         
     | 
| 8 | 
         
            +
            1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
         
     | 
| 9 | 
         
            +
            1girl, looking at viewer, evil smile, very short hair, suit, evil genius
         
     | 
| 10 | 
         
            +
            1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
         
     | 
| 11 | 
         
            +
            1girl, looking at viewer, purple hair, happy face, black leather jacket
         
     | 
| 12 | 
         
            +
            1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
         
     | 
| 13 | 
         
            +
            1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
         
     | 
| 14 | 
         
            +
            1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
         
     | 
| 15 | 
         
            +
            1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
         
     | 
| 16 | 
         
            +
            1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
         
     | 
    	
        examples/prompt_props.txt
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            🏯 Palace, Gyeongbokgung palace
         
     | 
| 2 | 
         
            +
            🌳 Garden, Chinese garden
         
     | 
| 3 | 
         
            +
            🏛️ Rome, Ancient city of Rome
         
     | 
| 4 | 
         
            +
            🧱 Wall, Castle wall
         
     | 
| 5 | 
         
            +
            🔴 Mars, Martian desert, Red rocky desert
         
     | 
| 6 | 
         
            +
            🌻 Grassland, Grasslands
         
     | 
| 7 | 
         
            +
            🏡 Village, A fantasy village
         
     | 
| 8 | 
         
            +
            🐉 Dragon, a flying chinese dragon
         
     | 
| 9 | 
         
            +
            🌏 Earth, Earth seen from ISS
         
     | 
| 10 | 
         
            +
            🚀 Space Station, the international space station
         
     | 
| 11 | 
         
            +
            🪻 Grassland, Rusty grassland with flowers
         
     | 
| 12 | 
         
            +
            🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
         
     | 
| 13 | 
         
            +
            🏙️ City Ruin, city, ruins, ruins, ruins, deserted
         
     | 
| 14 | 
         
            +
            🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
         
     | 
| 15 | 
         
            +
            🌷 Flowers, Flower garden
         
     | 
| 16 | 
         
            +
            🌼 Flowers, Flower garden, spring garden
         
     | 
| 17 | 
         
            +
            🌹 Flowers, Flowers flowers, flowers
         
     | 
| 18 | 
         
            +
            ⛰️ Dolomites Mountains, Dolomites
         
     | 
| 19 | 
         
            +
            ⛰️ Himalayas Mountains, Himalayas
         
     | 
| 20 | 
         
            +
            ⛰️ Alps Mountains, Alps
         
     | 
| 21 | 
         
            +
            ⛰️ Mountains, Mountains
         
     | 
| 22 | 
         
            +
            ❄️⛰️ Mountains, Winter mountains
         
     | 
| 23 | 
         
            +
            🌷⛰️ Mountains, Spring mountains
         
     | 
| 24 | 
         
            +
            🌞⛰️ Mountains, Summer mountains
         
     | 
| 25 | 
         
            +
            🌵 Desert, A sandy desert, dunes
         
     | 
| 26 | 
         
            +
            🪨🌵 Desert, A rocky desert
         
     | 
| 27 | 
         
            +
            💦 Waterfall, A giant waterfall
         
     | 
| 28 | 
         
            +
            🌊 Ocean, Ocean
         
     | 
| 29 | 
         
            +
            ⛱️ Seashore, Seashore
         
     | 
| 30 | 
         
            +
            🌅 Sea Horizon, Sea horizon
         
     | 
| 31 | 
         
            +
            🌊 Lake, Clear blue lake
         
     | 
| 32 | 
         
            +
            💻 Computer, A giant supecomputer
         
     | 
| 33 | 
         
            +
            🌳 Tree, A giant tree
         
     | 
| 34 | 
         
            +
            🌳 Forest, A forest
         
     | 
| 35 | 
         
            +
            🌳🌳 Forest, A dense forest
         
     | 
| 36 | 
         
            +
            🌲 Forest, Winter forest
         
     | 
| 37 | 
         
            +
            🌴 Forest, Summer forest, tropical forest
         
     | 
| 38 | 
         
            +
            👒 Hat, A hat
         
     | 
| 39 | 
         
            +
            🐶 Dog, Doggy body parts
         
     | 
| 40 | 
         
            +
            😻 Cat, A cat
         
     | 
| 41 | 
         
            +
            🦉 Owl, A small sitting owl
         
     | 
| 42 | 
         
            +
            🦅 Eagle, A small sitting eagle
         
     | 
| 43 | 
         
            +
            🚀 Rocket, A flying rocket
         
     | 
    	
        model.py
    ADDED
    
    | 
         @@ -0,0 +1,1095 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Jaerin Lee
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 4 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 5 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 6 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 7 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 8 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 11 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 14 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 15 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 16 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 17 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 18 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 19 | 
         
            +
            # SOFTWARE.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import inspect
         
     | 
| 22 | 
         
            +
            from typing import Any, Callable, Dict, List, Literal, Tuple, Optional, Union
         
     | 
| 23 | 
         
            +
            from tqdm import tqdm
         
     | 
| 24 | 
         
            +
            from PIL import Image
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            import torch
         
     | 
| 27 | 
         
            +
            import torch.nn as nn
         
     | 
| 28 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 29 | 
         
            +
            import torchvision.transforms as T
         
     | 
| 30 | 
         
            +
            from einops import rearrange
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            from transformers import (
         
     | 
| 33 | 
         
            +
                CLIPTextModelWithProjection,
         
     | 
| 34 | 
         
            +
                CLIPTokenizer,
         
     | 
| 35 | 
         
            +
                T5EncoderModel,
         
     | 
| 36 | 
         
            +
                T5TokenizerFast,
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
            from transformers import Blip2Processor, Blip2ForConditionalGeneration
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            from diffusers.image_processor import VaeImageProcessor
         
     | 
| 41 | 
         
            +
            from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
         
     | 
| 42 | 
         
            +
            from diffusers.models.attention_processor import (
         
     | 
| 43 | 
         
            +
                AttnProcessor2_0,
         
     | 
| 44 | 
         
            +
                FusedAttnProcessor2_0,
         
     | 
| 45 | 
         
            +
                LoRAAttnProcessor2_0,
         
     | 
| 46 | 
         
            +
                LoRAXFormersAttnProcessor,
         
     | 
| 47 | 
         
            +
                XFormersAttnProcessor,
         
     | 
| 48 | 
         
            +
            )
         
     | 
| 49 | 
         
            +
            from diffusers.models.autoencoders import AutoencoderKL
         
     | 
| 50 | 
         
            +
            from diffusers.models.transformers import SD3Transformer2DModel
         
     | 
| 51 | 
         
            +
            from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3PipelineOutput
         
     | 
| 52 | 
         
            +
            from diffusers.schedulers import (
         
     | 
| 53 | 
         
            +
                FlowMatchEulerDiscreteScheduler,
         
     | 
| 54 | 
         
            +
                FlashFlowMatchEulerDiscreteScheduler,
         
     | 
| 55 | 
         
            +
            )
         
     | 
| 56 | 
         
            +
            from diffusers.utils import (
         
     | 
| 57 | 
         
            +
                is_torch_xla_available,
         
     | 
| 58 | 
         
            +
                logging,
         
     | 
| 59 | 
         
            +
                replace_example_docstring,
         
     | 
| 60 | 
         
            +
            )
         
     | 
| 61 | 
         
            +
            from diffusers.utils.torch_utils import randn_tensor
         
     | 
| 62 | 
         
            +
            from diffusers import (
         
     | 
| 63 | 
         
            +
                DiffusionPipeline, 
         
     | 
| 64 | 
         
            +
                StableDiffusion3Pipeline,
         
     | 
| 65 | 
         
            +
            )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            from peft import PeftModel
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            if is_torch_xla_available():
         
     | 
| 73 | 
         
            +
                import torch_xla.core.xla_model as xm
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                XLA_AVAILABLE = True
         
     | 
| 76 | 
         
            +
            else:
         
     | 
| 77 | 
         
            +
                XLA_AVAILABLE = False
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            EXAMPLE_DOC_STRING = """
         
     | 
| 83 | 
         
            +
                Examples:
         
     | 
| 84 | 
         
            +
                    ```py
         
     | 
| 85 | 
         
            +
                    >>> import torch
         
     | 
| 86 | 
         
            +
                    >>> from diffusers import StableDiffusion3Pipeline
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    >>> pipe = StableDiffusion3Pipeline.from_pretrained(
         
     | 
| 89 | 
         
            +
                    ...     "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
         
     | 
| 90 | 
         
            +
                    ... )
         
     | 
| 91 | 
         
            +
                    >>> pipe.to("cuda")
         
     | 
| 92 | 
         
            +
                    >>> prompt = "A cat holding a sign that says hello world"
         
     | 
| 93 | 
         
            +
                    >>> image = pipe(prompt).images[0]
         
     | 
| 94 | 
         
            +
                    >>> image.save("sd3.png")
         
     | 
| 95 | 
         
            +
                    ```
         
     | 
| 96 | 
         
            +
            """
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            class StableMultiDiffusion3Pipeline(nn.Module):
         
     | 
| 100 | 
         
            +
                def __init__(
         
     | 
| 101 | 
         
            +
                    self,
         
     | 
| 102 | 
         
            +
                    device: torch.device,
         
     | 
| 103 | 
         
            +
                    dtype: torch.dtype = torch.float16,
         
     | 
| 104 | 
         
            +
                    hf_key: Optional[str] = None,
         
     | 
| 105 | 
         
            +
                    lora_key: Optional[str] = None,
         
     | 
| 106 | 
         
            +
                    load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
         
     | 
| 107 | 
         
            +
                    default_mask_std: float = 1.0, # 8.0
         
     | 
| 108 | 
         
            +
                    default_mask_strength: float = 1.0,
         
     | 
| 109 | 
         
            +
                    default_prompt_strength: float = 1.0, # 8.0
         
     | 
| 110 | 
         
            +
                    default_bootstrap_steps: int = 1,
         
     | 
| 111 | 
         
            +
                    default_boostrap_mix_steps: float = 1.0,
         
     | 
| 112 | 
         
            +
                    default_bootstrap_leak_sensitivity: float = 0.2,
         
     | 
| 113 | 
         
            +
                    default_preprocess_mask_cover_alpha: float = 0.3,
         
     | 
| 114 | 
         
            +
                    t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number.
         
     | 
| 115 | 
         
            +
                    mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
         
     | 
| 116 | 
         
            +
                    has_i2t: bool = True,
         
     | 
| 117 | 
         
            +
                    lora_weight: float = 1.0,
         
     | 
| 118 | 
         
            +
                ) -> None:
         
     | 
| 119 | 
         
            +
                    r"""Stabilized MultiDiffusion for fast sampling.
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    Accelrated region-based text-to-image synthesis with Latent Consistency
         
     | 
| 122 | 
         
            +
                    Model while preserving mask fidelity and quality.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    Args:
         
     | 
| 125 | 
         
            +
                        device (torch.device): Specify CUDA device.
         
     | 
| 126 | 
         
            +
                        hf_key (Optional[str]): Custom StableDiffusion checkpoint for
         
     | 
| 127 | 
         
            +
                            stylized generation.
         
     | 
| 128 | 
         
            +
                        lora_key (Optional[str]): Custom Lightning LoRA for acceleration.
         
     | 
| 129 | 
         
            +
                        load_from_local (bool): Turn on if you have already downloaed LoRA 
         
     | 
| 130 | 
         
            +
                            & Hugging Face hub is down.
         
     | 
| 131 | 
         
            +
                        default_mask_std (float): Preprocess mask with Gaussian blur with
         
     | 
| 132 | 
         
            +
                            specified standard deviation.
         
     | 
| 133 | 
         
            +
                        default_mask_strength (float): Preprocess mask by multiplying it
         
     | 
| 134 | 
         
            +
                            globally with the specified variable. Caution: extremely
         
     | 
| 135 | 
         
            +
                            sensitive. Recommended range: 0.98-1.
         
     | 
| 136 | 
         
            +
                        default_prompt_strength (float): Preprocess foreground prompts
         
     | 
| 137 | 
         
            +
                            globally by linearly interpolating its embedding with the
         
     | 
| 138 | 
         
            +
                            background prompt embeddint with specified mix ratio. Useful
         
     | 
| 139 | 
         
            +
                            control handle for foreground blending. Recommended range:
         
     | 
| 140 | 
         
            +
                            0.5-1.
         
     | 
| 141 | 
         
            +
                        default_bootstrap_steps (int): Bootstrapping stage steps to
         
     | 
| 142 | 
         
            +
                            encourage region separation. Recommended range: 1-3.
         
     | 
| 143 | 
         
            +
                        default_boostrap_mix_steps (float): Bootstrapping background is a
         
     | 
| 144 | 
         
            +
                            linear interpolation between background latent and the white
         
     | 
| 145 | 
         
            +
                            image latent. This handle controls the mix ratio. Available
         
     | 
| 146 | 
         
            +
                            range: 0-(number of bootstrapping inference steps). For
         
     | 
| 147 | 
         
            +
                            example, 2.3 means that for the first two steps, white image
         
     | 
| 148 | 
         
            +
                            is used as a bootstrapping background and in the third step,
         
     | 
| 149 | 
         
            +
                            mixture of white (0.3) and registered background (0.7) is used
         
     | 
| 150 | 
         
            +
                            as a bootstrapping background.
         
     | 
| 151 | 
         
            +
                        default_bootstrap_leak_sensitivity (float): Postprocessing at each
         
     | 
| 152 | 
         
            +
                            inference step by masking away the remaining bootstrap
         
     | 
| 153 | 
         
            +
                            backgrounds t Recommended range: 0-1.
         
     | 
| 154 | 
         
            +
                        default_preprocess_mask_cover_alpha (float): Optional preprocessing
         
     | 
| 155 | 
         
            +
                            where each mask covered by other masks is reduced in its alpha
         
     | 
| 156 | 
         
            +
                            value by this specified factor.
         
     | 
| 157 | 
         
            +
                        t_index_list (List[int]): The default scheduling for the scheduler.
         
     | 
| 158 | 
         
            +
                        mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
         
     | 
| 159 | 
         
            +
                            defines the mask quantization modes. Details in the codes of
         
     | 
| 160 | 
         
            +
                            `self.process_mask`. Basically, this (subtly) controls the
         
     | 
| 161 | 
         
            +
                            smoothness of foreground-background blending. More continuous
         
     | 
| 162 | 
         
            +
                            means more blending, but smaller generated patch depending on
         
     | 
| 163 | 
         
            +
                            the mask standard deviation.
         
     | 
| 164 | 
         
            +
                        has_i2t (bool): Automatic background image to text prompt con-
         
     | 
| 165 | 
         
            +
                            version with BLIP-2 model. May not be necessary for the non-
         
     | 
| 166 | 
         
            +
                            streaming application.
         
     | 
| 167 | 
         
            +
                        lora_weight (float): Adjusts weight of the LCM/Lightning LoRA.
         
     | 
| 168 | 
         
            +
                            Heavily affects the overall quality!
         
     | 
| 169 | 
         
            +
                    """
         
     | 
| 170 | 
         
            +
                    super().__init__()
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    self.device = device
         
     | 
| 173 | 
         
            +
                    self.dtype = dtype
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.default_mask_std = default_mask_std
         
     | 
| 176 | 
         
            +
                    self.default_mask_strength = default_mask_strength
         
     | 
| 177 | 
         
            +
                    self.default_prompt_strength = default_prompt_strength
         
     | 
| 178 | 
         
            +
                    self.default_t_list = t_index_list
         
     | 
| 179 | 
         
            +
                    self.default_bootstrap_steps = default_bootstrap_steps
         
     | 
| 180 | 
         
            +
                    self.default_boostrap_mix_steps = default_boostrap_mix_steps
         
     | 
| 181 | 
         
            +
                    self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
         
     | 
| 182 | 
         
            +
                    self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
         
     | 
| 183 | 
         
            +
                    self.mask_type = mask_type
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    # Create model.
         
     | 
| 186 | 
         
            +
                    print(f'[INFO] Loading Stable Diffusion...')
         
     | 
| 187 | 
         
            +
                    if hf_key is not None:
         
     | 
| 188 | 
         
            +
                        print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
         
     | 
| 189 | 
         
            +
                    else:
         
     | 
| 190 | 
         
            +
                        hf_key = "stabilityai/stable-diffusion-3-medium-diffusers"
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    transformer = SD3Transformer2DModel.from_pretrained(
         
     | 
| 193 | 
         
            +
                        hf_key,
         
     | 
| 194 | 
         
            +
                        subfolder="transformer",
         
     | 
| 195 | 
         
            +
                        torch_dtype=torch.float16,
         
     | 
| 196 | 
         
            +
                    ).to(self.device)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-sd3").to(self.device)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    self.pipe = StableDiffusion3Pipeline.from_pretrained(
         
     | 
| 201 | 
         
            +
                        "stabilityai/stable-diffusion-3-medium-diffusers",
         
     | 
| 202 | 
         
            +
                        transformer=transformer,
         
     | 
| 203 | 
         
            +
                        torch_dtype=torch.float16,
         
     | 
| 204 | 
         
            +
                        text_encoder_3=None,
         
     | 
| 205 | 
         
            +
                        tokenizer_3=None
         
     | 
| 206 | 
         
            +
                    ).to(self.device)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # Create model
         
     | 
| 209 | 
         
            +
                    if has_i2t:
         
     | 
| 210 | 
         
            +
                        self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
         
     | 
| 211 | 
         
            +
                        self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    # Use SDXL-Lightning LoRA by default.
         
     | 
| 214 | 
         
            +
                    self.pipe.scheduler = FlashFlowMatchEulerDiscreteScheduler.from_pretrained(
         
     | 
| 215 | 
         
            +
                        "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="scheduler")
         
     | 
| 216 | 
         
            +
                    self.pipe = self.pipe.to(self.device)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    self.scheduler = self.pipe.scheduler
         
     | 
| 219 | 
         
            +
                    self.default_num_inference_steps = 4
         
     | 
| 220 | 
         
            +
                    self.default_guidance_scale = 0.0
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    if t_index_list is None:
         
     | 
| 223 | 
         
            +
                        self.prepare_flashflowmatch_schedule(
         
     | 
| 224 | 
         
            +
                            list(range(self.default_num_inference_steps)),
         
     | 
| 225 | 
         
            +
                            self.default_num_inference_steps,
         
     | 
| 226 | 
         
            +
                        )
         
     | 
| 227 | 
         
            +
                    else:
         
     | 
| 228 | 
         
            +
                        self.prepare_flashflowmatch_schedule(t_index_list, 50)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    self.vae = self.pipe.vae
         
     | 
| 231 | 
         
            +
                    self.tokenizer = self.pipe.tokenizer
         
     | 
| 232 | 
         
            +
                    self.tokenizer_2 = self.pipe.tokenizer_2
         
     | 
| 233 | 
         
            +
                    self.tokenizer_3 = self.pipe.tokenizer_3
         
     | 
| 234 | 
         
            +
                    self.text_encoder = self.pipe.text_encoder
         
     | 
| 235 | 
         
            +
                    self.text_encoder_2 = self.pipe.text_encoder_2
         
     | 
| 236 | 
         
            +
                    self.text_encoder_3 = self.pipe.text_encoder_3
         
     | 
| 237 | 
         
            +
                    self.transformer = self.pipe.transformer
         
     | 
| 238 | 
         
            +
                    self.vae_scale_factor = self.pipe.vae_scale_factor
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    # Prepare white background for bootstrapping.
         
     | 
| 241 | 
         
            +
                    self.get_white_background(1024, 1024)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    print(f'[INFO] Model is loaded!')
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                def prepare_flashflowmatch_schedule(
         
     | 
| 246 | 
         
            +
                    self,
         
     | 
| 247 | 
         
            +
                    t_index_list: Optional[List[int]] = None,
         
     | 
| 248 | 
         
            +
                    num_inference_steps: Optional[int] = None,
         
     | 
| 249 | 
         
            +
                ) -> None:
         
     | 
| 250 | 
         
            +
                    r"""Set up different inference schedule for the diffusion model.
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    You do not have to run this explicitly if you want to use the default
         
     | 
| 253 | 
         
            +
                    setting, but if you want other time schedules, run this function
         
     | 
| 254 | 
         
            +
                    between the module initialization and the main call.
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    Note:
         
     | 
| 257 | 
         
            +
                      - Recommended t_index_lists for LCMs:
         
     | 
| 258 | 
         
            +
                          - [0, 12, 25, 37]: Default schedule for 4 steps. Best for
         
     | 
| 259 | 
         
            +
                              panorama. Not recommended if you want to use bootstrapping.
         
     | 
| 260 | 
         
            +
                              Because bootstrapping stage affects the initial structuring
         
     | 
| 261 | 
         
            +
                              of the generated image & in this four step LCM, this is done
         
     | 
| 262 | 
         
            +
                              with only at the first step, the structure may be distorted.
         
     | 
| 263 | 
         
            +
                          - [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
         
     | 
| 264 | 
         
            +
                              strapping. Default initialization in this implementation.
         
     | 
| 265 | 
         
            +
                          - [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
         
     | 
| 266 | 
         
            +
                              bootstrapping.
         
     | 
| 267 | 
         
            +
                      - Due to the characteristic of SD1.5 LCM LoRA, setting
         
     | 
| 268 | 
         
            +
                        `num_inference_steps` larger than 20 may results in overly blurry
         
     | 
| 269 | 
         
            +
                        and unrealistic images. Beware!
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    Args:
         
     | 
| 272 | 
         
            +
                        t_index_list (Optional[List[int]]): The specified scheduling step
         
     | 
| 273 | 
         
            +
                            regarding the maximum timestep as `num_inference_steps`, which
         
     | 
| 274 | 
         
            +
                            is by default, 50. That means that
         
     | 
| 275 | 
         
            +
                            `t_index_list=[0, 12, 25, 37]` is a relative time indices basd
         
     | 
| 276 | 
         
            +
                            on the full scale of 50. If None, reinitialize the module with
         
     | 
| 277 | 
         
            +
                            the default value.
         
     | 
| 278 | 
         
            +
                        num_inference_steps (Optional[int]): The maximum timestep of the
         
     | 
| 279 | 
         
            +
                            sampler. Defines relative scale of the `t_index_list`. Rarely
         
     | 
| 280 | 
         
            +
                            used in practice. If None, reinitialize the module with the
         
     | 
| 281 | 
         
            +
                            default value.
         
     | 
| 282 | 
         
            +
                    """
         
     | 
| 283 | 
         
            +
                    if t_index_list is None:
         
     | 
| 284 | 
         
            +
                        t_index_list = self.default_t_list
         
     | 
| 285 | 
         
            +
                    if num_inference_steps is None:
         
     | 
| 286 | 
         
            +
                        num_inference_steps = self.default_num_inference_steps
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    self.scheduler.set_timesteps(num_inference_steps)
         
     | 
| 289 | 
         
            +
                    self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)].to(self.device)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    # FlashFlowMatchEulerDiscreteScheduler
         
     | 
| 292 | 
         
            +
                    # https://github.com/initml/diffusers/blob/clement/feature/flash_sd3/src/diffusers/schedulers/scheduling_flash_flow_match_euler_discrete.py
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)].to(self.device)
         
     | 
| 295 | 
         
            +
                    self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:].to(self.device)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
         
     | 
| 298 | 
         
            +
                    self.noise_lvs = noise_lvs[None, :, None, None, None]
         
     | 
| 299 | 
         
            +
                    self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                @torch.no_grad()
         
     | 
| 302 | 
         
            +
                def get_text_prompts(self, image: Image.Image) -> str:
         
     | 
| 303 | 
         
            +
                    r"""A convenient method to extract text prompt from an image.
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    This is called if the user does not provide background prompt but only
         
     | 
| 306 | 
         
            +
                    the background image. We use BLIP-2 to automatically generate prompts.
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    Args:
         
     | 
| 309 | 
         
            +
                        image (Image.Image): A PIL image.
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    Returns:
         
     | 
| 312 | 
         
            +
                        A single string of text prompt.
         
     | 
| 313 | 
         
            +
                    """
         
     | 
| 314 | 
         
            +
                    if hasattr(self, 'i2t_model'):
         
     | 
| 315 | 
         
            +
                        question = 'Question: What are in the image? Answer:'
         
     | 
| 316 | 
         
            +
                        inputs = self.i2t_processor(image, question, return_tensors='pt')
         
     | 
| 317 | 
         
            +
                        out = self.i2t_model.generate(**inputs, max_new_tokens=77)
         
     | 
| 318 | 
         
            +
                        prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
         
     | 
| 319 | 
         
            +
                        return prompt
         
     | 
| 320 | 
         
            +
                    else:
         
     | 
| 321 | 
         
            +
                        return ''
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                @torch.no_grad()
         
     | 
| 324 | 
         
            +
                def encode_imgs(
         
     | 
| 325 | 
         
            +
                    self,
         
     | 
| 326 | 
         
            +
                    imgs: torch.Tensor,
         
     | 
| 327 | 
         
            +
                    generator: Optional[torch.Generator] = None,
         
     | 
| 328 | 
         
            +
                    vae: Optional[nn.Module] = None,
         
     | 
| 329 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 330 | 
         
            +
                    r"""A wrapper function for VAE encoder of the latent diffusion model.
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    Args:
         
     | 
| 333 | 
         
            +
                        imgs (torch.Tensor): An image to get StableDiffusion latents.
         
     | 
| 334 | 
         
            +
                            Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
         
     | 
| 335 | 
         
            +
                        generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
         
     | 
| 336 | 
         
            +
                        vae (Optional[nn.Module]): Explicitly specify VAE (used for
         
     | 
| 337 | 
         
            +
                            the demo application with TinyVAE).
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    Returns:
         
     | 
| 340 | 
         
            +
                        An image latent embedding with 1/8 size (depending on the auto-
         
     | 
| 341 | 
         
            +
                        encoder. Shape: (B, 4, H//8, W//8).
         
     | 
| 342 | 
         
            +
                    """
         
     | 
| 343 | 
         
            +
                    def _retrieve_latents(
         
     | 
| 344 | 
         
            +
                        encoder_output: torch.Tensor,
         
     | 
| 345 | 
         
            +
                        generator: Optional[torch.Generator] = None,
         
     | 
| 346 | 
         
            +
                        sample_mode: str = 'sample',
         
     | 
| 347 | 
         
            +
                    ):
         
     | 
| 348 | 
         
            +
                        if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
         
     | 
| 349 | 
         
            +
                            return encoder_output.latent_dist.sample(generator)
         
     | 
| 350 | 
         
            +
                        elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
         
     | 
| 351 | 
         
            +
                            return encoder_output.latent_dist.mode()
         
     | 
| 352 | 
         
            +
                        elif hasattr(encoder_output, 'latents'):
         
     | 
| 353 | 
         
            +
                            return encoder_output.latents
         
     | 
| 354 | 
         
            +
                        else:
         
     | 
| 355 | 
         
            +
                            raise AttributeError('Could not access latents of provided encoder_output')
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    vae = self.vae if vae is None else vae
         
     | 
| 358 | 
         
            +
                    imgs = 2 * imgs - 1
         
     | 
| 359 | 
         
            +
                    latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
         
     | 
| 360 | 
         
            +
                    return latents
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                @torch.no_grad()
         
     | 
| 363 | 
         
            +
                def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
         
     | 
| 364 | 
         
            +
                    r"""A wrapper function for VAE decoder of the latent diffusion model.
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    Args:
         
     | 
| 367 | 
         
            +
                        latents (torch.Tensor): An image latent to get associated images.
         
     | 
| 368 | 
         
            +
                            Expected shape: (B, 4, H//8, W//8).
         
     | 
| 369 | 
         
            +
                        vae (Optional[nn.Module]): Explicitly specify VAE (used for
         
     | 
| 370 | 
         
            +
                            the demo application with TinyVAE).
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    Returns:
         
     | 
| 373 | 
         
            +
                        An image latent embedding with 1/8 size (depending on the auto-
         
     | 
| 374 | 
         
            +
                        encoder. Shape: (B, 3, H, W).
         
     | 
| 375 | 
         
            +
                    """
         
     | 
| 376 | 
         
            +
                    vae = self.vae if vae is None else vae
         
     | 
| 377 | 
         
            +
                    latents = 1 / vae.config.scaling_factor * latents
         
     | 
| 378 | 
         
            +
                    imgs = vae.decode(latents).sample
         
     | 
| 379 | 
         
            +
                    imgs = (imgs / 2 + 0.5).clip_(0, 1)
         
     | 
| 380 | 
         
            +
                    return imgs
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                @torch.no_grad()
         
     | 
| 383 | 
         
            +
                def get_white_background(self, height: int, width: int) -> torch.Tensor:
         
     | 
| 384 | 
         
            +
                    r"""White background image latent for bootstrapping or in case of
         
     | 
| 385 | 
         
            +
                    absent background.
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    Additionally stores the maximally-sized white latent for fast retrieval
         
     | 
| 388 | 
         
            +
                    in the future. By default, we initially call this with 1024x1024 sized
         
     | 
| 389 | 
         
            +
                    white image, so the function is rarely visited twice.
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    Args:
         
     | 
| 392 | 
         
            +
                        height (int): The height of the white *image*, not its latent.
         
     | 
| 393 | 
         
            +
                        width (int): The width of the white *image*, not its latent.
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    Returns:
         
     | 
| 396 | 
         
            +
                        A white image latent of size (1, 4, height//8, width//8). A cropped
         
     | 
| 397 | 
         
            +
                        version of the stored white latent is returned if the requested
         
     | 
| 398 | 
         
            +
                        size is smaller than what we already have created.
         
     | 
| 399 | 
         
            +
                    """
         
     | 
| 400 | 
         
            +
                    if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
         
     | 
| 401 | 
         
            +
                        white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
         
     | 
| 402 | 
         
            +
                        self.white = self.encode_imgs(white)
         
     | 
| 403 | 
         
            +
                        return self.white
         
     | 
| 404 | 
         
            +
                    return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                @torch.no_grad()
         
     | 
| 407 | 
         
            +
                def process_mask(
         
     | 
| 408 | 
         
            +
                    self,
         
     | 
| 409 | 
         
            +
                    masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
         
     | 
| 410 | 
         
            +
                    strength: Optional[Union[torch.Tensor, float]] = None,
         
     | 
| 411 | 
         
            +
                    std: Optional[Union[torch.Tensor, float]] = None,
         
     | 
| 412 | 
         
            +
                    height: int = 1024,
         
     | 
| 413 | 
         
            +
                    width: int = 1024,
         
     | 
| 414 | 
         
            +
                    use_boolean_mask: bool = True,
         
     | 
| 415 | 
         
            +
                    timesteps: Optional[torch.Tensor] = None,
         
     | 
| 416 | 
         
            +
                    preprocess_mask_cover_alpha: Optional[float] = None,
         
     | 
| 417 | 
         
            +
                ) -> Tuple[torch.Tensor]:
         
     | 
| 418 | 
         
            +
                    r"""Fast preprocess of masks for region-based generation with fine-
         
     | 
| 419 | 
         
            +
                    grained controls.
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    Mask preprocessing is done in four steps:
         
     | 
| 422 | 
         
            +
                     1. Resizing: Resize the masks into the specified width and height by
         
     | 
| 423 | 
         
            +
                        nearest neighbor interpolation.
         
     | 
| 424 | 
         
            +
                     2. (Optional) Ordering: Masks with higher indices are considered to
         
     | 
| 425 | 
         
            +
                        cover the masks with smaller indices. Covered masks are decayed
         
     | 
| 426 | 
         
            +
                        in its alpha value by the specified factor of
         
     | 
| 427 | 
         
            +
                        `preprocess_mask_cover_alpha`.
         
     | 
| 428 | 
         
            +
                     3. Blurring: Gaussian blur is applied to the mask with the specified
         
     | 
| 429 | 
         
            +
                        standard deviation (isotropic). This results in gradual increase of
         
     | 
| 430 | 
         
            +
                        masked region as the timesteps evolve, naturally blending fore-
         
     | 
| 431 | 
         
            +
                        ground and the predesignated background. Not strictly required if
         
     | 
| 432 | 
         
            +
                        you want to produce images from scratch withoout background.
         
     | 
| 433 | 
         
            +
                     4. Quantization: Split the real-numbered masks of value between [0, 1]
         
     | 
| 434 | 
         
            +
                        into predefined noise levels for each quantized scheduling step of
         
     | 
| 435 | 
         
            +
                        the diffusion sampler. For example, if the diffusion model sampler
         
     | 
| 436 | 
         
            +
                        has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
         
     | 
| 437 | 
         
            +
                        is the default noise level of this module with schedule [0, 4, 12,
         
     | 
| 438 | 
         
            +
                        25, 37], the masks are split into binary masks whose values are
         
     | 
| 439 | 
         
            +
                        greater than these levels. This results in tradual increase of mask
         
     | 
| 440 | 
         
            +
                        region as the timesteps increase. Details are described in our
         
     | 
| 441 | 
         
            +
                        paper at https://arxiv.org/pdf/2403.09055.pdf.
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    On the Three Modes of `mask_type`:
         
     | 
| 444 | 
         
            +
                        `self.mask_type` is predefined at the initialization stage of this
         
     | 
| 445 | 
         
            +
                        pipeline. Three possible modes are available: 'discrete', 'semi-
         
     | 
| 446 | 
         
            +
                        continuous', and 'continuous'. These define the mask quantization
         
     | 
| 447 | 
         
            +
                        modes we use. Basically, this (subtly) controls the smoothness of
         
     | 
| 448 | 
         
            +
                        foreground-background blending. Continuous modes produces nonbinary
         
     | 
| 449 | 
         
            +
                        masks to further blend foreground and background latents by linear-
         
     | 
| 450 | 
         
            +
                        ly interpolating between them. Semi-continuous masks only applies
         
     | 
| 451 | 
         
            +
                        continuous mask at the last step of the LCM sampler. Due to the
         
     | 
| 452 | 
         
            +
                        large step size of the LCM scheduler, we find that our continuous
         
     | 
| 453 | 
         
            +
                        blending helps generating seamless inpainting and editing results.
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    Args:
         
     | 
| 456 | 
         
            +
                        masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
         
     | 
| 457 | 
         
            +
                        strength (Optional[Union[torch.Tensor, float]]): Mask strength that
         
     | 
| 458 | 
         
            +
                            overrides the default value. A globally multiplied factor to
         
     | 
| 459 | 
         
            +
                            the mask at the initial stage of processing. Can be applied
         
     | 
| 460 | 
         
            +
                            seperately for each mask.
         
     | 
| 461 | 
         
            +
                        std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
         
     | 
| 462 | 
         
            +
                            kernel's standard deviation. Overrides the default value. Can
         
     | 
| 463 | 
         
            +
                            be applied seperately for each mask.
         
     | 
| 464 | 
         
            +
                        height (int): The height of the expected generation. Mask is
         
     | 
| 465 | 
         
            +
                            resized to (height//8, width//8) with nearest neighbor inter-
         
     | 
| 466 | 
         
            +
                            polation.
         
     | 
| 467 | 
         
            +
                        width (int): The width of the expected generation. Mask is resized
         
     | 
| 468 | 
         
            +
                            to (height//8, width//8) with nearest neighbor interpolation.
         
     | 
| 469 | 
         
            +
                        use_boolean_mask (bool): Specify this to treat the mask image as
         
     | 
| 470 | 
         
            +
                            a boolean tensor. The retion with dark part darker than 0.5 of
         
     | 
| 471 | 
         
            +
                            the maximal pixel value (that is, 127.5) is considered as the
         
     | 
| 472 | 
         
            +
                            designated mask.
         
     | 
| 473 | 
         
            +
                        timesteps (Optional[torch.Tensor]): Defines the scheduler noise
         
     | 
| 474 | 
         
            +
                            levels that acts as bins of mask quantization.
         
     | 
| 475 | 
         
            +
                        preprocess_mask_cover_alpha (Optional[float]): Optional pre-
         
     | 
| 476 | 
         
            +
                            processing where each mask covered by other masks is reduced in
         
     | 
| 477 | 
         
            +
                            its alpha value by this specified factor. Overrides the default
         
     | 
| 478 | 
         
            +
                            value.
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                    Returns: A tuple of tensors.
         
     | 
| 481 | 
         
            +
                      - masks: Preprocessed (ordered, blurred, and quantized) binary/non-
         
     | 
| 482 | 
         
            +
                            binary masks (see the explanation on `mask_type` above) for
         
     | 
| 483 | 
         
            +
                            region-based image synthesis.
         
     | 
| 484 | 
         
            +
                      - masks_blurred: Gaussian blurred masks. Used for optionally
         
     | 
| 485 | 
         
            +
                            specified foreground-background blending after image
         
     | 
| 486 | 
         
            +
                            generation.
         
     | 
| 487 | 
         
            +
                      - std: Mask blur standard deviation. Used for optionally specified
         
     | 
| 488 | 
         
            +
                            foreground-background blending after image generation.
         
     | 
| 489 | 
         
            +
                    """
         
     | 
| 490 | 
         
            +
                    if isinstance(masks, Image.Image):
         
     | 
| 491 | 
         
            +
                        masks = [masks]
         
     | 
| 492 | 
         
            +
                    if isinstance(masks, (tuple, list)):
         
     | 
| 493 | 
         
            +
                        # Assumes white background for Image.Image;
         
     | 
| 494 | 
         
            +
                        # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
         
     | 
| 495 | 
         
            +
                        if use_boolean_mask:
         
     | 
| 496 | 
         
            +
                            proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
         
     | 
| 497 | 
         
            +
                        else:
         
     | 
| 498 | 
         
            +
                            proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
         
     | 
| 499 | 
         
            +
                        masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
         
     | 
| 500 | 
         
            +
                    masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
         
     | 
| 501 | 
         
            +
                    masks = masks.to(self.device)
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
                    # Background mask alpha is decayed by the specified factor where foreground masks covers it.
         
     | 
| 504 | 
         
            +
                    if preprocess_mask_cover_alpha is None:
         
     | 
| 505 | 
         
            +
                        preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
         
     | 
| 506 | 
         
            +
                    if preprocess_mask_cover_alpha > 0:
         
     | 
| 507 | 
         
            +
                        masks = torch.stack([
         
     | 
| 508 | 
         
            +
                            torch.where(
         
     | 
| 509 | 
         
            +
                                masks[i + 1:].sum(dim=0) > 0,
         
     | 
| 510 | 
         
            +
                                mask * preprocess_mask_cover_alpha,
         
     | 
| 511 | 
         
            +
                                mask,
         
     | 
| 512 | 
         
            +
                            ) if i < len(masks) - 1 else mask
         
     | 
| 513 | 
         
            +
                            for i, mask in enumerate(masks)
         
     | 
| 514 | 
         
            +
                        ], dim=0)
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                    # Scheduler noise levels for mask quantization.
         
     | 
| 517 | 
         
            +
                    if timesteps is None:
         
     | 
| 518 | 
         
            +
                        noise_lvs = self.noise_lvs
         
     | 
| 519 | 
         
            +
                        next_noise_lvs = self.next_noise_lvs
         
     | 
| 520 | 
         
            +
                    else:
         
     | 
| 521 | 
         
            +
                        noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
         
     | 
| 522 | 
         
            +
                        # noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
         
     | 
| 523 | 
         
            +
                        noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device)
         
     | 
| 524 | 
         
            +
                        next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    # Mask preprocessing parameters are fetched from the default settings.
         
     | 
| 527 | 
         
            +
                    if std is None:
         
     | 
| 528 | 
         
            +
                        std = self.default_mask_std
         
     | 
| 529 | 
         
            +
                    if isinstance(std, (int, float)):
         
     | 
| 530 | 
         
            +
                        std = [std] * len(masks)
         
     | 
| 531 | 
         
            +
                    if isinstance(std, (list, tuple)):
         
     | 
| 532 | 
         
            +
                        std = torch.as_tensor(std, dtype=torch.float, device=self.device)
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                    if strength is None:
         
     | 
| 535 | 
         
            +
                        strength = self.default_mask_strength
         
     | 
| 536 | 
         
            +
                    if isinstance(strength, (int, float)):
         
     | 
| 537 | 
         
            +
                        strength = [strength] * len(masks)
         
     | 
| 538 | 
         
            +
                    if isinstance(strength, (list, tuple)):
         
     | 
| 539 | 
         
            +
                        strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                    if (std > 0).any():
         
     | 
| 542 | 
         
            +
                        std = torch.where(std > 0, std, 1e-5)
         
     | 
| 543 | 
         
            +
                        masks = gaussian_lowpass(masks, std)
         
     | 
| 544 | 
         
            +
                    masks_blurred = masks
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                    # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
         
     | 
| 547 | 
         
            +
                    #       gives unpleasant results.
         
     | 
| 548 | 
         
            +
                    masks = masks * strength[:, None, None, None]
         
     | 
| 549 | 
         
            +
                    masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    # Mask is quantized according to the current noise levels specified by the scheduler.
         
     | 
| 552 | 
         
            +
                    if self.mask_type == 'discrete':
         
     | 
| 553 | 
         
            +
                        # Discrete mode.
         
     | 
| 554 | 
         
            +
                        masks = masks > noise_lvs
         
     | 
| 555 | 
         
            +
                    elif self.mask_type == 'semi-continuous':
         
     | 
| 556 | 
         
            +
                        # Semi-continuous mode (continuous at the last step only).
         
     | 
| 557 | 
         
            +
                        masks = torch.cat((
         
     | 
| 558 | 
         
            +
                            masks[:, :-1] > noise_lvs[:, :-1],
         
     | 
| 559 | 
         
            +
                            (
         
     | 
| 560 | 
         
            +
                                (masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
         
     | 
| 561 | 
         
            +
                            ).clip_(0, 1),
         
     | 
| 562 | 
         
            +
                        ), dim=1)
         
     | 
| 563 | 
         
            +
                    elif self.mask_type == 'continuous':
         
     | 
| 564 | 
         
            +
                        # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
         
     | 
| 565 | 
         
            +
                        #                  decreases continuously after the discrete mode boundary to become `0` at the
         
     | 
| 566 | 
         
            +
                        #                  next lower threshold.
         
     | 
| 567 | 
         
            +
                        masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                    # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
         
     | 
| 570 | 
         
            +
                    #       fine-grained mask alpha channel tuning is available with this form.
         
     | 
| 571 | 
         
            +
                    # masks = masks * strength[None, :, None, None, None]
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                    h = height // self.vae_scale_factor
         
     | 
| 574 | 
         
            +
                    w = width // self.vae_scale_factor
         
     | 
| 575 | 
         
            +
                    masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
         
     | 
| 576 | 
         
            +
                    masks = F.interpolate(masks, size=(h, w), mode='nearest')
         
     | 
| 577 | 
         
            +
                    masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
         
     | 
| 578 | 
         
            +
                    return masks, masks_blurred, std
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                def scheduler_step(
         
     | 
| 581 | 
         
            +
                    self,
         
     | 
| 582 | 
         
            +
                    noise_pred: torch.Tensor,
         
     | 
| 583 | 
         
            +
                    idx: int,
         
     | 
| 584 | 
         
            +
                    latent: torch.Tensor,
         
     | 
| 585 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 586 | 
         
            +
                    r"""Denoise-only step for reverse diffusion scheduler.
         
     | 
| 587 | 
         
            +
                    
         
     | 
| 588 | 
         
            +
                    Designed to match the interface of the original `pipe.scheduler.step`,
         
     | 
| 589 | 
         
            +
                    which is a combination of this method and the following
         
     | 
| 590 | 
         
            +
                    `scheduler_add_noise`.
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                    Args:
         
     | 
| 593 | 
         
            +
                        noise_pred (torch.Tensor): Noise prediction results from the U-Net.
         
     | 
| 594 | 
         
            +
                        idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
         
     | 
| 595 | 
         
            +
                            for the timesteps tensor (ranged in [0, len(timesteps)-1]).
         
     | 
| 596 | 
         
            +
                        latent (torch.Tensor): Noisy latent.
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    Returns:
         
     | 
| 599 | 
         
            +
                        A denoised tensor with the same size as latent.
         
     | 
| 600 | 
         
            +
                    """
         
     | 
| 601 | 
         
            +
                    # Upcast to avoid precision issues when computing prev_sample.
         
     | 
| 602 | 
         
            +
                    latent = latent.to(torch.float32)
         
     | 
| 603 | 
         
            +
                    prev_sample = latent - noise_pred * self.sigmas[idx]
         
     | 
| 604 | 
         
            +
                    return prev_sample.to(self.dtype)
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                def scheduler_add_noise(
         
     | 
| 607 | 
         
            +
                    self,
         
     | 
| 608 | 
         
            +
                    latent: torch.Tensor,
         
     | 
| 609 | 
         
            +
                    noise: Optional[torch.Tensor],
         
     | 
| 610 | 
         
            +
                    idx: int,
         
     | 
| 611 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 612 | 
         
            +
                    r"""Separated noise-add step for the reverse diffusion scheduler.
         
     | 
| 613 | 
         
            +
                    
         
     | 
| 614 | 
         
            +
                    Designed to match the interface of the original
         
     | 
| 615 | 
         
            +
                    `pipe.scheduler.add_noise`.
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
                    Args:
         
     | 
| 618 | 
         
            +
                        latent (torch.Tensor): Denoised latent.
         
     | 
| 619 | 
         
            +
                        noise (torch.Tensor): Added noise. Can be None. If None, a random
         
     | 
| 620 | 
         
            +
                            noise is newly sampled for addition.
         
     | 
| 621 | 
         
            +
                        idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
         
     | 
| 622 | 
         
            +
                            for the timesteps tensor (ranged in [0, len(timesteps)-1]).
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    Returns:
         
     | 
| 625 | 
         
            +
                        A noisy tensor with the same size as latent.
         
     | 
| 626 | 
         
            +
                    """
         
     | 
| 627 | 
         
            +
                    if idx < len(self.sigmas) and idx >= 0:
         
     | 
| 628 | 
         
            +
                        noise = torch.randn_like(latent) if noise is None else noise
         
     | 
| 629 | 
         
            +
                        return (1.0 - self.sigmas[idx]) * latent + self.sigmas[idx] * noise
         
     | 
| 630 | 
         
            +
                    else:
         
     | 
| 631 | 
         
            +
                        return latent
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                @torch.no_grad()
         
     | 
| 634 | 
         
            +
                def __call__(
         
     | 
| 635 | 
         
            +
                    self,
         
     | 
| 636 | 
         
            +
                    prompts: Optional[Union[str, List[str]]] = None,
         
     | 
| 637 | 
         
            +
                    negative_prompts: Union[str, List[str]] = '',
         
     | 
| 638 | 
         
            +
                    suffix: Optional[str] = None, #', background is ',
         
     | 
| 639 | 
         
            +
                    background: Optional[Union[torch.Tensor, Image.Image]] = None,
         
     | 
| 640 | 
         
            +
                    background_prompt: Optional[str] = None,
         
     | 
| 641 | 
         
            +
                    background_negative_prompt: str = '',
         
     | 
| 642 | 
         
            +
                    height: int = 1024,
         
     | 
| 643 | 
         
            +
                    width: int = 1024,
         
     | 
| 644 | 
         
            +
                    num_inference_steps: Optional[int] = None,
         
     | 
| 645 | 
         
            +
                    guidance_scale: Optional[float] = None,
         
     | 
| 646 | 
         
            +
                    prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
         
     | 
| 647 | 
         
            +
                    masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
         
     | 
| 648 | 
         
            +
                    mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
         
     | 
| 649 | 
         
            +
                    mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
         
     | 
| 650 | 
         
            +
                    use_boolean_mask: bool = True,
         
     | 
| 651 | 
         
            +
                    do_blend: bool = True,
         
     | 
| 652 | 
         
            +
                    tile_size: int = 1024,
         
     | 
| 653 | 
         
            +
                    bootstrap_steps: Optional[int] = None,
         
     | 
| 654 | 
         
            +
                    boostrap_mix_steps: Optional[float] = None,
         
     | 
| 655 | 
         
            +
                    bootstrap_leak_sensitivity: Optional[float] = None,
         
     | 
| 656 | 
         
            +
                    preprocess_mask_cover_alpha: Optional[float] = None,
         
     | 
| 657 | 
         
            +
                    # SDXL Pipeline setting.
         
     | 
| 658 | 
         
            +
                    guidance_rescale: float = 0.7,
         
     | 
| 659 | 
         
            +
                    output_type = 'pil',
         
     | 
| 660 | 
         
            +
                    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 661 | 
         
            +
                    clip_skip: Optional[int] = None,
         
     | 
| 662 | 
         
            +
                ) -> Image.Image:
         
     | 
| 663 | 
         
            +
                    r"""Arbitrary-size image generation from multiple pairs of (regional)
         
     | 
| 664 | 
         
            +
                    text prompt-mask pairs.
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                    This is a main routine for this pipeline.
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                    Example:
         
     | 
| 669 | 
         
            +
                        >>> device = torch.device('cuda:0')
         
     | 
| 670 | 
         
            +
                        >>> smd = StableMultiDiffusionPipeline(device)
         
     | 
| 671 | 
         
            +
                        >>> prompts = {... specify prompts}
         
     | 
| 672 | 
         
            +
                        >>> masks = {... specify mask tensors}
         
     | 
| 673 | 
         
            +
                        >>> height, width = masks.shape[-2:]
         
     | 
| 674 | 
         
            +
                        >>> image = smd(
         
     | 
| 675 | 
         
            +
                        >>>     prompts, masks=masks.float(), height=height, width=width)
         
     | 
| 676 | 
         
            +
                        >>> image.save('my_beautiful_creation.png')
         
     | 
| 677 | 
         
            +
             
     | 
| 678 | 
         
            +
                    Args:
         
     | 
| 679 | 
         
            +
                        prompts (Union[str, List[str]]): A text prompt.
         
     | 
| 680 | 
         
            +
                        negative_prompts (Union[str, List[str]]): A negative text prompt.
         
     | 
| 681 | 
         
            +
                        suffix (Optional[str]): One option for blending foreground prompts
         
     | 
| 682 | 
         
            +
                            with background prompts by simply appending background prompt
         
     | 
| 683 | 
         
            +
                            to the end of each foreground prompt with this `middle word` in
         
     | 
| 684 | 
         
            +
                            between. For example, if you set this as `, background is`,
         
     | 
| 685 | 
         
            +
                            then the foreground prompt will be changed into
         
     | 
| 686 | 
         
            +
                            `(fg), background is (bg)` before conditional generation.
         
     | 
| 687 | 
         
            +
                        background (Optional[Union[torch.Tensor, Image.Image]]): a
         
     | 
| 688 | 
         
            +
                            background image, if the user wants to draw in front of the
         
     | 
| 689 | 
         
            +
                            specified image. Background prompt will automatically generated
         
     | 
| 690 | 
         
            +
                            with a BLIP-2 model.
         
     | 
| 691 | 
         
            +
                        background_prompt (Optional[str]): The background prompt is used
         
     | 
| 692 | 
         
            +
                            for preprocessing foreground prompt embeddings to blend
         
     | 
| 693 | 
         
            +
                            foreground and background.
         
     | 
| 694 | 
         
            +
                        background_negative_prompt (Optional[str]): The negative background
         
     | 
| 695 | 
         
            +
                            prompt.
         
     | 
| 696 | 
         
            +
                        height (int): Height of a generated image. It is tiled if larger
         
     | 
| 697 | 
         
            +
                            than `tile_size`.
         
     | 
| 698 | 
         
            +
                        width (int): Width of a generated image. It is tiled if larger
         
     | 
| 699 | 
         
            +
                            than `tile_size`.
         
     | 
| 700 | 
         
            +
                        num_inference_steps (Optional[int]): Number of inference steps.
         
     | 
| 701 | 
         
            +
                            Default inference scheduling is used if none is specified.
         
     | 
| 702 | 
         
            +
                        guidance_scale (Optional[float]): Classifier guidance scale.
         
     | 
| 703 | 
         
            +
                            Default value is used if none is specified.
         
     | 
| 704 | 
         
            +
                        prompt_strength (float): Overrides default value. Preprocess
         
     | 
| 705 | 
         
            +
                            foreground prompts globally by linearly interpolating its
         
     | 
| 706 | 
         
            +
                            embedding with the background prompt embeddint with specified
         
     | 
| 707 | 
         
            +
                            mix ratio. Useful control handle for foreground blending.
         
     | 
| 708 | 
         
            +
                            Recommended range: 0.5-1.
         
     | 
| 709 | 
         
            +
                        masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
         
     | 
| 710 | 
         
            +
                            mask images. Each mask associates with each of the text prompts
         
     | 
| 711 | 
         
            +
                            and each of the negative prompts. If specified as an image, it
         
     | 
| 712 | 
         
            +
                            regards the image as a boolean mask. Also accepts torch.Tensor
         
     | 
| 713 | 
         
            +
                            masks, which can have nonbinary values for fine-grained
         
     | 
| 714 | 
         
            +
                            controls in mixing regional generations.
         
     | 
| 715 | 
         
            +
                        mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
         
     | 
| 716 | 
         
            +
                            Overrides the default value. an be assigned for each mask
         
     | 
| 717 | 
         
            +
                            separately. Preprocess mask by multiplying it globally with the
         
     | 
| 718 | 
         
            +
                            specified variable. Caution: extremely sensitive. Recommended
         
     | 
| 719 | 
         
            +
                            range: 0.98-1.
         
     | 
| 720 | 
         
            +
                        mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
         
     | 
| 721 | 
         
            +
                            Overrides the default value. Can be assigned for each mask
         
     | 
| 722 | 
         
            +
                            separately. Preprocess mask with Gaussian blur with specified
         
     | 
| 723 | 
         
            +
                            standard deviation. Recommended range: 0-64.
         
     | 
| 724 | 
         
            +
                        use_boolean_mask (bool): Turn this off if you want to treat the
         
     | 
| 725 | 
         
            +
                            mask image as nonbinary one. The module will use the last
         
     | 
| 726 | 
         
            +
                            channel of the given image in `masks` as the mask value.
         
     | 
| 727 | 
         
            +
                        do_blend (bool): Blend the generated foreground and the optionally
         
     | 
| 728 | 
         
            +
                            predefined background by smooth boundary obtained from Gaussian
         
     | 
| 729 | 
         
            +
                            blurs of the foreground `masks` with the given `mask_stds`.
         
     | 
| 730 | 
         
            +
                        tile_size (Optional[int]): Tile size of the panorama generation.
         
     | 
| 731 | 
         
            +
                            Works best with the default training size of the Stable-
         
     | 
| 732 | 
         
            +
                            Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL.
         
     | 
| 733 | 
         
            +
                        bootstrap_steps (int): Overrides the default value. Bootstrapping
         
     | 
| 734 | 
         
            +
                            stage steps to encourage region separation. Recommended range:
         
     | 
| 735 | 
         
            +
                            1-3.
         
     | 
| 736 | 
         
            +
                        boostrap_mix_steps (float): Overrides the default value.
         
     | 
| 737 | 
         
            +
                            Bootstrapping background is a linear interpolation between
         
     | 
| 738 | 
         
            +
                            background latent and the white image latent. This handle
         
     | 
| 739 | 
         
            +
                            controls the mix ratio. Available range: 0-(number of
         
     | 
| 740 | 
         
            +
                            bootstrapping inference steps). For example, 2.3 means that for
         
     | 
| 741 | 
         
            +
                            the first two steps, white image is used as a bootstrapping
         
     | 
| 742 | 
         
            +
                            background and in the third step, mixture of white (0.3) and
         
     | 
| 743 | 
         
            +
                            registered background (0.7) is used as a bootstrapping
         
     | 
| 744 | 
         
            +
                            background.
         
     | 
| 745 | 
         
            +
                        bootstrap_leak_sensitivity (float): Overrides the default value.
         
     | 
| 746 | 
         
            +
                            Postprocessing at each inference step by masking away the
         
     | 
| 747 | 
         
            +
                            remaining bootstrap backgrounds t Recommended range: 0-1.
         
     | 
| 748 | 
         
            +
                        preprocess_mask_cover_alpha (float): Overrides the default value.
         
     | 
| 749 | 
         
            +
                            Optional preprocessing where each mask covered by other masks
         
     | 
| 750 | 
         
            +
                            is reduced in its alpha value by this specified factor.
         
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
                    Returns: A PIL.Image image of a panorama (large-size) image.
         
     | 
| 753 | 
         
            +
                    """
         
     | 
| 754 | 
         
            +
             
     | 
| 755 | 
         
            +
                    ### Simplest cases
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                    # prompts is None: return background.
         
     | 
| 758 | 
         
            +
                    # masks is None but prompts is not None: return prompts
         
     | 
| 759 | 
         
            +
                    # masks is not None and prompts is not None: Do StableMultiDiffusion.
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
                    if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
         
     | 
| 762 | 
         
            +
                        if background is None and background_prompt is not None:
         
     | 
| 763 | 
         
            +
                            return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
         
     | 
| 764 | 
         
            +
                        return background
         
     | 
| 765 | 
         
            +
                    elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
         
     | 
| 766 | 
         
            +
                        return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
             
     | 
| 769 | 
         
            +
                    ### Prepare generation
         
     | 
| 770 | 
         
            +
             
     | 
| 771 | 
         
            +
                    if num_inference_steps is not None:
         
     | 
| 772 | 
         
            +
                        self.prepare_flashflowmatch_schedule(list(range(num_inference_steps)), num_inference_steps)
         
     | 
| 773 | 
         
            +
             
     | 
| 774 | 
         
            +
                    if guidance_scale is None:
         
     | 
| 775 | 
         
            +
                        guidance_scale = self.default_guidance_scale
         
     | 
| 776 | 
         
            +
                    self.pipe._guidance_scale = guidance_scale
         
     | 
| 777 | 
         
            +
                    self.pipe._clip_skip = clip_skip
         
     | 
| 778 | 
         
            +
                    self.pipe._joint_attention_kwargs = joint_attention_kwargs
         
     | 
| 779 | 
         
            +
                    self.pipe._interrupt = False
         
     | 
| 780 | 
         
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         
     | 
| 781 | 
         
            +
             
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                    ### Prompts & Masks
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
                    # asserts #m > 0 and #p > 0.
         
     | 
| 786 | 
         
            +
                    # #m == #p == #n > 0: We happily generate according to the prompts & masks.
         
     | 
| 787 | 
         
            +
                    # #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
         
     | 
| 788 | 
         
            +
                    # #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
                    if isinstance(masks, Image.Image):
         
     | 
| 791 | 
         
            +
                        masks = [masks]
         
     | 
| 792 | 
         
            +
                    if isinstance(prompts, str):
         
     | 
| 793 | 
         
            +
                        prompts = [prompts]
         
     | 
| 794 | 
         
            +
                    if isinstance(negative_prompts, str):
         
     | 
| 795 | 
         
            +
                        negative_prompts = [negative_prompts]
         
     | 
| 796 | 
         
            +
                    num_masks = len(masks)
         
     | 
| 797 | 
         
            +
                    num_prompts = len(prompts)
         
     | 
| 798 | 
         
            +
                    num_nprompts = len(negative_prompts)
         
     | 
| 799 | 
         
            +
                    assert num_prompts in (num_masks, 1), \
         
     | 
| 800 | 
         
            +
                        f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
         
     | 
| 801 | 
         
            +
                    assert num_nprompts in (num_prompts, 1), \
         
     | 
| 802 | 
         
            +
                        f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                    fg_masks, masks_g, std = self.process_mask(
         
     | 
| 805 | 
         
            +
                        masks,
         
     | 
| 806 | 
         
            +
                        mask_strengths,
         
     | 
| 807 | 
         
            +
                        mask_stds,
         
     | 
| 808 | 
         
            +
                        height=height,
         
     | 
| 809 | 
         
            +
                        width=width,
         
     | 
| 810 | 
         
            +
                        use_boolean_mask=use_boolean_mask,
         
     | 
| 811 | 
         
            +
                        timesteps=self.timesteps,
         
     | 
| 812 | 
         
            +
                        preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
         
     | 
| 813 | 
         
            +
                    )  # (p, t, 1, H, W)
         
     | 
| 814 | 
         
            +
                    bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1)  # (T, 1, h, w)
         
     | 
| 815 | 
         
            +
                    has_background = bg_masks.sum() > 0
         
     | 
| 816 | 
         
            +
             
     | 
| 817 | 
         
            +
                    h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
         
     | 
| 818 | 
         
            +
                    w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                    ### Background
         
     | 
| 822 | 
         
            +
             
     | 
| 823 | 
         
            +
                    # background == None && background_prompt == None: Initialize with white background.
         
     | 
| 824 | 
         
            +
                    # background == None && background_prompt != None: Generate background *along with other prompts*.
         
     | 
| 825 | 
         
            +
                    # background != None && background_prompt == None: Retrieve text prompt using BLIP.
         
     | 
| 826 | 
         
            +
                    # background != None && background_prompt != None: Use the given arguments.
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                    # not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
         
     | 
| 829 | 
         
            +
                    # has_background && prompt_strength != 1: mix only for this case.
         
     | 
| 830 | 
         
            +
             
     | 
| 831 | 
         
            +
                    bg_latent = None
         
     | 
| 832 | 
         
            +
                    if has_background:
         
     | 
| 833 | 
         
            +
                        if background is None and background_prompt is not None:
         
     | 
| 834 | 
         
            +
                            fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
         
     | 
| 835 | 
         
            +
                            if suffix is not None:
         
     | 
| 836 | 
         
            +
                                prompts = [p + suffix + background_prompt for p in prompts]
         
     | 
| 837 | 
         
            +
                            prompts = [background_prompt] + prompts
         
     | 
| 838 | 
         
            +
                            negative_prompts = [background_negative_prompt] + negative_prompts
         
     | 
| 839 | 
         
            +
                            has_background = False # Regard that background does not exist.
         
     | 
| 840 | 
         
            +
                        else:
         
     | 
| 841 | 
         
            +
                            if background is None and background_prompt is None:
         
     | 
| 842 | 
         
            +
                                background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
         
     | 
| 843 | 
         
            +
                                background_prompt = 'simple white background image'
         
     | 
| 844 | 
         
            +
                            elif background is not None and background_prompt is None:
         
     | 
| 845 | 
         
            +
                                background_prompt = self.get_text_prompts(background)
         
     | 
| 846 | 
         
            +
                            if suffix is not None:
         
     | 
| 847 | 
         
            +
                                prompts = [p + suffix + background_prompt for p in prompts]
         
     | 
| 848 | 
         
            +
                            prompts = [background_prompt] + prompts
         
     | 
| 849 | 
         
            +
                            negative_prompts = [background_negative_prompt] + negative_prompts
         
     | 
| 850 | 
         
            +
                            if isinstance(background, Image.Image):
         
     | 
| 851 | 
         
            +
                                background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
         
     | 
| 852 | 
         
            +
                            background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
         
     | 
| 853 | 
         
            +
                            bg_latent = self.encode_imgs(background)
         
     | 
| 854 | 
         
            +
             
     | 
| 855 | 
         
            +
                    # Bootstrapping stage preparation.
         
     | 
| 856 | 
         
            +
             
     | 
| 857 | 
         
            +
                    if bootstrap_steps is None:
         
     | 
| 858 | 
         
            +
                        bootstrap_steps = self.default_bootstrap_steps
         
     | 
| 859 | 
         
            +
                    if boostrap_mix_steps is None:
         
     | 
| 860 | 
         
            +
                        boostrap_mix_steps = self.default_boostrap_mix_steps
         
     | 
| 861 | 
         
            +
                    if bootstrap_leak_sensitivity is None:
         
     | 
| 862 | 
         
            +
                        bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
         
     | 
| 863 | 
         
            +
                    if bootstrap_steps > 0:
         
     | 
| 864 | 
         
            +
                        height_ = min(height, tile_size)
         
     | 
| 865 | 
         
            +
                        width_ = min(width, tile_size)
         
     | 
| 866 | 
         
            +
                        white = self.get_white_background(height, width) # (1, 4, h, w)
         
     | 
| 867 | 
         
            +
             
     | 
| 868 | 
         
            +
             
     | 
| 869 | 
         
            +
                    ### Prepare text embeddings (optimized for the minimal encoder batch size)
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                    # SD3 pipeline settings.
         
     | 
| 872 | 
         
            +
                    batch_size = 1
         
     | 
| 873 | 
         
            +
                    num_images_per_prompt = 1
         
     | 
| 874 | 
         
            +
             
     | 
| 875 | 
         
            +
                    original_size = (height, width)
         
     | 
| 876 | 
         
            +
                    target_size = (height, width)
         
     | 
| 877 | 
         
            +
                    crops_coords_top_left = (0, 0)
         
     | 
| 878 | 
         
            +
                    negative_original_size = None
         
     | 
| 879 | 
         
            +
                    negative_target_size = None
         
     | 
| 880 | 
         
            +
                    negative_crops_coords_top_left = (0, 0)
         
     | 
| 881 | 
         
            +
             
     | 
| 882 | 
         
            +
                    prompt_2 = None
         
     | 
| 883 | 
         
            +
                    prompt_3 = None
         
     | 
| 884 | 
         
            +
                    negative_prompt_2 = None
         
     | 
| 885 | 
         
            +
                    negative_prompt_3 = None
         
     | 
| 886 | 
         
            +
                    prompt_embeds = None
         
     | 
| 887 | 
         
            +
                    negative_prompt_embeds = None
         
     | 
| 888 | 
         
            +
                    pooled_prompt_embeds = None
         
     | 
| 889 | 
         
            +
                    negative_pooled_prompt_embeds = None
         
     | 
| 890 | 
         
            +
                    text_encoder_lora_scale = None
         
     | 
| 891 | 
         
            +
             
     | 
| 892 | 
         
            +
                    (
         
     | 
| 893 | 
         
            +
                        prompt_embeds,
         
     | 
| 894 | 
         
            +
                        negative_prompt_embeds,
         
     | 
| 895 | 
         
            +
                        pooled_prompt_embeds,
         
     | 
| 896 | 
         
            +
                        negative_pooled_prompt_embeds,
         
     | 
| 897 | 
         
            +
                    ) = self.pipe.encode_prompt(
         
     | 
| 898 | 
         
            +
                        prompt=prompts,
         
     | 
| 899 | 
         
            +
                        prompt_2=prompt_2,
         
     | 
| 900 | 
         
            +
                        prompt_3=prompt_3,
         
     | 
| 901 | 
         
            +
                        negative_prompt=negative_prompts,
         
     | 
| 902 | 
         
            +
                        negative_prompt_2=negative_prompt_2,
         
     | 
| 903 | 
         
            +
                        negative_prompt_3=negative_prompt_3,
         
     | 
| 904 | 
         
            +
                        do_classifier_free_guidance=do_classifier_free_guidance,
         
     | 
| 905 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 906 | 
         
            +
                        negative_prompt_embeds=negative_prompt_embeds,
         
     | 
| 907 | 
         
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 908 | 
         
            +
                        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
         
     | 
| 909 | 
         
            +
                        device=self.device,
         
     | 
| 910 | 
         
            +
                        clip_skip=self.pipe.clip_skip,
         
     | 
| 911 | 
         
            +
                        num_images_per_prompt=num_images_per_prompt,
         
     | 
| 912 | 
         
            +
                    )
         
     | 
| 913 | 
         
            +
             
     | 
| 914 | 
         
            +
                    if has_background:
         
     | 
| 915 | 
         
            +
                        # First channel is background prompt text embeds. Background prompt itself is not used for generation.
         
     | 
| 916 | 
         
            +
                        s = prompt_strengths
         
     | 
| 917 | 
         
            +
                        if prompt_strengths is None:
         
     | 
| 918 | 
         
            +
                            s = self.default_prompt_strength
         
     | 
| 919 | 
         
            +
                        if isinstance(s, (int, float)):
         
     | 
| 920 | 
         
            +
                            s = [s] * num_prompts
         
     | 
| 921 | 
         
            +
                        if isinstance(s, (list, tuple)):
         
     | 
| 922 | 
         
            +
                            assert len(s) == num_prompts, \
         
     | 
| 923 | 
         
            +
                                f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
         
     | 
| 924 | 
         
            +
                            s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
         
     | 
| 925 | 
         
            +
                        s = s[:, None, None]
         
     | 
| 926 | 
         
            +
             
     | 
| 927 | 
         
            +
                        be = prompt_embeds[:1]
         
     | 
| 928 | 
         
            +
                        fe = prompt_embeds[1:]
         
     | 
| 929 | 
         
            +
                        prompt_embeds = torch.lerp(be, fe, s)  # (p, 77, 1024)
         
     | 
| 930 | 
         
            +
             
     | 
| 931 | 
         
            +
                        if negative_prompt_embeds is not None:
         
     | 
| 932 | 
         
            +
                            bu = negative_prompt_embeds[:1]
         
     | 
| 933 | 
         
            +
                            fu = negative_prompt_embeds[1:]
         
     | 
| 934 | 
         
            +
                            if num_prompts > num_nprompts:
         
     | 
| 935 | 
         
            +
                                # # negative prompts = 1; # prompts > 1.
         
     | 
| 936 | 
         
            +
                                assert fu.shape[0] == 1 and fe.shape == num_prompts
         
     | 
| 937 | 
         
            +
                                fu = fu.repeat(num_prompts, 1, 1)
         
     | 
| 938 | 
         
            +
                            negative_prompt_embeds = torch.lerp(bu, fu, s)  # (n, 77, 1024)
         
     | 
| 939 | 
         
            +
             
     | 
| 940 | 
         
            +
                        be = pooled_prompt_embeds[:1]
         
     | 
| 941 | 
         
            +
                        fe = pooled_prompt_embeds[1:]
         
     | 
| 942 | 
         
            +
                        pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0])  # (p, 1280)
         
     | 
| 943 | 
         
            +
             
     | 
| 944 | 
         
            +
                        if negative_pooled_prompt_embeds is not None:
         
     | 
| 945 | 
         
            +
                            bu = negative_pooled_prompt_embeds[:1]
         
     | 
| 946 | 
         
            +
                            fu = negative_pooled_prompt_embeds[1:]
         
     | 
| 947 | 
         
            +
                            if num_prompts > num_nprompts:
         
     | 
| 948 | 
         
            +
                                # # negative prompts = 1; # prompts > 1.
         
     | 
| 949 | 
         
            +
                                assert fu.shape[0] == 1 and fe.shape == num_prompts
         
     | 
| 950 | 
         
            +
                                fu = fu.repeat(num_prompts, 1)
         
     | 
| 951 | 
         
            +
                            negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0])  # (n, 1280)
         
     | 
| 952 | 
         
            +
                    elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
         
     | 
| 953 | 
         
            +
                        # # negative prompts = 1; # prompts > 1.
         
     | 
| 954 | 
         
            +
                        assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
         
     | 
| 955 | 
         
            +
                        negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
         
     | 
| 956 | 
         
            +
             
     | 
| 957 | 
         
            +
                        assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
         
     | 
| 958 | 
         
            +
                        negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
         
     | 
| 959 | 
         
            +
                    # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
         
     | 
| 960 | 
         
            +
                    if num_masks > num_prompts:
         
     | 
| 961 | 
         
            +
                        assert masks.shape[0] == num_masks and num_prompts == 1
         
     | 
| 962 | 
         
            +
                        prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1)
         
     | 
| 963 | 
         
            +
                        if negative_prompt_embeds is not None:
         
     | 
| 964 | 
         
            +
                            negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
         
     | 
| 965 | 
         
            +
             
     | 
| 966 | 
         
            +
                        pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
         
     | 
| 967 | 
         
            +
                        if negative_pooled_prompt_embeds is not None:
         
     | 
| 968 | 
         
            +
                            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
         
     | 
| 969 | 
         
            +
             
     | 
| 970 | 
         
            +
                    # SD3 pipeline settings.
         
     | 
| 971 | 
         
            +
                    if do_classifier_free_guidance:
         
     | 
| 972 | 
         
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         
     | 
| 973 | 
         
            +
                        pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
         
     | 
| 974 | 
         
            +
                    del negative_prompt_embeds, negative_pooled_prompt_embeds
         
     | 
| 975 | 
         
            +
             
     | 
| 976 | 
         
            +
                    prompt_embeds = prompt_embeds.to(self.device)
         
     | 
| 977 | 
         
            +
                    pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
         
     | 
| 978 | 
         
            +
             
     | 
| 979 | 
         
            +
             
     | 
| 980 | 
         
            +
                    ### Run
         
     | 
| 981 | 
         
            +
             
     | 
| 982 | 
         
            +
                    # Latent initialization.
         
     | 
| 983 | 
         
            +
                    num_channels_latents = self.transformer.config.in_channels
         
     | 
| 984 | 
         
            +
                    noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
         
     | 
| 985 | 
         
            +
                    if self.timesteps[0] < 999 and has_background:
         
     | 
| 986 | 
         
            +
                        latent = self.scheduler_add_noise(bg_latent, noise, 0)
         
     | 
| 987 | 
         
            +
                    else:
         
     | 
| 988 | 
         
            +
                        noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
         
     | 
| 989 | 
         
            +
                        latent = noise
         
     | 
| 990 | 
         
            +
             
     | 
| 991 | 
         
            +
                    if has_background:
         
     | 
| 992 | 
         
            +
                        noise_bg_latents = [
         
     | 
| 993 | 
         
            +
                            self.scheduler_add_noise(bg_latent, noise, i) for i in range(len(self.timesteps))
         
     | 
| 994 | 
         
            +
                        ] + [bg_latent]
         
     | 
| 995 | 
         
            +
             
     | 
| 996 | 
         
            +
                    # Tiling (if needed).
         
     | 
| 997 | 
         
            +
                    if height > tile_size or width > tile_size:
         
     | 
| 998 | 
         
            +
                        t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
         
     | 
| 999 | 
         
            +
                        views, tile_masks = get_panorama_views(h, w, t)
         
     | 
| 1000 | 
         
            +
                        tile_masks = tile_masks.to(self.device)
         
     | 
| 1001 | 
         
            +
                    else:
         
     | 
| 1002 | 
         
            +
                        views = [(0, h, 0, w)]
         
     | 
| 1003 | 
         
            +
                        tile_masks = latent.new_ones((1, 1, h, w))
         
     | 
| 1004 | 
         
            +
                    value = torch.zeros_like(latent)
         
     | 
| 1005 | 
         
            +
                    count_all = torch.zeros_like(latent)
         
     | 
| 1006 | 
         
            +
             
     | 
| 1007 | 
         
            +
                    with torch.autocast('cuda'):
         
     | 
| 1008 | 
         
            +
                        for i, t in enumerate(tqdm(self.timesteps)):
         
     | 
| 1009 | 
         
            +
                            if self.pipe.interrupt:
         
     | 
| 1010 | 
         
            +
                                continue
         
     | 
| 1011 | 
         
            +
             
     | 
| 1012 | 
         
            +
                            fg_mask = fg_masks[:, i]
         
     | 
| 1013 | 
         
            +
                            bg_mask = bg_masks[i:i + 1]
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
                            value.zero_()
         
     | 
| 1016 | 
         
            +
                            count_all.zero_()
         
     | 
| 1017 | 
         
            +
                            for j, (h_start, h_end, w_start, w_end) in enumerate(views):
         
     | 
| 1018 | 
         
            +
                                fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
         
     | 
| 1019 | 
         
            +
                                latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
         
     | 
| 1020 | 
         
            +
             
     | 
| 1021 | 
         
            +
                                # Bootstrap for tight background.
         
     | 
| 1022 | 
         
            +
                                if i < bootstrap_steps:
         
     | 
| 1023 | 
         
            +
                                    mix_ratio = min(1, max(0, boostrap_mix_steps - i))
         
     | 
| 1024 | 
         
            +
                                    # Treat the first foreground latent as the background latent if one does not exist.
         
     | 
| 1025 | 
         
            +
                                    bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
         
     | 
| 1026 | 
         
            +
                                    white_ = white[..., h_start:h_end, w_start:w_end]
         
     | 
| 1027 | 
         
            +
                                    white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i)
         
     | 
| 1028 | 
         
            +
                                    bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
         
     | 
| 1029 | 
         
            +
                                    latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
         
     | 
| 1030 | 
         
            +
             
     | 
| 1031 | 
         
            +
                                    # Centering.
         
     | 
| 1032 | 
         
            +
                                    latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
         
     | 
| 1033 | 
         
            +
             
     | 
| 1034 | 
         
            +
                                # expand the latents if we are doing classifier free guidance
         
     | 
| 1035 | 
         
            +
                                latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
         
     | 
| 1036 | 
         
            +
                                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 1037 | 
         
            +
                                timestep = t.expand(latent_model_input.shape[0])
         
     | 
| 1038 | 
         
            +
             
     | 
| 1039 | 
         
            +
                                # Perform one step of the reverse diffusion.
         
     | 
| 1040 | 
         
            +
                                noise_pred = self.transformer(
         
     | 
| 1041 | 
         
            +
                                    hidden_states=latent_model_input,
         
     | 
| 1042 | 
         
            +
                                    timestep=timestep,
         
     | 
| 1043 | 
         
            +
                                    encoder_hidden_states=prompt_embeds,
         
     | 
| 1044 | 
         
            +
                                    pooled_projections=pooled_prompt_embeds,
         
     | 
| 1045 | 
         
            +
                                    joint_attention_kwargs=joint_attention_kwargs,
         
     | 
| 1046 | 
         
            +
                                    return_dict=False,
         
     | 
| 1047 | 
         
            +
                                )[0]
         
     | 
| 1048 | 
         
            +
             
     | 
| 1049 | 
         
            +
                                if do_classifier_free_guidance:
         
     | 
| 1050 | 
         
            +
                                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
         
     | 
| 1051 | 
         
            +
                                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
         
     | 
| 1052 | 
         
            +
             
     | 
| 1053 | 
         
            +
                                if do_classifier_free_guidance and guidance_rescale > 0.0:
         
     | 
| 1054 | 
         
            +
                                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
         
     | 
| 1055 | 
         
            +
                                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
         
     | 
| 1056 | 
         
            +
             
     | 
| 1057 | 
         
            +
                                latent_ = self.scheduler_step(noise_pred, i, latent_)
         
     | 
| 1058 | 
         
            +
             
     | 
| 1059 | 
         
            +
                                if i < bootstrap_steps:
         
     | 
| 1060 | 
         
            +
                                    # Uncentering.
         
     | 
| 1061 | 
         
            +
                                    latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
         
     | 
| 1062 | 
         
            +
             
     | 
| 1063 | 
         
            +
                                    # Remove leakage (optional).
         
     | 
| 1064 | 
         
            +
                                    leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
         
     | 
| 1065 | 
         
            +
                                    leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
         
     | 
| 1066 | 
         
            +
                                    fg_mask_ = fg_mask_ * leak_sigmoid
         
     | 
| 1067 | 
         
            +
             
     | 
| 1068 | 
         
            +
                                # Mix the latents.
         
     | 
| 1069 | 
         
            +
                                fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
         
     | 
| 1070 | 
         
            +
                                value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
         
     | 
| 1071 | 
         
            +
                                count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
         
     | 
| 1072 | 
         
            +
             
     | 
| 1073 | 
         
            +
                            latent = torch.where(count_all > 0, value / count_all, value)
         
     | 
| 1074 | 
         
            +
                            bg_mask = (1 - count_all).clip_(0, 1)  # (T, 1, h, w)
         
     | 
| 1075 | 
         
            +
                            if has_background:
         
     | 
| 1076 | 
         
            +
                                latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
         
     | 
| 1077 | 
         
            +
             
     | 
| 1078 | 
         
            +
                            # Noise is added after mixing.
         
     | 
| 1079 | 
         
            +
                            if i < len(self.timesteps) - 1:
         
     | 
| 1080 | 
         
            +
                                latent = self.scheduler_add_noise(latent, None, i + 1)
         
     | 
| 1081 | 
         
            +
             
     | 
| 1082 | 
         
            +
                    if not output_type == "latent":
         
     | 
| 1083 | 
         
            +
                        latent = (latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
         
     | 
| 1084 | 
         
            +
                        image = self.vae.decode(latent, return_dict=False)[0]
         
     | 
| 1085 | 
         
            +
                    else:
         
     | 
| 1086 | 
         
            +
                        image = latent
         
     | 
| 1087 | 
         
            +
             
     | 
| 1088 | 
         
            +
                    # Return PIL Image.
         
     | 
| 1089 | 
         
            +
                    image = image[0].clip_(-1, 1) * 0.5 + 0.5
         
     | 
| 1090 | 
         
            +
                    if has_background and do_blend:
         
     | 
| 1091 | 
         
            +
                        fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
         
     | 
| 1092 | 
         
            +
                        image = blend(image, background[0], fg_mask)
         
     | 
| 1093 | 
         
            +
                    else:
         
     | 
| 1094 | 
         
            +
                        image = T.ToPILImage()(image)
         
     | 
| 1095 | 
         
            +
                    return image
         
     | 
    	
        prompt_util.py
    ADDED
    
    | 
         @@ -0,0 +1,154 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, List, Tuple, Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            quality_prompt_list = [
         
     | 
| 5 | 
         
            +
                {
         
     | 
| 6 | 
         
            +
                    "name": "(None)",
         
     | 
| 7 | 
         
            +
                    "prompt": "{prompt}",
         
     | 
| 8 | 
         
            +
                    "negative_prompt": "nsfw, lowres",
         
     | 
| 9 | 
         
            +
                },
         
     | 
| 10 | 
         
            +
                {
         
     | 
| 11 | 
         
            +
                    "name": "Standard v3.0",
         
     | 
| 12 | 
         
            +
                    "prompt": "{prompt}, masterpiece, best quality",
         
     | 
| 13 | 
         
            +
                    "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
         
     | 
| 14 | 
         
            +
                },
         
     | 
| 15 | 
         
            +
                {
         
     | 
| 16 | 
         
            +
                    "name": "Standard v3.1",
         
     | 
| 17 | 
         
            +
                    "prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
         
     | 
| 18 | 
         
            +
                    "negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         
     | 
| 19 | 
         
            +
                },
         
     | 
| 20 | 
         
            +
                {
         
     | 
| 21 | 
         
            +
                    "name": "Light v3.1",
         
     | 
| 22 | 
         
            +
                    "prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
         
     | 
| 23 | 
         
            +
                    "negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
         
     | 
| 24 | 
         
            +
                },
         
     | 
| 25 | 
         
            +
                {
         
     | 
| 26 | 
         
            +
                    "name": "Heavy v3.1",
         
     | 
| 27 | 
         
            +
                    "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
         
     | 
| 28 | 
         
            +
                    "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
         
     | 
| 29 | 
         
            +
                },
         
     | 
| 30 | 
         
            +
            ]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            style_list = [
         
     | 
| 33 | 
         
            +
                {
         
     | 
| 34 | 
         
            +
                    "name": "(None)",
         
     | 
| 35 | 
         
            +
                    "prompt": "{prompt}",
         
     | 
| 36 | 
         
            +
                    "negative_prompt": "",
         
     | 
| 37 | 
         
            +
                },
         
     | 
| 38 | 
         
            +
                {
         
     | 
| 39 | 
         
            +
                    "name": "Cinematic",
         
     | 
| 40 | 
         
            +
                    "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
         
     | 
| 41 | 
         
            +
                    "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
         
     | 
| 42 | 
         
            +
                },
         
     | 
| 43 | 
         
            +
                {
         
     | 
| 44 | 
         
            +
                    "name": "Photographic",
         
     | 
| 45 | 
         
            +
                    "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
         
     | 
| 46 | 
         
            +
                    "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
         
     | 
| 47 | 
         
            +
                },
         
     | 
| 48 | 
         
            +
                {
         
     | 
| 49 | 
         
            +
                    "name": "Anime",
         
     | 
| 50 | 
         
            +
                    "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
         
     | 
| 51 | 
         
            +
                    "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
         
     | 
| 52 | 
         
            +
                },
         
     | 
| 53 | 
         
            +
                {
         
     | 
| 54 | 
         
            +
                    "name": "Manga",
         
     | 
| 55 | 
         
            +
                    "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
         
     | 
| 56 | 
         
            +
                    "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
         
     | 
| 57 | 
         
            +
                },
         
     | 
| 58 | 
         
            +
                {
         
     | 
| 59 | 
         
            +
                    "name": "Digital Art",
         
     | 
| 60 | 
         
            +
                    "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
         
     | 
| 61 | 
         
            +
                    "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
         
     | 
| 62 | 
         
            +
                },
         
     | 
| 63 | 
         
            +
                {
         
     | 
| 64 | 
         
            +
                    "name": "Pixel art",
         
     | 
| 65 | 
         
            +
                    "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
         
     | 
| 66 | 
         
            +
                    "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
         
     | 
| 67 | 
         
            +
                },
         
     | 
| 68 | 
         
            +
                {
         
     | 
| 69 | 
         
            +
                    "name": "Fantasy art",
         
     | 
| 70 | 
         
            +
                    "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
         
     | 
| 71 | 
         
            +
                    "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
         
     | 
| 72 | 
         
            +
                },
         
     | 
| 73 | 
         
            +
                {
         
     | 
| 74 | 
         
            +
                    "name": "Neonpunk",
         
     | 
| 75 | 
         
            +
                    "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
         
     | 
| 76 | 
         
            +
                    "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
         
     | 
| 77 | 
         
            +
                },
         
     | 
| 78 | 
         
            +
                {
         
     | 
| 79 | 
         
            +
                    "name": "3D Model",
         
     | 
| 80 | 
         
            +
                    "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
         
     | 
| 81 | 
         
            +
                    "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
         
     | 
| 82 | 
         
            +
                },
         
     | 
| 83 | 
         
            +
            ]
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            _style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
         
     | 
| 87 | 
         
            +
            _quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def preprocess_prompt(
         
     | 
| 91 | 
         
            +
                positive: str,
         
     | 
| 92 | 
         
            +
                negative: str = "",
         
     | 
| 93 | 
         
            +
                style_dict: Dict[str, dict] = _quality_dict,
         
     | 
| 94 | 
         
            +
                style_name: str = "Standard v3.1", # "Heavy v3.1"
         
     | 
| 95 | 
         
            +
                add_style: bool = True,
         
     | 
| 96 | 
         
            +
            ) -> Tuple[str, str]:
         
     | 
| 97 | 
         
            +
                p, n = style_dict.get(style_name, style_dict["(None)"])
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                if add_style and positive.strip():
         
     | 
| 100 | 
         
            +
                    formatted_positive = p.format(prompt=positive)
         
     | 
| 101 | 
         
            +
                else:
         
     | 
| 102 | 
         
            +
                    formatted_positive = positive
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                combined_negative = n
         
     | 
| 105 | 
         
            +
                if negative.strip():
         
     | 
| 106 | 
         
            +
                    if combined_negative:
         
     | 
| 107 | 
         
            +
                        combined_negative += ", " + negative
         
     | 
| 108 | 
         
            +
                    else:
         
     | 
| 109 | 
         
            +
                        combined_negative = negative
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                return formatted_positive, combined_negative
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            def preprocess_prompts(
         
     | 
| 115 | 
         
            +
                positives: List[str],
         
     | 
| 116 | 
         
            +
                negatives: List[str] = None,
         
     | 
| 117 | 
         
            +
                style_dict = _style_dict,
         
     | 
| 118 | 
         
            +
                style_name: str = "Manga", # "(None)"
         
     | 
| 119 | 
         
            +
                quality_dict = _quality_dict,
         
     | 
| 120 | 
         
            +
                quality_name: str = "Standard v3.1", # "Heavy v3.1"
         
     | 
| 121 | 
         
            +
                add_style: bool = True,
         
     | 
| 122 | 
         
            +
                add_quality_tags = True,
         
     | 
| 123 | 
         
            +
            ) -> Tuple[List[str], List[str]]:
         
     | 
| 124 | 
         
            +
                if negatives is None:
         
     | 
| 125 | 
         
            +
                    negatives = ['' for _ in positives]
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                positives_ = []
         
     | 
| 128 | 
         
            +
                negatives_ = []
         
     | 
| 129 | 
         
            +
                for pos, neg in zip(positives, negatives):
         
     | 
| 130 | 
         
            +
                    pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
         
     | 
| 131 | 
         
            +
                    pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
         
     | 
| 132 | 
         
            +
                    positives_.append(pos)
         
     | 
| 133 | 
         
            +
                    negatives_.append(neg)
         
     | 
| 134 | 
         
            +
                return positives_, negatives_
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            def print_prompts(
         
     | 
| 138 | 
         
            +
                positives: Union[str, List[str]],
         
     | 
| 139 | 
         
            +
                negatives: Union[str, List[str]],
         
     | 
| 140 | 
         
            +
                has_background: bool = False,
         
     | 
| 141 | 
         
            +
            ) -> None:
         
     | 
| 142 | 
         
            +
                if isinstance(positives, str):
         
     | 
| 143 | 
         
            +
                    positives = [positives]
         
     | 
| 144 | 
         
            +
                if isinstance(negatives, str):
         
     | 
| 145 | 
         
            +
                    negatives = [negatives]
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                for i, prompt in enumerate(positives):
         
     | 
| 148 | 
         
            +
                    prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
         
     | 
| 149 | 
         
            +
                              if has_background else f'Prompt{i + 1}')
         
     | 
| 150 | 
         
            +
                    print(prefix + ': ' + prompt)
         
     | 
| 151 | 
         
            +
                for i, prompt in enumerate(negatives):
         
     | 
| 152 | 
         
            +
                    prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
         
     | 
| 153 | 
         
            +
                              if has_background else f'Negative Prompt{i + 1}')
         
     | 
| 154 | 
         
            +
                    print(prefix + ': ' + prompt)
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch==2.0.1
         
     | 
| 2 | 
         
            +
            torchvision
         
     | 
| 3 | 
         
            +
            xformers==0.0.22
         
     | 
| 4 | 
         
            +
            einops
         
     | 
| 5 | 
         
            +
            diffusers @ git+https://github.com/initml/diffusers.git@clement/feature/flash_sd3
         
     | 
| 6 | 
         
            +
            transformers
         
     | 
| 7 | 
         
            +
            huggingface_hub[torch]
         
     | 
| 8 | 
         
            +
            gradio==4.39.0
         
     | 
| 9 | 
         
            +
            Pillow
         
     | 
| 10 | 
         
            +
            emoji
         
     | 
| 11 | 
         
            +
            numpy
         
     | 
| 12 | 
         
            +
            tqdm
         
     | 
| 13 | 
         
            +
            jupyterlab
         
     | 
| 14 | 
         
            +
            peft>=0.10.0
         
     | 
| 15 | 
         
            +
            sentencepiece
         
     | 
| 16 | 
         
            +
            protobuf
         
     | 
    	
        share_btn.py
    ADDED
    
    | 
         @@ -0,0 +1,70 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            share_js = """async () => {
         
     | 
| 2 | 
         
            +
            	async function uploadFile(file) {
         
     | 
| 3 | 
         
            +
            		const UPLOAD_URL = 'https://huggingface.co/uploads';
         
     | 
| 4 | 
         
            +
            		const response = await fetch(UPLOAD_URL, {
         
     | 
| 5 | 
         
            +
            			method: 'POST',
         
     | 
| 6 | 
         
            +
            			headers: {
         
     | 
| 7 | 
         
            +
            				'Content-Type': file.type,
         
     | 
| 8 | 
         
            +
            				'X-Requested-With': 'XMLHttpRequest',
         
     | 
| 9 | 
         
            +
            			},
         
     | 
| 10 | 
         
            +
            			body: file, /// <- File inherits from Blob
         
     | 
| 11 | 
         
            +
            		});
         
     | 
| 12 | 
         
            +
            		const url = await response.text();
         
     | 
| 13 | 
         
            +
            		return url;
         
     | 
| 14 | 
         
            +
            	}
         
     | 
| 15 | 
         
            +
                async function getBase64(file) {
         
     | 
| 16 | 
         
            +
                   var reader = new FileReader();
         
     | 
| 17 | 
         
            +
                   reader.readAsDataURL(file);
         
     | 
| 18 | 
         
            +
                   reader.onload = function () {
         
     | 
| 19 | 
         
            +
                       console.log(reader.result);
         
     | 
| 20 | 
         
            +
                   };
         
     | 
| 21 | 
         
            +
                   reader.onerror = function (error) {
         
     | 
| 22 | 
         
            +
                       console.log('Error: ', error);
         
     | 
| 23 | 
         
            +
                   };
         
     | 
| 24 | 
         
            +
                }
         
     | 
| 25 | 
         
            +
                const toDataURL = url => fetch(url)
         
     | 
| 26 | 
         
            +
                    .then(response => response.blob())
         
     | 
| 27 | 
         
            +
                    .then(blob => new Promise((resolve, reject) => {
         
     | 
| 28 | 
         
            +
                        const reader = new FileReader()
         
     | 
| 29 | 
         
            +
                        reader.onloadend = () => resolve(reader.result)
         
     | 
| 30 | 
         
            +
                        reader.onerror = reject
         
     | 
| 31 | 
         
            +
                        reader.readAsDataURL(blob)
         
     | 
| 32 | 
         
            +
                    }));
         
     | 
| 33 | 
         
            +
                async function dataURLtoFile(dataurl, filename) {
         
     | 
| 34 | 
         
            +
                    var arr = dataurl.split(','), mime = arr[0].match(/:(.*?);/)[1],
         
     | 
| 35 | 
         
            +
                    bstr = atob(arr[1]), n = bstr.length, u8arr = new Uint8Array(n);
         
     | 
| 36 | 
         
            +
                    while (n--) {
         
     | 
| 37 | 
         
            +
                        u8arr[n] = bstr.charCodeAt(n);
         
     | 
| 38 | 
         
            +
                    }
         
     | 
| 39 | 
         
            +
                    return new File([u8arr], filename, {type:mime});
         
     | 
| 40 | 
         
            +
                };
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                const gradioEl = document.querySelector('body > gradio-app');
         
     | 
| 43 | 
         
            +
                const imgEls = gradioEl.querySelectorAll('#output-screen img');
         
     | 
| 44 | 
         
            +
                if(!imgEls.length){
         
     | 
| 45 | 
         
            +
                    return;
         
     | 
| 46 | 
         
            +
                };
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                const urls = await Promise.all([...imgEls].map((imgEl) => {
         
     | 
| 49 | 
         
            +
                    const origURL = imgEl.src;
         
     | 
| 50 | 
         
            +
                    const imgId = Date.now() % 200;
         
     | 
| 51 | 
         
            +
                    const fileName = 'semantic-palette-xl-' + imgId + '.png';
         
     | 
| 52 | 
         
            +
                    return toDataURL(origURL)
         
     | 
| 53 | 
         
            +
                        .then(dataUrl => {
         
     | 
| 54 | 
         
            +
                            return dataURLtoFile(dataUrl, fileName);
         
     | 
| 55 | 
         
            +
                        })
         
     | 
| 56 | 
         
            +
                    })).then(fileData => {return Promise.all([...fileData].map((file) => {
         
     | 
| 57 | 
         
            +
                        return uploadFile(file);
         
     | 
| 58 | 
         
            +
                    }))});
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            	const htmlImgs = urls.map(url => `<img src='${url}' width='2560' height='1024'>`);
         
     | 
| 61 | 
         
            +
            	const descriptionMd = `<div style='display: flex; flex-wrap: wrap; column-gap: 0.75rem;'>
         
     | 
| 62 | 
         
            +
            ${htmlImgs.join(`\n`)}
         
     | 
| 63 | 
         
            +
            </div>`;
         
     | 
| 64 | 
         
            +
                const params = new URLSearchParams({
         
     | 
| 65 | 
         
            +
                    title: `My creation`,
         
     | 
| 66 | 
         
            +
                    description: descriptionMd,
         
     | 
| 67 | 
         
            +
                });
         
     | 
| 68 | 
         
            +
            	const paramsStr = params.toString();
         
     | 
| 69 | 
         
            +
            	window.open(`https://huggingface.co/spaces/ironjr/SemanticPaletteXL/discussions/new?${paramsStr}`, '_blank');
         
     | 
| 70 | 
         
            +
            }"""
         
     | 
    	
        util.py
    ADDED
    
    | 
         @@ -0,0 +1,315 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Jaerin Lee
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 4 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 5 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 6 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 7 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 8 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 11 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 14 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 15 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 16 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 17 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 18 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 19 | 
         
            +
            # SOFTWARE.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import concurrent.futures
         
     | 
| 22 | 
         
            +
            import time
         
     | 
| 23 | 
         
            +
            from typing import Any, Callable, List, Literal, Tuple, Union
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from PIL import Image
         
     | 
| 26 | 
         
            +
            import numpy as np
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            import torch
         
     | 
| 29 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 30 | 
         
            +
            import torch.cuda.amp as amp
         
     | 
| 31 | 
         
            +
            import torchvision.transforms as T
         
     | 
| 32 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            from diffusers import (
         
     | 
| 35 | 
         
            +
                DiffusionPipeline,
         
     | 
| 36 | 
         
            +
                StableDiffusionPipeline,
         
     | 
| 37 | 
         
            +
                StableDiffusionXLPipeline,
         
     | 
| 38 | 
         
            +
            )
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def seed_everything(seed: int) -> None:
         
     | 
| 42 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 43 | 
         
            +
                torch.cuda.manual_seed(seed)
         
     | 
| 44 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 45 | 
         
            +
                torch.backends.cudnn.benchmark = True
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def load_model(
         
     | 
| 49 | 
         
            +
                model_key: str,
         
     | 
| 50 | 
         
            +
                sd_version: Literal['1.5', 'xl'],
         
     | 
| 51 | 
         
            +
                device: torch.device,
         
     | 
| 52 | 
         
            +
                dtype: torch.dtype,
         
     | 
| 53 | 
         
            +
            ) -> torch.nn.Module:
         
     | 
| 54 | 
         
            +
                if model_key.endswith('.safetensors'):
         
     | 
| 55 | 
         
            +
                    if sd_version == '1.5':
         
     | 
| 56 | 
         
            +
                        pipeline = StableDiffusionPipeline
         
     | 
| 57 | 
         
            +
                    elif sd_version == 'xl':
         
     | 
| 58 | 
         
            +
                        pipeline = StableDiffusionXLPipeline
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
         
     | 
| 61 | 
         
            +
                    return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
         
     | 
| 62 | 
         
            +
                try:
         
     | 
| 63 | 
         
            +
                    return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
         
     | 
| 64 | 
         
            +
                except:
         
     | 
| 65 | 
         
            +
                    return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def get_cutoff(cutoff: float = None, scale: float = None) -> float:
         
     | 
| 69 | 
         
            +
                if cutoff is not None:
         
     | 
| 70 | 
         
            +
                    return cutoff
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                if scale is not None and cutoff is None:
         
     | 
| 73 | 
         
            +
                    return 0.5 / scale
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            def get_scale(cutoff: float = None, scale: float = None) -> float:
         
     | 
| 79 | 
         
            +
                if scale is not None:
         
     | 
| 80 | 
         
            +
                    return scale
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                if cutoff is not None and scale is None:
         
     | 
| 83 | 
         
            +
                    return 0.5 / cutoff
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
         
     | 
| 89 | 
         
            +
                assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
         
     | 
| 90 | 
         
            +
                #  assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                b, c, h, w = x.shape
         
     | 
| 93 | 
         
            +
                ks = k.shape[-1]
         
     | 
| 94 | 
         
            +
                k = k.view(1, 1, -1).repeat(c, 1, 1)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                x = x.permute(0, 2, 1, 3)
         
     | 
| 97 | 
         
            +
                x = x.reshape(b * h, c, w)
         
     | 
| 98 | 
         
            +
                x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
         
     | 
| 99 | 
         
            +
                x = F.conv1d(x, k, groups=c)
         
     | 
| 100 | 
         
            +
                x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
         
     | 
| 101 | 
         
            +
                x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
         
     | 
| 102 | 
         
            +
                x = F.conv1d(x, k, groups=c)
         
     | 
| 103 | 
         
            +
                x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
         
     | 
| 104 | 
         
            +
                return x
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
         
     | 
| 108 | 
         
            +
                assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                x = F.pad(x, (
         
     | 
| 111 | 
         
            +
                    k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
         
     | 
| 112 | 
         
            +
                    k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
         
     | 
| 113 | 
         
            +
                ), mode='replicate')
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                b, c, _, _ = x.shape
         
     | 
| 116 | 
         
            +
                if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
         
     | 
| 117 | 
         
            +
                    k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
         
     | 
| 118 | 
         
            +
                    x = F.conv2d(x, k, groups=c)
         
     | 
| 119 | 
         
            +
                elif len(k.shape) == 3:
         
     | 
| 120 | 
         
            +
                    assert k.shape[0] == b, \
         
     | 
| 121 | 
         
            +
                        'The number of kernels should match the batch size.'
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    k = k.unsqueeze(1)
         
     | 
| 124 | 
         
            +
                    x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
         
     | 
| 125 | 
         
            +
                return x
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            @amp.autocast(False)
         
     | 
| 129 | 
         
            +
            def filter_by_kernel(
         
     | 
| 130 | 
         
            +
                x: torch.Tensor,
         
     | 
| 131 | 
         
            +
                k: torch.Tensor,
         
     | 
| 132 | 
         
            +
                is_batch: bool = False,
         
     | 
| 133 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 134 | 
         
            +
                k_dim = len(k.shape)
         
     | 
| 135 | 
         
            +
                if k_dim == 1 or k_dim == 2 and is_batch:
         
     | 
| 136 | 
         
            +
                    return filter_2d_by_kernel_1d(x, k)
         
     | 
| 137 | 
         
            +
                elif k_dim == 2 or k_dim == 3 and is_batch:
         
     | 
| 138 | 
         
            +
                    return filter_2d_by_kernel_2d(x, k)
         
     | 
| 139 | 
         
            +
                else:
         
     | 
| 140 | 
         
            +
                    raise ValueError('Kernel size should be one of (1, 2, 3).')
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            def gen_gauss_lowpass_filter_2d(
         
     | 
| 144 | 
         
            +
                std: torch.Tensor,
         
     | 
| 145 | 
         
            +
                window_size: int = None,
         
     | 
| 146 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 147 | 
         
            +
                # Gaussian kernel size is odd in order to preserve the center.
         
     | 
| 148 | 
         
            +
                if window_size is None:
         
     | 
| 149 | 
         
            +
                    window_size = (
         
     | 
| 150 | 
         
            +
                        2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                y = torch.arange(
         
     | 
| 153 | 
         
            +
                    window_size, dtype=std.dtype, device=std.device
         
     | 
| 154 | 
         
            +
                ).view(-1, 1).repeat(1, window_size)
         
     | 
| 155 | 
         
            +
                grid = torch.stack((y.t(), y), dim=-1)
         
     | 
| 156 | 
         
            +
                grid -= 0.5 * (window_size - 1) # (W, W)
         
     | 
| 157 | 
         
            +
                var = (std * std).unsqueeze(-1).unsqueeze(-1)
         
     | 
| 158 | 
         
            +
                distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
         
     | 
| 159 | 
         
            +
                k = torch.exp(-0.5 * distsq / var)
         
     | 
| 160 | 
         
            +
                k /= k.sum(dim=(-2, -1), keepdim=True)
         
     | 
| 161 | 
         
            +
                return k
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            def gaussian_lowpass(
         
     | 
| 165 | 
         
            +
                x: torch.Tensor,
         
     | 
| 166 | 
         
            +
                std: Union[float, Tuple[float], torch.Tensor] = None,
         
     | 
| 167 | 
         
            +
                cutoff: Union[float, torch.Tensor] = None,
         
     | 
| 168 | 
         
            +
                scale: Union[float, torch.Tensor] = None,
         
     | 
| 169 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 170 | 
         
            +
                if std is None:
         
     | 
| 171 | 
         
            +
                    cutoff = get_cutoff(cutoff, scale)
         
     | 
| 172 | 
         
            +
                    std = 0.5 / (np.pi * cutoff)
         
     | 
| 173 | 
         
            +
                if isinstance(std, (float, int)):
         
     | 
| 174 | 
         
            +
                    std = (std, std)
         
     | 
| 175 | 
         
            +
                if isinstance(std, torch.Tensor):
         
     | 
| 176 | 
         
            +
                    """Using nn.functional.conv2d with Gaussian kernels built in runtime is
         
     | 
| 177 | 
         
            +
                    80% faster than transforms.functional.gaussian_blur for individual
         
     | 
| 178 | 
         
            +
                    items.
         
     | 
| 179 | 
         
            +
                    
         
     | 
| 180 | 
         
            +
                    (in GPU); However, in CPU, the result is exactly opposite. But you
         
     | 
| 181 | 
         
            +
                    won't gonna run this on CPU, right?
         
     | 
| 182 | 
         
            +
                    """
         
     | 
| 183 | 
         
            +
                    if len(list(s for s in std.shape if s != 1)) >= 2:
         
     | 
| 184 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 185 | 
         
            +
                            'Anisotropic Gaussian filter is not currently available.')
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    # k.shape == (B, W, W).
         
     | 
| 188 | 
         
            +
                    k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
         
     | 
| 189 | 
         
            +
                    if k.shape[0] == 1:
         
     | 
| 190 | 
         
            +
                        return filter_by_kernel(x, k[0], False)
         
     | 
| 191 | 
         
            +
                    else:
         
     | 
| 192 | 
         
            +
                        return filter_by_kernel(x, k, True)
         
     | 
| 193 | 
         
            +
                else:
         
     | 
| 194 | 
         
            +
                    # Gaussian kernel size is odd in order to preserve the center.
         
     | 
| 195 | 
         
            +
                    window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
         
     | 
| 196 | 
         
            +
                    return TF.gaussian_blur(x, window_size, std)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
            def blend(
         
     | 
| 200 | 
         
            +
                fg: Union[torch.Tensor, Image.Image],
         
     | 
| 201 | 
         
            +
                bg: Union[torch.Tensor, Image.Image],
         
     | 
| 202 | 
         
            +
                mask: Union[torch.Tensor, Image.Image],
         
     | 
| 203 | 
         
            +
                std: float = 0.0,
         
     | 
| 204 | 
         
            +
            ) -> Image.Image:
         
     | 
| 205 | 
         
            +
                if not isinstance(fg, torch.Tensor):
         
     | 
| 206 | 
         
            +
                    fg = T.ToTensor()(fg)
         
     | 
| 207 | 
         
            +
                if not isinstance(bg, torch.Tensor):
         
     | 
| 208 | 
         
            +
                    bg = T.ToTensor()(bg)
         
     | 
| 209 | 
         
            +
                if not isinstance(mask, torch.Tensor):
         
     | 
| 210 | 
         
            +
                    mask = (T.ToTensor()(mask) < 0.5).float()[:1]
         
     | 
| 211 | 
         
            +
                if std > 0:
         
     | 
| 212 | 
         
            +
                    mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
         
     | 
| 213 | 
         
            +
                return T.ToPILImage()(fg * mask + bg * (1 - mask))
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            def get_panorama_views(
         
     | 
| 217 | 
         
            +
                panorama_height: int,
         
     | 
| 218 | 
         
            +
                panorama_width: int,
         
     | 
| 219 | 
         
            +
                window_size: int = 64,
         
     | 
| 220 | 
         
            +
            ) -> tuple[List[Tuple[int]], torch.Tensor]:
         
     | 
| 221 | 
         
            +
                stride = window_size // 2
         
     | 
| 222 | 
         
            +
                is_horizontal = panorama_width > panorama_height
         
     | 
| 223 | 
         
            +
                num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
         
     | 
| 224 | 
         
            +
                num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
         
     | 
| 225 | 
         
            +
                total_num_blocks = num_blocks_height * num_blocks_width
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
         
     | 
| 228 | 
         
            +
                half_rev = half_fwd.flip(0)
         
     | 
| 229 | 
         
            +
                if window_size % 2 == 1:
         
     | 
| 230 | 
         
            +
                    half_rev = half_rev[1:]
         
     | 
| 231 | 
         
            +
                c = torch.cat((half_fwd, half_rev))
         
     | 
| 232 | 
         
            +
                one = torch.ones_like(c)
         
     | 
| 233 | 
         
            +
                f = c.clone()
         
     | 
| 234 | 
         
            +
                f[:window_size // 2] = 1
         
     | 
| 235 | 
         
            +
                b = c.clone()
         
     | 
| 236 | 
         
            +
                b[-(window_size // 2):] = 1
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
         
     | 
| 239 | 
         
            +
                w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                views = []
         
     | 
| 242 | 
         
            +
                masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
         
     | 
| 243 | 
         
            +
                for i in range(total_num_blocks):
         
     | 
| 244 | 
         
            +
                    hi, wi = i // num_blocks_width, i % num_blocks_width
         
     | 
| 245 | 
         
            +
                    h_start = hi * stride
         
     | 
| 246 | 
         
            +
                    h_end = min(h_start + window_size, panorama_height)
         
     | 
| 247 | 
         
            +
                    w_start = wi * stride
         
     | 
| 248 | 
         
            +
                    w_end = min(w_start + window_size, panorama_width)
         
     | 
| 249 | 
         
            +
                    views.append((h_start, h_end, w_start, w_end))
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    h_width = h_end - h_start
         
     | 
| 252 | 
         
            +
                    w_width = w_end - w_start
         
     | 
| 253 | 
         
            +
                    masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                # Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
         
     | 
| 256 | 
         
            +
                return views, masks[None] # (1, n, h, w)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
            def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
         
     | 
| 260 | 
         
            +
                h, w = mask.shape[-2:]
         
     | 
| 261 | 
         
            +
                device = mask.device
         
     | 
| 262 | 
         
            +
                mask = mask.reshape(-1, h, w)
         
     | 
| 263 | 
         
            +
                # assert mask.shape[0] == im.shape[0]
         
     | 
| 264 | 
         
            +
                h_occupied = mask.sum(dim=-2) > 0
         
     | 
| 265 | 
         
            +
                w_occupied = mask.sum(dim=-1) > 0
         
     | 
| 266 | 
         
            +
                l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
         
     | 
| 267 | 
         
            +
                r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
         
     | 
| 268 | 
         
            +
                t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
         
     | 
| 269 | 
         
            +
                b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
         
     | 
| 270 | 
         
            +
                tb = (t + b + 1) // 2
         
     | 
| 271 | 
         
            +
                lr = (l + r + 1) // 2
         
     | 
| 272 | 
         
            +
                shifts = (tb - (h // 2), lr - (w // 2))
         
     | 
| 273 | 
         
            +
                shifts = torch.cat(shifts, dim=1) # (p, 2)
         
     | 
| 274 | 
         
            +
                if reverse:
         
     | 
| 275 | 
         
            +
                    shifts = shifts * -1
         
     | 
| 276 | 
         
            +
                return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            class Streamer:
         
     | 
| 280 | 
         
            +
                def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
         
     | 
| 281 | 
         
            +
                    self.fn = fn
         
     | 
| 282 | 
         
            +
                    self.ema_alpha = ema_alpha
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
         
     | 
| 285 | 
         
            +
                    self.future = self.executor.submit(fn)
         
     | 
| 286 | 
         
            +
                    self.image = None
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    self.prev_exec_time = 0
         
     | 
| 289 | 
         
            +
                    self.ema_exec_time = 0
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                @property
         
     | 
| 292 | 
         
            +
                def throughput(self) -> float:
         
     | 
| 293 | 
         
            +
                    return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                def timed_fn(self) -> Any:
         
     | 
| 296 | 
         
            +
                    start = time.time()
         
     | 
| 297 | 
         
            +
                    res = self.fn()
         
     | 
| 298 | 
         
            +
                    end = time.time()
         
     | 
| 299 | 
         
            +
                    self.prev_exec_time = end - start
         
     | 
| 300 | 
         
            +
                    self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
         
     | 
| 301 | 
         
            +
                    return res
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def __call__(self) -> Any:
         
     | 
| 304 | 
         
            +
                    if self.future.done() or self.image is None:
         
     | 
| 305 | 
         
            +
                        # get the result (the new image) and start a new task
         
     | 
| 306 | 
         
            +
                        image = self.future.result()
         
     | 
| 307 | 
         
            +
                        self.future = self.executor.submit(self.timed_fn)
         
     | 
| 308 | 
         
            +
                        self.image = image
         
     | 
| 309 | 
         
            +
                        return image
         
     | 
| 310 | 
         
            +
                    else:
         
     | 
| 311 | 
         
            +
                        # if self.fn() is not ready yet, use the previous image
         
     | 
| 312 | 
         
            +
                        # NOTE: This assumes that we have access to a previously generated image here.
         
     | 
| 313 | 
         
            +
                        # If there's no previous image (i.e., this is the first invocation), you could fall 
         
     | 
| 314 | 
         
            +
                        # back to some default image or handle it differently based on your requirements.
         
     | 
| 315 | 
         
            +
                        return self.image
         
     |