import torch from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering import json import streamlit as st model_name = "distilbert-base-cased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) model = DistilBertForQuestionAnswering.from_pretrained(model_name) def format_response(start_index, end_index, raw_answer): answer_tokens = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(i)[0] for i in range(start_index, end_index+1)]) return {'answer': answer_tokens.strip(), 'score': None} def get_answers(question, context): inputs = tokenizer.encode_plus(question, context, return_tensors="pt") start_scores, end_scores = model(**inputs).values() start_index = torch.argmax(start_scores) end_index = torch.argmax(end_scores) + 1 formatted_answer = format_response(start_index, end_index - 1, context[start_index:end_index].tolist()) return formatted_answer def update_displayed_query(query): st.session_state['current_query'] = query def interactive(): print("Hi! I am a simple AI chatbot built using Hugging Face and Streamlit.") query = "" while query != "quit": if ('current_query' not in st.session_state) or (st.session_state['current_query'] != query): st.session_state['current_query'] = query st.text_input("Ask me something or type 'quit' to exit:", "", key='my_input', on_change=lambda x: update_displayed_query(x["new"])) if ('shown_query' not in st.session_state) or (st.session_state['shown_query'] != query): st.session_state['shown_query'] = query if len(query) > 0 and query != 'quit': try: context = "The capital of France is Paris." response = get_answers(query, context) st.write(json.dumps(response)) except Exception as e: st.write(f"Error occurred: {str(e)}") query = st.session_state.get('current_query', '') if __name__ == "__main__": st.set_page_config(layout="wide") st.title('AI Chatbot Built Using Hugging Face and Streamlit') st.subheader('Welcome to Our Intelligent Conversational Agent!') st.write('Please enter your question below or type "quit" to close the session.') interactive()