|
import re |
|
import streamlit as st |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
history = [] |
|
|
|
def clean_text(text): |
|
return re.sub('[^a-zA-Z\s]', '', text).strip() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/DialoGPT-small").half().cuda() |
|
|
|
def generate_response(user_input): |
|
history.append((user_input, "")) |
|
|
|
if not history: |
|
return "" |
|
|
|
last_user_message = history[-1][0] |
|
combined_messages = " ".join([msg for msg, _ in reversed(history[:-1])]) + " User: " + last_user_message |
|
|
|
tokens = tokenizer.encode(combined_messages, add_special_tokens=True, max_length=4096, truncation=True) |
|
tokens = tokens[:1024] |
|
segment_ids = [0]*len(tokens) |
|
input_ids = torch.tensor([tokens], dtype=torch.long).cuda() |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids, |
|
max_length=1024, |
|
min_length=20, |
|
length_penalty=2.0, |
|
early_stopping=True, |
|
num_beams=4, |
|
bad_words_callback=[lambda x: True if 'User:' in str(x) else False] |
|
) |
|
output = output[0].tolist()[len(tokens)-1:] |
|
decoded_output = tokenizer.decode(output, skip_special_tokens=True) |
|
|
|
history[-1] = (last_user_message, decoded_output) |
|
return f"AI: {decoded_output}".capitalize() |
|
|
|
st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)") |
|
|
|
for i in range(len(history)): |
|
message = history[i][0] |
|
response = history[i][1] |
|
|
|
if i % 2 == 0: |
|
col1, col2 = st.beta_columns([0.8, 0.2]) |
|
with col1: |
|
st.markdown(f">> {message}") |
|
with col2: |
|
st.write("") |
|
else: |
|
col1, col2 = st.beta_columns([0.8, 0.2]) |
|
with col1: |
|
st.markdown(f" {response}") |
|
with col2: |
|
st.button("Clear") |
|
|
|
new_message = st.text_area("Type something...") |
|
if st.button("Submit"): |
|
generated_response = generate_response(new_message) |
|
st.markdown(generated_response) |