Icar commited on
Commit
04389c8
1 Parent(s): b179b6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from gtts import gTTS
4
+ from transformers import pipeline
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
+ css = """
7
+ #input {background-color: #FFCCCB}
8
+ """
9
+ # Utility Functions
10
+ flatten = lambda l: [item for sublist in l for item in sublist]
11
+
12
+ def to_data(x):
13
+ if torch.cuda.is_available():
14
+ x = x.cpu()
15
+ return x.data.numpy()
16
+
17
+ def to_var(x):
18
+ if not torch.is_tensor(x):
19
+ x = torch.Tensor(x)
20
+ if torch.cuda.is_available():
21
+ x = x.cuda()
22
+ return x
23
+
24
+ def clear():
25
+ return None,[],[]
26
+
27
+ def append(text, history,dialog_hx,personas):
28
+ history.append([text,None])
29
+ history , audio,dialog_hx= bot.respond(history,dialog_hx,personas)
30
+ return history, audio, None,dialog_hx
31
+
32
+ class AI_Companion:
33
+ """
34
+ Class that Implements AI Companion.
35
+ """
36
+
37
+ def __init__(self, asr = "openai/whisper-tiny", chatbot = "af1tang/personaGPT"):
38
+ """
39
+ Create an Instance of the Companion.
40
+ Parameters:
41
+ asr: Huggingface ASR Model Card. Default: openai/whisper-tiny
42
+ chatbot: Huggingface Conversational Model Card. Default: af1tang/personaGPT
43
+ """
44
+
45
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
46
+ self.asr = pipeline("automatic-speech-recognition",model = asr,device= -1 if self.device == "cpu" else 0)
47
+ self.model = GPT2LMHeadModel.from_pretrained(chatbot).to(self.device)
48
+ self.tokenizer = GPT2Tokenizer.from_pretrained(chatbot)
49
+ self.personas=[]
50
+ self.sett={
51
+ "do_sample":True,
52
+ "top_k":10,
53
+ "top_p":0.92,
54
+ "max_length":1000,
55
+ }
56
+
57
+ def listen(self, audio, history):
58
+ """
59
+ Convert Speech to Text.
60
+
61
+ Parameters:
62
+ audio: Audio Filepath
63
+ history: Chat History
64
+
65
+ Returns:
66
+ history : history with recognized text appended
67
+ Audio : empty gradio component to clear gradio voice input
68
+ """
69
+ text = self.asr(audio)["text"]
70
+ history.append([text,None])
71
+ return history , None
72
+
73
+ def add_fact(self,audio,personas,msg):
74
+ '''
75
+ Add fact to Persona.
76
+ Takes in Audio, converts it into text and adds it to the facts list.
77
+
78
+ Parameters:
79
+ audio : audio of the spoken fact
80
+ '''
81
+ if audio is not None:
82
+ text=self.asr(audio)
83
+ personas.append(text['text']+self.tokenizer.eos_token)
84
+ else:
85
+ personas.append(msg+self.tokenizer.eos_token)
86
+ return None,personas,None
87
+
88
+ def respond(self, history,dialog_hx,personas,**kwargs):
89
+ """
90
+ Generates Response to User Input.
91
+
92
+ Parameters:
93
+ history: Chat History
94
+
95
+ Returns:
96
+ history: history with response appended
97
+ audio: audio of the spoken response
98
+ """
99
+
100
+ person = self.tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
101
+ user_inp= self.tokenizer.encode(history[-1][0]+self.tokenizer.eos_token)
102
+ dialog_hx.append(user_inp)
103
+ bot_input_ids = to_var([person + flatten(dialog_hx)]).long()
104
+ with torch.no_grad():
105
+
106
+ full_msg = self.model.generate(bot_input_ids,
107
+ repetition_penalty=1.4,
108
+ top_k = 10,
109
+ top_p = 0.92,
110
+ max_new_tokens = 256,
111
+ num_beams=2,
112
+ pad_token_id = self.tokenizer.eos_token_id)
113
+
114
+
115
+ response = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
116
+ dialog_hx.append(response)
117
+ history[-1][1] = self.tokenizer.decode(response, skip_special_tokens=True)
118
+ self.speak(history[-1][1])
119
+ return history, "out.mp3",dialog_hx
120
+
121
+ def talk(self, audio, history,dialog_hx,personas,text):
122
+ if audio is not None:
123
+ history, _ = self.listen(audio, history)
124
+ else:
125
+ history.append([text,None])
126
+ history, audio,dialog_hx = self.respond(history,dialog_hx,personas)
127
+ return history, None, audio,dialog_hx,None
128
+
129
+ def speak(self, text):
130
+ """
131
+ Speaks the given text using gTTS,
132
+ Parameters:
133
+ text: text to be spoken
134
+ """
135
+ tts = gTTS(text, lang='en')
136
+ tts.save('out.mp3')
137
+
138
+ # Initialize AI Companion
139
+ bot = AI_Companion()
140
+ personas=[]
141
+ for i in ['I\'m a 19 year old girl','I study at IIT Indore','I am an easy-going and fun loving person','I love to swim','I am friendly, nice ,fun and kind','I am studious and get good grades']:
142
+ response = i+ bot.tokenizer.eos_token
143
+ personas.append(response)
144
+
145
+
146
+ # Create the Interface
147
+ with gr.Blocks() as demo:
148
+ dialog_hx=gr.State([])
149
+ personas=gr.State(personas)
150
+ chatbot = gr.Chatbot([], elem_id = "chatbot").style(height = 300)
151
+ audio = gr.Audio(source = "microphone", type = "filepath", label = "Input")
152
+ msg = gr.Textbox()
153
+ audio1 = gr.Audio(type = "filepath", label = "Output",elem_id="input")
154
+ with gr.Row():
155
+ b1 = gr.Button("Submit")
156
+ b2 = gr.Button("Clear")
157
+ b3= gr.Button("Add Fact")
158
+ b1.click(bot.talk, [audio, chatbot,dialog_hx,personas,msg], [chatbot, audio, audio1,dialog_hx,msg])
159
+ msg.submit(append, [msg, chatbot,dialog_hx,personas], [chatbot, audio1, msg,dialog_hx])
160
+ b2.click(clear, [] , [audio,chatbot,dialog_hx])
161
+ b3.click(bot.add_fact, [audio,personas,msg], [audio,personas,msg])
162
+ demo.launch(share=True)
163
+
164
+
165
+