Hackavist commited on
Commit
6aa1bf6
1 Parent(s): 497fdf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -33
app.py CHANGED
@@ -1,40 +1,53 @@
1
- import torch
2
- from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
3
- import json
4
  import streamlit as st
5
  from transformers import pipeline
6
 
7
- model_name = "distilbert-base-cased"
8
- tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
9
- model = DistilBertForQuestionAnswering.from_pretrained(model_name)
10
 
11
- def get_user_input():
12
- return st.chat_input("You:", "")
13
 
14
- def generate_response(prompt):
15
- try:
16
- response = model(prompt, max_length=50, do_sample=True)
17
- return response[0]['generated_text'].strip()
18
- except Exception as e:
19
- print(f"Error while generating response: {str(e)}")
20
- return "Sorry, I cannot respond right now."
21
 
22
- if __name__ == "__main__":
23
- st.set_page_config(page_title="AI Chat App", page_icon=":robot_face:", layout="wide")
24
-
25
- st.write("# Welcome to the AI Chat App! :wave:")
26
- st.markdown("<style>.container { width:100% !important; }</style>", unsafe_allow_html=True)
 
27
 
28
- user_input = ""
29
- while True:
30
- # Get user input
31
- user_input = get_user_input()
32
-
33
- if not user_input:
34
- continue
35
-
36
- # Generate response
37
- response = generate_response(user_input)
38
-
39
- # Display response
40
- st.markdown(f"<b>AI:</b> {response}", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
 
 
2
  import streamlit as st
3
  from transformers import pipeline
4
 
5
+ # Initialize the chat history
6
+ history = []
 
7
 
8
+ def clean_text(text):
9
+ return re.sub('[^a-zA-Z\s]', '', text).strip()
10
 
11
+ # Load DistilBert model
12
+ model = pipeline("question-answering", model="distilbert-base-cased")
 
 
 
 
 
13
 
14
+ def generate_response(user_input):
15
+ # Add user input to history
16
+ history.append((user_input, ""))
17
+
18
+ if not history:
19
+ return ""
20
 
21
+ last_user_message = history[-1][0]
22
+ user_input = clean_text(last_user_message)
23
+
24
+ if len(user_input) > 0:
25
+ result = model(question=user_input, context="Placeholder text")
26
+ answer = result['answer']
27
+ history[-1] = (last_user_message, answer)
28
+
29
+ return f"AI: {answer}"
30
+
31
+ st.title("Simple Chat App using DistilBert Model (HuggingFace & Streamlit)")
32
+
33
+ for i in range(len(history)):
34
+ message = history[i][0]
35
+ response = history[i][1]
36
+
37
+ if i % 2 == 0:
38
+ col1, col2 = st.beta_columns([0.8, 0.2])
39
+ with col1:
40
+ st.markdown(f">> {message}")
41
+ with col2:
42
+ st.write("")
43
+ else:
44
+ col1, col2 = st.beta_columns([0.8, 0.2])
45
+ with col1:
46
+ st.markdown(f" {response}")
47
+ with col2:
48
+ st.button("Clear")
49
+
50
+ new_message = st.text_area("Type something...")
51
+ if st.button("Submit"):
52
+ generated_response = generate_response(new_message)
53
+ st.markdown(generated_response)