Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
3 |
import json
|
|
|
4 |
|
5 |
# Define hyperparameters
|
6 |
-
learning_rate = 3e-5
|
7 |
-
batch_size =
|
8 |
-
epochs =
|
9 |
-
max_seq_length =
|
10 |
-
warmup_steps =
|
11 |
-
weight_decay = 0.01
|
12 |
-
dropout_prob = 0.
|
13 |
-
gradient_clip_value = 1.0
|
14 |
|
15 |
context_val = ''
|
16 |
|
@@ -35,10 +36,19 @@ def q_n_a_fn(context, text):
|
|
35 |
with torch.no_grad():
|
36 |
outputs = q_n_a_model(**inputs)
|
37 |
|
38 |
-
#
|
39 |
start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
return answer
|
43 |
|
44 |
def classification_fn(text):
|
@@ -75,4 +85,4 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
75 |
gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
3 |
import json
|
4 |
+
import torch
|
5 |
|
6 |
# Define hyperparameters
|
7 |
+
learning_rate = 3e-5 # Slightly lower learning rate
|
8 |
+
batch_size = 8 # Smaller batch size to allow for more precise updates
|
9 |
+
epochs = 4 # Slightly more training epochs
|
10 |
+
max_seq_length = 256 # Smaller sequence length, especially if the majority of your questions and contexts are shorter
|
11 |
+
warmup_steps = 200 # Longer warmup phase
|
12 |
+
weight_decay = 0.01 # Keep weight decay as it is
|
13 |
+
dropout_prob = 0.2 # Slightly higher dropout for regularization
|
14 |
+
gradient_clip_value = 1.0 # Keep gradient clip value as it is
|
15 |
|
16 |
context_val = ''
|
17 |
|
|
|
36 |
with torch.no_grad():
|
37 |
outputs = q_n_a_model(**inputs)
|
38 |
|
39 |
+
# Get the predicted answer span indices
|
40 |
start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
|
41 |
+
|
42 |
+
# Ensure indices are within bounds
|
43 |
+
start_idx = min(start_idx, len(inputs["input_ids"][0]) - 1)
|
44 |
+
end_idx = min(end_idx, len(inputs["input_ids"][0]) - 1)
|
45 |
+
|
46 |
+
# Find the answer tokens in the input
|
47 |
+
answer_tokens = inputs["input_ids"][0][start_idx : end_idx + 1]
|
48 |
+
|
49 |
+
# Decode the answer tokens into a human-readable answer
|
50 |
+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
51 |
+
|
52 |
return answer
|
53 |
|
54 |
def classification_fn(text):
|
|
|
85 |
gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
+
demo.launch()
|