vatistasdimitris commited on
Commit
1355278
·
verified ·
1 Parent(s): fb3692c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -49
app.py CHANGED
@@ -1,26 +1,70 @@
1
  import gradio as gr
2
- from gtts import gTTS
3
- import speech_recognition as sr
4
- import tempfile
5
- import os
6
  from huggingface_hub import InferenceClient
 
 
7
 
8
- # Initialize the Hugging Face model client
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
11
- def respond(audio_file, history, system_message, max_tokens, temperature, top_p):
12
- # Convert audio file to text
13
- recognizer = sr.Recognizer()
14
- temp_audio_path = audio_file.name
 
 
 
 
 
 
15
 
16
- try:
17
- with sr.AudioFile(temp_audio_path) as source:
18
- audio_data = recognizer.record(source)
19
- message = recognizer.recognize_google(audio_data)
20
- except Exception as e:
21
- return "Error in recognizing audio", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Prepare messages for the model
24
  messages = [{"role": "system", "content": system_message}]
25
  for val in history:
26
  if val[0]:
@@ -30,46 +74,59 @@ def respond(audio_file, history, system_message, max_tokens, temperature, top_p)
30
 
31
  messages.append({"role": "user", "content": message})
32
 
33
- # Get response from the model
34
  response = ""
35
- for msg in client.chat_completion(
36
  messages,
37
  max_tokens=max_tokens,
38
  stream=True,
39
  temperature=temperature,
40
  top_p=top_p,
41
  ):
42
- token = msg.choices[0].delta.content
43
  response += token
 
44
 
45
- # Convert response to speech
46
- tts = gTTS(response, lang='en', slow=False)
47
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file:
48
- tts.save(temp_file.name)
49
- temp_file.seek(0)
50
- audio_output = temp_file.read()
51
-
52
- os.remove(temp_file.name)
53
-
54
- return response, audio_output
55
-
56
- # Create Gradio interface
57
- demo = gr.Interface(
58
- fn=respond,
59
- inputs=[
60
- gr.Audio(type="filepath", label="Record your audio"),
61
- gr.State(value=[]), # No `label` parameter
62
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
63
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
64
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
65
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
66
- ],
67
- outputs=[
68
- gr.Textbox(label="Response"),
69
- gr.Audio(label="Response Audio", type="file")
70
- ],
71
- live=True
72
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- if __name__ == "__main__":
75
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ from diffusers import DiffusionPipeline
5
+ import torch
6
  from huggingface_hub import InferenceClient
7
+ from PIL import Image
8
+ from io import BytesIO
9
 
10
+ # Initialize the Hugging Face client for chat
11
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
 
13
+ # Initialize the DiffusionPipeline for image generation
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ if torch.cuda.is_available():
16
+ torch.cuda.max_memory_allocated(device=device)
17
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
18
+ pipe.enable_xformers_memory_efficient_attention()
19
+ pipe = pipe.to(device)
20
+ else:
21
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
22
+ pipe = pipe.to(device)
23
 
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 1024
26
+
27
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
28
+ if randomize_seed:
29
+ seed = random.randint(0, MAX_SEED)
30
+ generator = torch.Generator().manual_seed(seed)
31
+ image = pipe(
32
+ prompt = prompt,
33
+ negative_prompt = negative_prompt,
34
+ guidance_scale = guidance_scale,
35
+ num_inference_steps = num_inference_steps,
36
+ width = width,
37
+ height = height,
38
+ generator = generator
39
+ ).images[0]
40
+ return image
41
+
42
+ def respond(
43
+ message,
44
+ history: list[tuple[str, str]],
45
+ system_message,
46
+ max_tokens,
47
+ temperature,
48
+ top_p,
49
+ ):
50
+ # Check for image generation request
51
+ if "generate an image" in message.lower():
52
+ prompt = message.replace("generate an image", "").strip()
53
+ image = infer(
54
+ prompt=prompt,
55
+ negative_prompt="",
56
+ seed=0,
57
+ randomize_seed=True,
58
+ width=512,
59
+ height=512,
60
+ guidance_scale=7.5,
61
+ num_inference_steps=50
62
+ )
63
+ buffered = BytesIO()
64
+ image.save(buffered, format="PNG")
65
+ img_str = buffered.getvalue()
66
+ return "Here is your generated image:", img_str
67
 
 
68
  messages = [{"role": "system", "content": system_message}]
69
  for val in history:
70
  if val[0]:
 
74
 
75
  messages.append({"role": "user", "content": message})
76
 
 
77
  response = ""
78
+ for message in client.chat_completion(
79
  messages,
80
  max_tokens=max_tokens,
81
  stream=True,
82
  temperature=temperature,
83
  top_p=top_p,
84
  ):
85
+ token = message.choices[0].delta.content
86
  response += token
87
+ yield response
88
 
89
+ # Define Gradio Blocks interface
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# Chat and Image Generation")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ chat_interface = gr.ChatInterface(
96
+ respond,
97
+ additional_inputs=[
98
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
99
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
100
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
101
+ gr.Slider(
102
+ minimum=0.1,
103
+ maximum=1.0,
104
+ value=0.95,
105
+ step=0.05,
106
+ label="Top-p (nucleus sampling)",
107
+ ),
108
+ ],
109
+ )
110
+
111
+ def process_image_request(prompt):
112
+ image = infer(
113
+ prompt=prompt,
114
+ negative_prompt="",
115
+ seed=0,
116
+ randomize_seed=True,
117
+ width=512,
118
+ height=512,
119
+ guidance_scale=7.5,
120
+ num_inference_steps=50
121
+ )
122
+ buffered = BytesIO()
123
+ image.save(buffered, format="PNG")
124
+ return buffered.getvalue()
125
+
126
+ gr.Examples(
127
+ examples=["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice"],
128
+ inputs=[gr.Textbox(label="Prompt", placeholder="Enter your prompt")],
129
+ outputs=[gr.Image()]
130
+ )
131
 
132
+ demo.queue().launch()