|
import re |
|
import streamlit as st |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, TFAutoModelForMaskedLM |
|
|
|
|
|
history = [] |
|
|
|
def clean_text(text): |
|
return re.sub('[^a-zA-Z\s]', '', text).strip() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small") |
|
model = TFAutoModelForMaskedLM.from_pretrained("t5-small").half().cuda() |
|
|
|
def generate_response(user_input): |
|
history.append((user_input, "")) |
|
|
|
if not history: |
|
return "" |
|
|
|
last_user_message = history[-1][0] |
|
combined_messages = " Human: " + " . ".join([msg for msg, _ in reversed(history[:-1])]) + " . Human: " + last_user_message |
|
input_str = "summarize: " + combined_messages |
|
source_encodings = tokenizer.batch_encode_plus([input_str], pad_to_max_length=False, padding='max_length', return_attention_mask=True, return_tensors="tf") |
|
input_ids = source_encodings["input_ids"][0] |
|
attention_mask = source_encodings["attention_mask"][0] |
|
input_ids = tf.constant(input_ids)[None, :] |
|
attention_mask = tf.constant(attention_mask)[None, :] |
|
|
|
with tf.device('/GPU:0'): |
|
output = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_length=256, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
|
|
predicted_sentence = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
history[-1] = (last_user_message, predicted_sentence) |
|
return f"AI: {predicted_sentence}".capitalize() |
|
|
|
st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)") |
|
|
|
for i in range(len(history)): |
|
message = history[i][0] |
|
response = history[i][1] |
|
|
|
if i % 2 == 0: |
|
col1, col2 = st.beta_columns([0.8, 0.2]) |
|
with col1: |
|
st.markdown(f">> {message}") |
|
with col2: |
|
st.write("") |
|
else: |
|
col1, col2 = st.beta_columns([0.8, 0.2]) |
|
with col1: |
|
st.markdown(f" {response}") |
|
with col2: |
|
st.button("Clear") |
|
|
|
new_message = st.text_area("Type something...") |
|
if st.button("Submit"): |
|
generated_response = generate_response(new_message) |
|
st.markdown(generated_response) |