Brij1808 commited on
Commit
500c32a
·
verified ·
1 Parent(s): 80df7a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+ import base64
4
+ from PIL import Image, ImageFilter
5
+ from io import BytesIO
6
+ from transformers import pipeline
7
+ import streamlit as st
8
+
9
+ # API Endpoint for image generation
10
+ url = "http://34.198.214.220:8000/generate/"
11
+
12
+ # Streamlit sidebar for model selection
13
+ model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0)
14
+
15
+ st.title("Text to Image Generator")
16
+
17
+ # Streamlit input field for prompt
18
+ prompt = st.text_input("Enter Prompt", "") # Default prompt
19
+
20
+ # Use a pipeline as a high-level helper for image classification
21
+ pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier")
22
+
23
+ def classify_image(image):
24
+ """
25
+ Classifies an image using the NSFW classifier.
26
+
27
+ Args:
28
+ image: The PIL image object to be classified.
29
+
30
+ Returns:
31
+ A dictionary containing the classification results.
32
+ """
33
+ try:
34
+ # Classify the image using the pipeline
35
+ results = pipe(image)
36
+ return results
37
+ except Exception as e:
38
+ st.error(f"Error during classification: {e}")
39
+ return None
40
+
41
+ def blur_image(image):
42
+ """
43
+ Applies a Gaussian Blur to an image and saves it.
44
+
45
+ Args:
46
+ image: The PIL image object to be blurred.
47
+ """
48
+ # Apply Gaussian Blur filter to the image
49
+ blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40))
50
+
51
+ # Display the blurred image
52
+ st.image(blurred_image, caption="Blurred Image", use_container_width=True)
53
+
54
+ def process_image(image):
55
+ """
56
+ Processes the image by classifying it and applying actions based on results.
57
+
58
+ Args:
59
+ image: The PIL image object.
60
+ """
61
+ results = classify_image(image)
62
+
63
+ if results:
64
+
65
+ # Check if either 'porn' label > 0.7 or 'sexy' label > 0.85
66
+ porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0)
67
+ sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0)
68
+ if porn_score > 0.7 or sexy_score > 0.85:
69
+ blur_image(image) # Apply blur and show the blurred image
70
+ else:
71
+ st.image(image, caption="Original Image", use_container_width=True) # Show the original image even if it does not meet the threshold
72
+ else:
73
+ st.error("Error: Image classification failed.")
74
+ # Button to generate image
75
+ if st.button('Generate Image'):
76
+ payload = {
77
+ "prompt": prompt, # User input prompt
78
+ "model": model_option # Model selected by user
79
+ }
80
+
81
+ # Generate the image using the API
82
+ response = requests.post(url, json=payload)
83
+
84
+ # Check if the request was successful
85
+ if response.status_code == 200:
86
+ response_data = response.json()
87
+
88
+ # Extract the base64 image string
89
+ if "image_base64" in response_data:
90
+ base64_string = response_data["image_base64"]
91
+
92
+ # Decode the base64 string into an image
93
+ image_data = base64.b64decode(base64_string)
94
+ image = Image.open(BytesIO(image_data))
95
+
96
+ # Process the generated image
97
+ process_image(image)
98
+
99
+ else:
100
+ st.error("No image data found in the response!")
101
+ else:
102
+ st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}")