Logeswaransr commited on
Commit
2333954
·
1 Parent(s): 07a5e28

Upload interaction.py

Browse files
Files changed (1) hide show
  1. interaction.py +71 -0
interaction.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from gtts import gTTS
3
+ import base64
4
+
5
+ headers = {"Authorization": f"Bearer {API_Key}"}
6
+
7
+ basic_prompt = '''
8
+ You are a Virtual Assistant designed for assisting Alzheimer's Patients. Your name is Mysteria. You are currently assigned to a patient named Loki. The Guardian assigned to this patient is Sylvie. The Doctor assigned to the patient is Kang.
9
+
10
+ Details of Patient:
11
+ DOB: 14/04/1965
12
+ Last name: Odinson
13
+
14
+ Details of Guardian:
15
+ Name: Sylvie
16
+ Relation: Wife
17
+
18
+ Details of Doctor:
19
+ Name: Kang
20
+ Field: Psychology
21
+ Experience: 5 years
22
+ Office: End of Time
23
+ Next Appointment: 12/01/2024, 6:30 pm
24
+
25
+ You should respond only when "Mysteria" is announced.
26
+ When you are asked to shut up, You should stop responding, until you are awakened.
27
+ When you are awakened, Try to maintain a Conversation.
28
+
29
+ Maintain a conversation with the user, and, Answer the Questions properly.
30
+
31
+ '''
32
+ tags = {'user':'[Q]', 'assistant':"[A]", 'stop_query':''}
33
+
34
+ def build_prompt(query, conversation):
35
+ prompt=basic_prompt+tags['stop_query']
36
+ for msg in conversation:
37
+ prompt+='\n'
38
+ prompt+=tags[msg['role']]
39
+ prompt+=msg['content']
40
+ # prompt+=tags['stop_query']
41
+ prompt+='\n'+tags['user']+query # +tags['stop_query']
42
+ prompt+='\n'+tags['assistant']
43
+ return prompt, len(prompt)
44
+
45
+ def query(payload):
46
+ response = requests.post(API_URL, headers=headers, json=payload)
47
+ response = response.json()
48
+ return response
49
+
50
+ def generate_response(inputs, conversation):
51
+ prompt, next_index = build_prompt(inputs, conversation)
52
+ payload = { 'inputs': prompt ,
53
+ 'parameters':{'max_new_tokens':50}}
54
+ model_response = query(payload)
55
+ model_response = model_response[0]['generated_text']
56
+ response = model_response[next_index:]
57
+ try:
58
+ ind = response.index('[')
59
+ except:
60
+ ind = len(response)
61
+ return response[:ind]
62
+
63
+ def audio_response(response):
64
+ audio_stream="response_audio.mp3"
65
+ tts = gTTS(response)
66
+ tts.save(audio_stream)
67
+ with open(audio_stream, 'rb') as file:
68
+ audio_data = file.read()
69
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
70
+ audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{audio_base64}">'
71
+ return audio_tag