Greatmonkey commited on
Commit
bacc64d
1 Parent(s): 75ae447

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
+
4
+ def main():
5
+ # Custom CSS for styling
6
+ custom_css = """
7
+ <style>
8
+ body {
9
+ font-family: 'Arial', sans-serif;
10
+ background-color: #FF0000; /* Set your desired background color here */
11
+ }
12
+ .title {
13
+ font-size: 2.5em;
14
+ color: #333;
15
+ text-align: center;
16
+ padding: 1em;
17
+ background-color: #3498db;
18
+ color: #fff;
19
+ border-radius: 10px;
20
+ margin-bottom: 1em;
21
+ }
22
+ .input-container {
23
+ margin: 2em;
24
+ }
25
+ .button-container {
26
+ text-align: center;
27
+ }
28
+ .result-container {
29
+ margin: 2em;
30
+ padding: 1em;
31
+ background-color: #fff;
32
+ border-radius: 10px;
33
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
34
+ }
35
+ .answer {
36
+ background-color: #3498db;
37
+ padding: 10px;
38
+ border-radius: 5px;
39
+ margin-top: 10px;
40
+ }
41
+ img {
42
+ max-width: 100%;
43
+ height: auto;
44
+ border-radius: 10px;
45
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
46
+ }
47
+ </style>
48
+ """
49
+ st.markdown(custom_css, unsafe_allow_html=True)
50
+
51
+ # Title
52
+ st.markdown("<div class='title'>Question Answering with Transformers</div>", unsafe_allow_html=True)
53
+
54
+ # Model Selection Dropdown
55
+ model_name = st.selectbox("Select Model", ["deepset/roberta-base-squad2", "bert-large-uncased-whole-word-masking-finetuned-squad", "distilbert-base-cased-distilled-squad",
56
+ "bert-base-uncased",
57
+ "albert-base-v2"])
58
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+
61
+ # Get user input
62
+ st.markdown("<div class='input-container'>", unsafe_allow_html=True)
63
+ context = st.text_area("Enter the context (max 400 words):")
64
+ question = st.text_input("Enter your question:")
65
+ st.markdown("</div>", unsafe_allow_html=True)
66
+
67
+ if st.button("Get Answer"):
68
+ #st.markdown("<div class='result-container'>", unsafe_allow_html=True)
69
+ if not question or not context:
70
+ st.warning("Please enter both a question and a context.")
71
+ else:
72
+ # Tokenize input
73
+ try:
74
+ # Check word count in the context
75
+ if len(context.split()) > 400:
76
+ raise ValueError("Context exceeds 400 words limit.")
77
+
78
+ inputs = tokenizer.encode_plus(question, context, return_tensors='pt')
79
+
80
+ # Get predictions
81
+ outputs = model(**inputs)
82
+ start_logits = outputs.start_logits
83
+ end_logits = outputs.end_logits
84
+
85
+ # Get top N answer spans
86
+ top_n = 3
87
+ start_indexes = start_logits.argsort(dim=1, descending=True)[:, :top_n]
88
+ end_indexes = end_logits.argsort(dim=1, descending=True)[:, :top_n]
89
+
90
+ # Display detailed answers
91
+ st.subheader(f"Question: {question}")
92
+ for i in range(top_n):
93
+ start_index = start_indexes[0, i].item()
94
+ end_index = end_indexes[0, i].item()
95
+ answer = tokenizer.decode(inputs['input_ids'][0, start_index:end_index + 1])
96
+
97
+ # Highlight answer in context
98
+ highlighted_context = f"{context[:start_index]}**{context[start_index:end_index+1]}**{context[end_index+1:]}"
99
+
100
+ # Display confidence scores
101
+ confidence_start = start_logits[0, start_index].item()
102
+ confidence_end = end_logits[0, end_index].item()
103
+ if answer == "":
104
+ continue
105
+ else:
106
+ st.markdown(f"<div class='answer'><strong>Answer:</strong> {answer}<br>"
107
+ f"<strong>Confidence (Start):</strong> {confidence_start:.4f}<br>"
108
+ f"<strong>Confidence (End):</strong> {confidence_end:.4f}</div>", unsafe_allow_html=True)
109
+
110
+ except ValueError as ve:
111
+ st.error(str(ve))
112
+ except Exception as e:
113
+ st.error(f"An error occurred: {e}")
114
+ st.markdown("</div>", unsafe_allow_html=True)
115
+
116
+ if __name__ == "__main__":
117
+ main()