File size: 1,259 Bytes
be9ad62
 
 
b52e32b
5b47eb8
f56f478
be9ad62
 
5b47eb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56f478
6b96a20
f56f478
 
6b96a20
5b47eb8
29341fc
be9ad62
f56f478
be9ad62
5b47eb8
be9ad62
 
5b47eb8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "Mohammed-Altaf/medical_chatbot-8bit"
model = AutoModelForCausalLM.from_pretrained(model_id,ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)


def get_clean_response(response):
    if type(response) == list:
        response = response[0].split("\n")
    else:
        response = response.split("\n")
        
    ans = ''
    cnt = 0 # to verify if we have seen Human before 
    for answer in response:
        if answer.startswith("[|Human|]"): cnt += 1
            
        elif answer.startswith('[|AI|]'):
            answer = answer.split(' ')
            ans += ' '.join(char for char in answer[1:])
            ans += '\n'
        
        elif cnt:
            ans += answer + '\n'
    return ans


def generate_text(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt")

    output = model.generate(
        **input_ids,
        max_length=100,
        )

    output_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return get_clean_response(output_text)


iface = gr.Interface(fn = generate_text, inputs = 'text', outputs = ['text'], title ='Medical ChatBot')
iface.launch()