subu4444 commited on
Commit
f9b83a8
1 Parent(s): e347113

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import gradio as gr
2
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
3
  import json
 
4
 
5
  # Define hyperparameters
6
- learning_rate = 3e-5
7
- batch_size = 16
8
- epochs = 3
9
- max_seq_length = 512
10
- warmup_steps = 100
11
- weight_decay = 0.01
12
- dropout_prob = 0.1
13
- gradient_clip_value = 1.0
14
 
15
  context_val = ''
16
 
@@ -35,10 +36,19 @@ def q_n_a_fn(context, text):
35
  with torch.no_grad():
36
  outputs = q_n_a_model(**inputs)
37
 
38
- # Decode and return the answer
39
  start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
40
- answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx+1])
41
-
 
 
 
 
 
 
 
 
 
42
  return answer
43
 
44
  def classification_fn(text):
@@ -75,4 +85,4 @@ with gr.Blocks(theme='gradio/soft') as demo:
75
  gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
76
 
77
  if __name__ == "__main__":
78
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
3
  import json
4
+ import torch
5
 
6
  # Define hyperparameters
7
+ learning_rate = 3e-5 # Slightly lower learning rate
8
+ batch_size = 8 # Smaller batch size to allow for more precise updates
9
+ epochs = 4 # Slightly more training epochs
10
+ max_seq_length = 256 # Smaller sequence length, especially if the majority of your questions and contexts are shorter
11
+ warmup_steps = 200 # Longer warmup phase
12
+ weight_decay = 0.01 # Keep weight decay as it is
13
+ dropout_prob = 0.2 # Slightly higher dropout for regularization
14
+ gradient_clip_value = 1.0 # Keep gradient clip value as it is
15
 
16
  context_val = ''
17
 
 
36
  with torch.no_grad():
37
  outputs = q_n_a_model(**inputs)
38
 
39
+ # Get the predicted answer span indices
40
  start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
41
+
42
+ # Ensure indices are within bounds
43
+ start_idx = min(start_idx, len(inputs["input_ids"][0]) - 1)
44
+ end_idx = min(end_idx, len(inputs["input_ids"][0]) - 1)
45
+
46
+ # Find the answer tokens in the input
47
+ answer_tokens = inputs["input_ids"][0][start_idx : end_idx + 1]
48
+
49
+ # Decode the answer tokens into a human-readable answer
50
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
51
+
52
  return answer
53
 
54
  def classification_fn(text):
 
85
  gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
86
 
87
  if __name__ == "__main__":
88
+ demo.launch()