Hackavist commited on
Commit
22a519a
1 Parent(s): 6092c26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
3
+
4
+ model_name = "distilbert-base-cased"
5
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
6
+ model = DistilBertForQuestionAnswering.from_pretrained(model_name)
7
+
8
+ def format_response(start_index, end_index, raw_answer):
9
+ answer_tokens = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(i)[0] for i in range(start_index, end_index+1)])
10
+ return answer_tokens.strip()
11
+
12
+ def get_answers(question, context):
13
+ inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
14
+ start_scores, end_scores = model(**inputs).values()
15
+ start_index = torch.argmax(start_scores)
16
+ end_index = torch.argmax(end_scores) + 1
17
+ formatted_answer = format_response(start_index, end_index - 1, context[start_index:end_index].tolist())
18
+ return formatted_answer
19
+
20
+ def main():
21
+ print("Hi! I am a simple AI chatbot built using Hugging Face.")
22
+ print("Type 'quit' to exit the program.")
23
+ query = ""
24
+ while True:
25
+ query = input("\nYour Question: ")
26
+ if query.lower() == "quit":
27
+ break
28
+ if len(query) > 0:
29
+ context = "The capital of France is Paris."
30
+ try:
31
+ response = get_answers(query, context)
32
+ print(f"Response: {response}")
33
+ except Exception as e:
34
+ print(f"Error occurred: {str(e)}")
35
+
36
+ if __name__ == "__main__":
37
+ main()