Hackavist commited on
Commit
4190c41
·
verified ·
1 Parent(s): 8182c47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -3,7 +3,8 @@ from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
3
 
4
  model_name = "distilbert-base-cased"
5
  tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
6
- model = DistilBertForQuestionAnswering.from_pretrained(model_name)
 
7
 
8
  def format_response(start_index, end_index, raw_answer):
9
  answer_tokens = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(i)[0] for i in range(start_index, end_index+1)])
@@ -22,16 +23,16 @@ def main():
22
  print("Type 'quit' to exit the program.")
23
  query = ""
24
  while True:
25
- query = input("Your Question: ")
26
  if query.lower() == "quit":
27
  break
28
- if len(query) > 0:
29
  context = "The capital of France is Paris."
30
  try:
31
  response = get_answers(query, context)
32
- print(f"Response: {response}")
33
  except Exception as e:
34
- print(f"Error occurred: {str(e)}")
35
 
36
  if __name__ == "__main__":
37
  main()
 
3
 
4
  model_name = "distilbert-base-cased"
5
  tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
6
+ # Set output_attentions and output_hidden_states to False to prevent weight initialization warnings
7
+ model = DistilBertForQuestionAnswering.from_pretrained(model_name, output_attentions=False, output_hidden_states=False)
8
 
9
  def format_response(start_index, end_index, raw_answer):
10
  answer_tokens = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(i)[0] for i in range(start_index, end_index+1)])
 
23
  print("Type 'quit' to exit the program.")
24
  query = ""
25
  while True:
26
+ query = input("Your Question: ").strip()
27
  if query.lower() == "quit":
28
  break
29
+ elif len(query) > 0:
30
  context = "The capital of France is Paris."
31
  try:
32
  response = get_answers(query, context)
33
+ print(f"\nResponse: {response}\n")
34
  except Exception as e:
35
+ print(f"\nError occurred: {str(e)}\n")
36
 
37
  if __name__ == "__main__":
38
  main()