File size: 3,276 Bytes
500c32a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

import requests
import base64
from PIL import Image, ImageFilter
from io import BytesIO
from transformers import pipeline
import streamlit as st

# API Endpoint for image generation
url = "http://34.198.214.220:8000/generate/"

# Streamlit sidebar for model selection
model_option = st.sidebar.selectbox("Select Model", ["Fluently-XL-Final", "Flux-Uncensored"], index=0)

st.title("Text to Image Generator")

# Streamlit input field for prompt
prompt = st.text_input("Enter Prompt", "")  # Default prompt

# Use a pipeline as a high-level helper for image classification
pipe = pipeline("image-classification", model="giacomoarienti/nsfw-classifier")

def classify_image(image):
    """
    Classifies an image using the NSFW classifier.

    Args:
        image: The PIL image object to be classified.

    Returns:
        A dictionary containing the classification results.
    """
    try:
        # Classify the image using the pipeline
        results = pipe(image)
        return results
    except Exception as e:
        st.error(f"Error during classification: {e}")
        return None

def blur_image(image):
    """
    Applies a Gaussian Blur to an image and saves it.

    Args:
        image: The PIL image object to be blurred.
    """
    # Apply Gaussian Blur filter to the image
    blurred_image = image.filter(ImageFilter.GaussianBlur(radius=40))

    # Display the blurred image
    st.image(blurred_image, caption="Blurred Image", use_container_width=True)

def process_image(image):
    """
    Processes the image by classifying it and applying actions based on results.

    Args:
        image: The PIL image object.
    """
    results = classify_image(image)

    if results:

        # Check if either 'porn' label > 0.7 or 'sexy' label > 0.85
        porn_score = next((item['score'] for item in results if item['label'] == 'porn'), 0)
        sexy_score = next((item['score'] for item in results if item['label'] == 'sexy'), 0)
        if porn_score > 0.7 or sexy_score > 0.85:
            blur_image(image)  # Apply blur and show the blurred image
        else:
            st.image(image, caption="Original Image", use_container_width=True)  # Show the original image even if it does not meet the threshold
    else:
        st.error("Error: Image classification failed.")
# Button to generate image
if st.button('Generate Image'):
    payload = {
        "prompt": prompt,  # User input prompt
        "model": model_option  # Model selected by user
    }

    # Generate the image using the API
    response = requests.post(url, json=payload)

    # Check if the request was successful
    if response.status_code == 200:
        response_data = response.json()

        # Extract the base64 image string
        if "image_base64" in response_data:
            base64_string = response_data["image_base64"]

            # Decode the base64 string into an image
            image_data = base64.b64decode(base64_string)
            image = Image.open(BytesIO(image_data))

            # Process the generated image
            process_image(image)

        else:
            st.error("No image data found in the response!")
    else:
        st.error(f"Failed to generate image. Status code: {response.status_code}, Error: {response.text}")