File size: 2,177 Bytes
8b8062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a7380
8b8062c
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3eb22
8b8062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
from streamlit_chat import message
import requests
from transformers import AutoModelWithLMHead, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
model = AutoModelWithLMHead.from_pretrained('output-small-save')

st.set_page_config(
    page_title="COVID Doctor using DialoGPT",
    page_icon=":robot:"
)

API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
headers = {"Authorization": st.secrets['api_key']}

st.header("Hello - Welcome to COVID Doctor using DialoGPT")
st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")

if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

def query(payload):
    bot_input_ids = tokenizer.encode(payload["inputs"]["text"] + tokenizer.eos_token, return_tensors='pt')

    chat_history_ids = model.generate(
      bot_input_ids, max_length=100,
      pad_token_id=tokenizer.eos_token_id,  
      no_repeat_ngram_size=3,       
      do_sample=True, 
      top_k=10, 
      top_p=0.7,
      temperature = 0.8
    )
    output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    return {"generated_text": output}

def get_text():
    input_text = st.text_input("You: ","I have shortness of breath and are worried, I don’t have a cough or sore throat, so they will not test me, should I do a private test?", key="input")
    return input_text 


user_input = get_text()

if user_input:
    output = query({
        "inputs": {
            "past_user_inputs": st.session_state.past,
            "generated_responses": st.session_state.generated,
            "text": user_input,
        },"parameters": {"repetition_penalty": 1.33},
    })
    st.session_state.past.append(user_input)
    st.session_state.generated.append(output["generated_text"])

if st.session_state['generated']:

    for i in range(len(st.session_state['generated'])-1, -1, -1):
        message(st.session_state["generated"][i], key=str(i))
        message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')