Shabbir-Anjum commited on
Commit
13bc63a
1 Parent(s): 7d671a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -24
app.py CHANGED
@@ -1,31 +1,26 @@
 
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
- # Load the Diffusion pipeline for text generation
5
- generator = pipeline("text-generation", model="stabilityai/stable-diffusion-3-medium")
6
-
7
- def generate_prompt(prompt_text):
8
- # Generate response using the Diffusion model
9
- response = generator(prompt_text, top_p=0.9, max_length=100)[0]['generated_text']
10
- return response
11
 
 
12
  def main():
13
- st.title('Diffusion Model Prompt Generator')
14
-
15
- # Text input for the prompt
16
- prompt_text = st.text_area("Enter your prompt here:", height=200)
17
-
18
- # Button to generate prompt
19
- if st.button("Generate"):
20
- if prompt_text:
21
- with st.spinner('Generating...'):
22
- generated_text = generate_prompt(prompt_text)
23
- st.success('Generation complete!')
24
- st.text_area('Generated Text:', value=generated_text, height=400)
25
- else:
26
- st.warning('Please enter a prompt.')
27
-
28
- if __name__ == '__main__':
29
  main()
30
 
31
 
 
 
1
+ import torch
2
  import streamlit as st
3
+ from diffusers import StableDiffusion3Pipeline
4
 
5
+ # Load the model
6
+ pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
7
+ pipeline = pipeline.to("cuda") # Move model to GPU if available
 
 
 
 
8
 
9
+ # Streamlit UI
10
  def main():
11
+ st.title("Stable Diffusion 3 Medium Demo")
12
+ prompt = st.text_input("Enter your prompt:", "A cat holding a sign that says hello world")
13
+
14
+ if st.button("Generate Image"):
15
+ with st.spinner("Generating..."):
16
+ try:
17
+ image = pipeline(prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
18
+ st.image(image, caption="Generated Image", use_column_width=True)
19
+ except Exception as e:
20
+ st.error(f"Error: {e}")
21
+
22
+ if __name__ == "__main__":
 
 
 
 
23
  main()
24
 
25
 
26
+