karwanjiru commited on
Commit
370d98b
·
1 Parent(s): ddf531e
Files changed (1) hide show
  1. app.py +68 -12
app.py CHANGED
@@ -1,13 +1,59 @@
1
  import gradio as gr
 
 
 
 
2
  from huggingface_hub import InferenceClient
3
  import requests
4
  from PIL import Image
5
  from io import BytesIO
6
 
7
- # Initialize the client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
 
10
- # Define the function to respond to user inputs
11
  def respond(message, history, system_message, max_tokens, temperature, top_p):
12
  messages = [{"role": "system", "content": system_message}]
13
 
@@ -28,7 +74,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
28
 
29
  return response.choices[0].message['content']
30
 
31
- # Define the function to generate posts
32
  def generate_post(prompt, max_tokens, temperature, top_p):
33
  response = client.chat_completion(
34
  [{"role": "user", "content": prompt}],
@@ -38,21 +84,20 @@ def generate_post(prompt, max_tokens, temperature, top_p):
38
  )
39
  return response.choices[0].message['content']
40
 
41
- # Define the function to moderate posts
42
  def moderate_post(post):
43
  # Implement your post moderation logic here
44
  if "inappropriate" in post:
45
  return "Post does not adhere to community guidelines."
46
  return "Post adheres to community guidelines."
47
 
48
- # Define the function to generate images
49
  def generate_image(prompt):
50
- # Replace with actual model or API endpoint for image generation
51
- response = client.text_to_image(prompt)
52
- image = Image.open(BytesIO(response))
53
  return image
54
 
55
- # Define the function to moderate images
56
  def moderate_image(image):
57
  # Convert the PIL image to a format that can be sent for moderation
58
  buffered = BytesIO()
@@ -73,10 +118,21 @@ def moderate_image(image):
73
  return "Image does not adhere to community guidelines."
74
 
75
  # Create the Gradio interface
76
- demo = gr.Blocks()
 
 
 
 
 
 
 
 
 
 
77
 
78
- with demo:
79
  gr.Markdown("# AI-driven Content Generation and Moderation Bot")
 
80
 
81
  with gr.Tabs():
82
  with gr.TabItem("Chat"):
@@ -131,4 +187,4 @@ with demo:
131
  moderate_image_button.click(moderate_image, uploaded_image, image_moderation_result)
132
 
133
  if __name__ == "__main__":
134
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ from diffusers import DiffusionPipeline
5
+ import torch
6
  from huggingface_hub import InferenceClient
7
  import requests
8
  from PIL import Image
9
  from io import BytesIO
10
 
11
+ # Device configuration
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Load the diffusion pipeline
15
+ if torch.cuda.is_available():
16
+ torch.cuda.max_memory_allocated(device=device)
17
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
18
+ pipe.enable_xformers_memory_efficient_attention()
19
+ pipe = pipe.to(device)
20
+ else:
21
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
22
+ pipe = pipe.to(device)
23
+
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 1024
26
+
27
+ # Inference function for generating images
28
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
29
+ if randomize_seed:
30
+ seed = random.randint(0, MAX_SEED)
31
+
32
+ generator = torch.Generator().manual_seed(seed)
33
+
34
+ image = pipe(
35
+ prompt=prompt,
36
+ negative_prompt=negative_prompt,
37
+ guidance_scale=guidance_scale,
38
+ num_inference_steps=num_inference_steps,
39
+ width=width,
40
+ height=height,
41
+ generator=generator
42
+ ).images[0]
43
+
44
+ return image
45
+
46
+ # Examples for the text-to-image generation
47
+ examples = [
48
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
49
+ "An astronaut riding a green horse",
50
+ "A delicious ceviche cheesecake slice",
51
+ ]
52
+
53
+ # Initialize the InferenceClient
54
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
55
 
56
+ # Respond function for the chatbot
57
  def respond(message, history, system_message, max_tokens, temperature, top_p):
58
  messages = [{"role": "system", "content": system_message}]
59
 
 
74
 
75
  return response.choices[0].message['content']
76
 
77
+ # Function to generate posts
78
  def generate_post(prompt, max_tokens, temperature, top_p):
79
  response = client.chat_completion(
80
  [{"role": "user", "content": prompt}],
 
84
  )
85
  return response.choices[0].message['content']
86
 
87
+ # Function to moderate posts
88
  def moderate_post(post):
89
  # Implement your post moderation logic here
90
  if "inappropriate" in post:
91
  return "Post does not adhere to community guidelines."
92
  return "Post adheres to community guidelines."
93
 
94
+ # Function to generate images using the diffusion pipeline
95
  def generate_image(prompt):
96
+ generator = torch.manual_seed(random.randint(0, MAX_SEED))
97
+ image = pipe(prompt=prompt, generator=generator).images[0]
 
98
  return image
99
 
100
+ # Function to moderate images
101
  def moderate_image(image):
102
  # Convert the PIL image to a format that can be sent for moderation
103
  buffered = BytesIO()
 
118
  return "Image does not adhere to community guidelines."
119
 
120
  # Create the Gradio interface
121
+ css = """
122
+ #col-container {
123
+ margin: 0 auto;
124
+ max-width: 520px;
125
+ }
126
+ """
127
+
128
+ if torch.cuda.is_available():
129
+ power_device = "GPU"
130
+ else:
131
+ power_device = "CPU"
132
 
133
+ with gr.Blocks(css=css) as demo:
134
  gr.Markdown("# AI-driven Content Generation and Moderation Bot")
135
+ gr.Markdown(f"Currently running on {power_device}.")
136
 
137
  with gr.Tabs():
138
  with gr.TabItem("Chat"):
 
187
  moderate_image_button.click(moderate_image, uploaded_image, image_moderation_result)
188
 
189
  if __name__ == "__main__":
190
+ demo.launch()