Hackavist commited on
Commit
dc2cdaf
1 Parent(s): 2a7a1be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
2
+ import json
3
+ import streamlit as st
4
+
5
+ model_name = "distilbert-base-cased"
6
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
7
+ model = DistilBertForQuestionAnswering.from_pretrained(model_name)
8
+
9
+ def format_response(start_index, end_index, raw_answer):
10
+ answer_tokens = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(i)[0] for i in range(start_index, end_index+1)])
11
+ return {'answer': answer_tokens.strip(), 'score': None}
12
+
13
+ def get_answers(question, context):
14
+ inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
15
+ start_scores, end_scores = model(**inputs).values()
16
+ start_index = torch.argmax(start_scores)
17
+ end_index = torch.argmax(end_scores) + 1
18
+ formatted_answer = format_response(start_index, end_index - 1, context[start_index:end_index].tolist())
19
+ return formatted_answer
20
+
21
+ def interactive():
22
+ print("Hi! I am a simple AI chatbot built using Hugging Face.")
23
+ while True:
24
+ query = input("\nAsk me something or type 'quit' to exit:\n").lower().strip()
25
+ if query == "quit":
26
+ break
27
+
28
+ try:
29
+ # Add some basic context here; replace with your own dataset later
30
+ context = "The capital of France is Paris."
31
+ response = get_answers(query, context)
32
+ print(f"\n{json.dumps(response)}")
33
+ except Exception as e:
34
+ print(f"Error occurred: {str(e)}")
35
+
36
+ if __name__ == "__main__":
37
+ interactive()