Sadat07 commited on
Commit
886038e
1 Parent(s): 49841d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +49 -6
README.md CHANGED
@@ -57,13 +57,56 @@ The model's predictions may be biased or overly reliant on the training dataset,
57
  Use the code below to get started with the model.
58
 
59
  ```python
60
- from transformers import pipeline
61
- qa_pipeline = pipeline('question-answering', model='bert_squad')
62
 
63
- context = "BERT is a transformers model for natural language processing."
64
- question = "What is BERT used for?"
65
- result = qa_pipeline(question=question, context=context)
66
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ```
68
 
69
  ## Training Details
 
57
  Use the code below to get started with the model.
58
 
59
  ```python
60
+ import torch
61
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
62
 
63
+ # Load the model and tokenizer
64
+ model_name = "Sadat07/bert_squad"
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
66
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ model.to(device)
70
+
71
+
72
+ context = """
73
+ The person who invented light was
74
+ Thomas Edison.He was born in 1879.
75
+ """
76
+ question = "When did Thomas Edison invent?"
77
+
78
+
79
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
80
+ input_ids = inputs["input_ids"].to(device)
81
+ attention_mask = inputs["attention_mask"].to(device)
82
+
83
+
84
+ print("Tokenized Input:", tokenizer.decode(input_ids[0]))
85
+
86
+ # Perform inference
87
+ with torch.no_grad():
88
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
89
+ start_scores = outputs.start_logits
90
+ end_scores = outputs.end_logits
91
+
92
+ # Logits
93
+ print("Start logits:", start_scores)
94
+ print("End logits:", end_scores)
95
+
96
+ # Get start and end indices
97
+ start_idx = torch.argmax(start_scores)
98
+ end_idx = torch.argmax(end_scores) + 1
99
+
100
+ # Decode the answer
101
+ if start_idx >= end_idx:
102
+ print("Model did not predict a valid answer. Please check context and question.")
103
+ else:
104
+ answer = tokenizer.convert_tokens_to_string(
105
+ tokenizer.convert_ids_to_tokens(input_ids[0][start_idx:end_idx])
106
+ )
107
+ print(f"Question: {question}")
108
+ print(f"Answer: {answer}")
109
+
110
  ```
111
 
112
  ## Training Details