Update app.py
Browse files
app.py
CHANGED
@@ -2,53 +2,39 @@ import torch
|
|
2 |
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
|
3 |
import json
|
4 |
import streamlit as st
|
|
|
5 |
|
6 |
model_name = "distilbert-base-cased"
|
7 |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
8 |
model = DistilBertForQuestionAnswering.from_pretrained(model_name)
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
start_index = torch.argmax(start_scores)
|
21 |
-
end_index = torch.argmax(end_scores) + 1
|
22 |
-
formatted_answer = format_response(start_index, end_index - 1, context[start_index:end_index].tolist())
|
23 |
-
return formatted_answer
|
24 |
-
|
25 |
-
def update_displayed_query(query):
|
26 |
-
st.session_state['current_query'] = query
|
27 |
-
|
28 |
-
def interactive():
|
29 |
-
#print("Hi! I am a simple AI chatbot built using Hugging Face and Streamlit.")
|
30 |
-
query = ""
|
31 |
-
while query != st.session_state.get('current_query', ''):
|
32 |
-
if ('current_query' not in st.session_state) or (st.session_state['current_query'] != query):
|
33 |
-
st.session_state['current_query'] = query
|
34 |
-
st.text_input("Ask me something or type 'quit' to exit:", "", key='my_input', on_change=handle_query_change, help="Enter your question here...") # added the help parameter to show the hint
|
35 |
-
st.echo(False) # hide the default output of the text input widget
|
36 |
-
|
37 |
-
if ('shown_query' not in st.session_state) or (st.session_state['shown_query'] != query):
|
38 |
-
st.session_state['shown_query'] = query
|
39 |
-
if len(query) > 0 and query != 'quit':
|
40 |
-
try:
|
41 |
-
context = "The capital of France is Paris."
|
42 |
-
response = get_answers(query, context)
|
43 |
-
st.write(json.dumps(response))
|
44 |
-
except Exception as e:
|
45 |
-
st.write(f"Error occurred: {str(e)}")
|
46 |
-
|
47 |
-
query = st.session_state.get('current_query', '')
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
-
st.set_page_config(layout="wide")
|
51 |
-
|
52 |
-
st.
|
53 |
-
st.
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
|
3 |
import json
|
4 |
import streamlit as st
|
5 |
+
from transformers import pipeline
|
6 |
|
7 |
model_name = "distilbert-base-cased"
|
8 |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
9 |
model = DistilBertForQuestionAnswering.from_pretrained(model_name)
|
10 |
|
11 |
+
def get_user_input():
|
12 |
+
return st.chat_input("You:", "")
|
13 |
|
14 |
+
def generate_response(prompt):
|
15 |
+
try:
|
16 |
+
response = model(prompt, max_length=50, do_sample=True)
|
17 |
+
return response[0]['generated_text'].strip()
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error while generating response: {str(e)}")
|
20 |
+
return "Sorry, I cannot respond right now."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
if __name__ == "__main__":
|
23 |
+
st.set_page_config(page_title="AI Chat App", page_icon=":robot_face:", layout="wide")
|
24 |
+
|
25 |
+
st.write("# Welcome to the AI Chat App! :wave:")
|
26 |
+
st.markdown("<style>.container { width:100% !important; }</style>", unsafe_allow_html=True)
|
27 |
+
|
28 |
+
user_input = ""
|
29 |
+
while True:
|
30 |
+
# Get user input
|
31 |
+
user_input = get_user_input()
|
32 |
+
|
33 |
+
if not user_input:
|
34 |
+
continue
|
35 |
+
|
36 |
+
# Generate response
|
37 |
+
response = generate_response(user_input)
|
38 |
+
|
39 |
+
# Display response
|
40 |
+
st.markdown(f"<b>AI:</b> {response}", unsafe_allow_html=True)
|