File size: 3,073 Bytes
d7ba0dd
 
 
 
 
 
6235bf0
d7ba0dd
6235bf0
 
d7ba0dd
 
6235bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ba0dd
 
6235bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ba0dd
 
 
6235bf0
d7ba0dd
 
6235bf0
 
 
 
d7ba0dd
6235bf0
d7ba0dd
 
 
 
 
 
 
 
 
 
 
 
 
6235bf0
d7ba0dd
 
6235bf0
d7ba0dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os

import gradio as gr
import requests

API_TOKEN = os.environ['API_TOKEN']
G_TRANS_API_TOKEN = os.environ['G_TRANS_API_TOKEN']

API_URL = 'https://api-inference.huggingface.co/models/{}'
G_TRANS_API = 'https://translation.googleapis.com/language/translate/v2'
headers = {'Authorization': f'Bearer {API_TOKEN}'}

def detect_lang(message):
    response = requests.get(G_TRANS_API+'/detect', params={'key': G_TRANS_API_TOKEN, 'q': message})
    return response.json()

def translate_src_to_en(message, src_lang):
    response = requests.get(G_TRANS_API, params={'key': G_TRANS_API_TOKEN, 'source': src_lang, 'target': 'en', 'q': message})
    return response.json()

def translate_en_to_src(message, src_lang):
    response = requests.get(G_TRANS_API, params={'key': G_TRANS_API_TOKEN, 'source': 'en', 'target': src_lang, 'q': message})
    return response.json()

def query_model(model_id, payload):
    response = requests.post(API_URL.format(model_id), headers=headers, json=payload)
    return response.json()

def parse_model_response(response):
    return response[0]['generated_text']

def parse_model_error(response):
    return f'{response["error"]}. Please wait about {int(response["estimated_time"])} seconds.'

def parse_translation_response(response):
    return response['data']['translations'][0]['translatedText']

def query_model(model_id, payload):
    response = requests.post(API_URL.format(model_id), headers=headers, json=payload)
    return response.json()

state = []

def chat(message, multi):
    message_en = message
    if multi:
        response = detect_lang(message)
        lang = response['data']['detections'][0][0]['language'][:2]
        if lang != 'en':
            response = translate_src_to_en(message, lang)
            message_en = parse_translation_response(response)
    response = query_model('IssakaAI/health-chatbot', {
        'inputs': message_en,
        'parameters': {
            'max_length': 500,
        }
    })
    reply = ''
    if isinstance(response, list):
        reply = parse_model_response(response)[len(message_en) + 1:]
        if multi and lang != 'en':
            response = translate_en_to_src(reply, lang)
            reply = parse_translation_response(response)
    elif isinstance(response, dict):
        reply = parse_model_error(response)
    state.append((message, reply))
    return gr.Textbox.update(value=''), state

def clear_message():
    state.clear()
    return gr.Chatbot.update(value=[])

with gr.Blocks() as blk:
    gr.Markdown('# Interact with IssakaAI NLP models')
    with gr.Row():
        chatbot = gr.Chatbot()
        with gr.Box():
            message = gr.Textbox(value='What is the menstrual cycle?', lines=10)
            multi = gr.Checkbox(False, label='Multilingual chatbot')
            send = gr.Button('Send', variant='primary')
            clear = gr.Button('Clear history', variant='secondary')
    send.click(fn=chat, inputs=[message, multi], outputs=[message, chatbot])
    clear.click(fn=clear_message, inputs=[], outputs=chatbot)
blk.launch(debug=True)