lakshmikarpolam commited on
Commit
89fa300
1 Parent(s): bc86086

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Function to generate and display image
7
+ def generate_and_display_image(prompt):
8
+ # Initialize the UNet model
9
+ unet = UNet2DConditionModel.from_pretrained("path/to/fine-tuned/weight", torch_dtype=torch.float16, variant="fp16")
10
+
11
+ # Initialize the diffusion pipeline
12
+ pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16")
13
+ pipeline.safety_checker = None
14
+ pipeline.requires_safety_checker = False
15
+
16
+ # Set the loaded scheduler in the pipeline
17
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
18
+ pipeline.to("cuda")
19
+
20
+ # Set the number of inference steps
21
+ inference_steps = 4
22
+
23
+ # Generate image
24
+ image = pipeline(prompt, num_inference_steps=inference_steps, guidance_scale=2).images[0]
25
+ image = image.resize((512, 512))
26
+
27
+ # Display the generated image
28
+ st.image(image, caption="Generated Image", use_column_width=True)
29
+
30
+ # Main function
31
+ def main():
32
+ st.title("Image Generation with Diffusion Models")
33
+
34
+ # Input prompt
35
+ prompt = st.text_input("Enter your prompt")
36
+
37
+ # Button to generate and display image
38
+ if st.button("Generate Image"):
39
+ if prompt:
40
+ generate_and_display_image(prompt)
41
+ else:
42
+ st.warning("Please provide a prompt.")
43
+
44
+ if __name__ == "__main__":
45
+ main()