zhaoyi3264 commited on
Commit
6235bf0
·
1 Parent(s): b4afea3

Add multilingual support

Browse files
Files changed (1) hide show
  1. app.py +50 -12
app.py CHANGED
@@ -4,28 +4,65 @@ import gradio as gr
4
  import requests
5
 
6
  API_TOKEN = os.environ['API_TOKEN']
 
7
 
8
- state = []
9
-
10
  headers = {'Authorization': f'Bearer {API_TOKEN}'}
11
 
12
- def query(payload, model_id):
13
- API_URL = f'https://api-inference.huggingface.co/models/{model_id}'
14
- response = requests.post(API_URL, headers=headers, json=payload)
 
 
 
 
 
 
 
 
 
 
 
15
  return response.json()
16
 
17
- def chat(message):
18
- response = query({
19
- 'inputs': message,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  'parameters': {
21
  'max_length': 500,
22
  }
23
- }, 'IssakaAI/health-chatbot')
24
  reply = ''
25
  if isinstance(response, list):
26
- reply = response[0]['generated_text'][len(message) + 1:]
 
 
 
27
  elif isinstance(response, dict):
28
- reply = f'{response["error"]}. Please wait about {int(response["estimated_time"])} seconds.'
29
  state.append((message, reply))
30
  return gr.Textbox.update(value=''), state
31
 
@@ -39,8 +76,9 @@ with gr.Blocks() as blk:
39
  chatbot = gr.Chatbot()
40
  with gr.Box():
41
  message = gr.Textbox(value='What is the menstrual cycle?', lines=10)
 
42
  send = gr.Button('Send', variant='primary')
43
  clear = gr.Button('Clear history', variant='secondary')
44
- send.click(fn=chat, inputs=message, outputs=[message, chatbot])
45
  clear.click(fn=clear_message, inputs=[], outputs=chatbot)
46
  blk.launch(debug=True)
 
4
  import requests
5
 
6
  API_TOKEN = os.environ['API_TOKEN']
7
+ G_TRANS_API_TOKEN = os.environ['G_TRANS_API_TOKEN']
8
 
9
+ API_URL = 'https://api-inference.huggingface.co/models/{}'
10
+ G_TRANS_API = 'https://translation.googleapis.com/language/translate/v2'
11
  headers = {'Authorization': f'Bearer {API_TOKEN}'}
12
 
13
+ def detect_lang(message):
14
+ response = requests.get(G_TRANS_API+'/detect', params={'key': G_TRANS_API_TOKEN, 'q': message})
15
+ return response.json()
16
+
17
+ def translate_src_to_en(message, src_lang):
18
+ response = requests.get(G_TRANS_API, params={'key': G_TRANS_API_TOKEN, 'source': src_lang, 'target': 'en', 'q': message})
19
+ return response.json()
20
+
21
+ def translate_en_to_src(message, src_lang):
22
+ response = requests.get(G_TRANS_API, params={'key': G_TRANS_API_TOKEN, 'source': 'en', 'target': src_lang, 'q': message})
23
+ return response.json()
24
+
25
+ def query_model(model_id, payload):
26
+ response = requests.post(API_URL.format(model_id), headers=headers, json=payload)
27
  return response.json()
28
 
29
+ def parse_model_response(response):
30
+ return response[0]['generated_text']
31
+
32
+ def parse_model_error(response):
33
+ return f'{response["error"]}. Please wait about {int(response["estimated_time"])} seconds.'
34
+
35
+ def parse_translation_response(response):
36
+ return response['data']['translations'][0]['translatedText']
37
+
38
+ def query_model(model_id, payload):
39
+ response = requests.post(API_URL.format(model_id), headers=headers, json=payload)
40
+ return response.json()
41
+
42
+ state = []
43
+
44
+ def chat(message, multi):
45
+ message_en = message
46
+ if multi:
47
+ response = detect_lang(message)
48
+ lang = response['data']['detections'][0][0]['language'][:2]
49
+ if lang != 'en':
50
+ response = translate_src_to_en(message, lang)
51
+ message_en = parse_translation_response(response)
52
+ response = query_model('IssakaAI/health-chatbot', {
53
+ 'inputs': message_en,
54
  'parameters': {
55
  'max_length': 500,
56
  }
57
+ })
58
  reply = ''
59
  if isinstance(response, list):
60
+ reply = parse_model_response(response)[len(message_en) + 1:]
61
+ if multi and lang != 'en':
62
+ response = translate_en_to_src(reply, lang)
63
+ reply = parse_translation_response(response)
64
  elif isinstance(response, dict):
65
+ reply = parse_model_error(response)
66
  state.append((message, reply))
67
  return gr.Textbox.update(value=''), state
68
 
 
76
  chatbot = gr.Chatbot()
77
  with gr.Box():
78
  message = gr.Textbox(value='What is the menstrual cycle?', lines=10)
79
+ multi = gr.Checkbox(False, label='Multilingual chatbot')
80
  send = gr.Button('Send', variant='primary')
81
  clear = gr.Button('Clear history', variant='secondary')
82
+ send.click(fn=chat, inputs=[message, multi], outputs=[message, chatbot])
83
  clear.click(fn=clear_message, inputs=[], outputs=chatbot)
84
  blk.launch(debug=True)