S-Dreamer commited on
Commit
ca0727a
·
verified ·
1 Parent(s): 288b7ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,26 +1,31 @@
1
  import gradio as gr
2
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
- MODEL_NAME = "your-hf-username/raft-qa"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
9
 
10
  def answer_question(context, question):
11
- inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
 
 
12
  with torch.no_grad():
13
  outputs = model(**inputs)
14
 
15
- start_scores, end_scores = outputs.start_logits, outputs.end_logits
16
- start_idx = torch.argmax(start_scores)
17
- end_idx = torch.argmax(end_scores) + 1
18
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_idx:end_idx]))
 
 
19
 
20
  return answer if answer.strip() else "No answer found."
21
 
22
  # Define UI
23
- with gr.Blocks(theme="soft") as demo:
24
  gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA")
25
  gr.Markdown("Ask a question based on the provided context and see how RAFT improves response accuracy!")
26
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
  import torch
4
+ import torch.nn.functional as F
5
 
6
  # Load model and tokenizer
7
+ MODEL_NAME = "S-Dreamer/raft-qa-space"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
10
 
11
  def answer_question(context, question):
12
+ inputs = tokenizer(
13
+ question, context, return_tensors="pt", truncation=True, max_length=512, stride=128, return_overflowing_tokens=True
14
+ )
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
 
18
+ start_probs = F.softmax(outputs.start_logits, dim=-1)
19
+ end_probs = F.softmax(outputs.end_logits, dim=-1)
20
+ start_idx = torch.argmax(start_probs)
21
+ end_idx = torch.argmax(end_probs) + 1
22
+
23
+ answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx], skip_special_tokens=True)
24
 
25
  return answer if answer.strip() else "No answer found."
26
 
27
  # Define UI
28
+ with gr.Blocks() as demo:
29
  gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA")
30
  gr.Markdown("Ask a question based on the provided context and see how RAFT improves response accuracy!")
31