Spaces:
Runtime error
Runtime error
finall
Browse files
app.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import transformers
|
3 |
+
from transformers import (
|
4 |
+
GPT2Config,
|
5 |
+
GPT2Tokenizer)
|
6 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
|
7 |
+
from model_file import FlaxGPT2ForMultipleChoice
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
import random
|
11 |
+
|
12 |
+
st.title('GPT2 for common sense reasoning')
|
13 |
+
st.write('Multiple Choice Question Answering using CosmosQA Dataset')
|
14 |
+
|
15 |
+
change_example = st.checkbox("Try Random Examples")
|
16 |
+
qquestion=['In the future , will this person go to see other bands play ?','What may have been the reason for you and your roommate wanting a mellow night ?','what may have happened before this event ?','What may happen during one of your visits to Conneticut ?']
|
17 |
+
ccontext=['Good Old War and person L : I saw both of these bands Wednesday night , and they both blew me away . seriously . Good Old War is acoustic and makes me smile . I really can not help but be happy when I listen to them ; I think it is the fact that they seemed so happy themselves when they played .','Last night , the roomie and I were both feeling pretty mellow . My week has been over scheduled , and it wore me out and prevented me from working one day . So I did n\'t make plans for Friday night , much as I wanted to . Instead , we went with our moods . And our moods said , " hey , why do n\'t we go out for ice cream , walk to the video store , rent a movie , walk to the firehouse , chat up Fireman Dave and then walk home to watch the movie ? ',
|
18 |
+
'Sensible people would have checked the weather or preemptively backed the truck into the carport . But that is not the kind of people we are , so I pulled on a sweatshirt , backed the other two cars out of the driveway , and had Mike back the truck in safely . I hope our neighbors were entertained . My pajamas legs were soaked .',
|
19 |
+
'After spending a few days in New York City this week , I ventured into Connecticut to spend the weekend with some friends . We had great weather and I enjoyed reconnecting with them after way too long . We enjoyed some great meals including Lenny & Joe \'s for amazing fried clams and lobster rolls , the River Tavern in Chester , and Rourke Diner in Middletown ( try the " Irish Embassy " for breakfast ) .']
|
20 |
+
cchoice_0=['None of the above choices .','I was not doing too much work that week .','We left the truck and its load just get rained on','I would avoid old relationships .']
|
21 |
+
cchoice_1=['This person likes music and likes to see the show , they will see other bands play .','I was burnt out from too much work .','None of the above choices .','I would end up going to New York instead .']
|
22 |
+
cchoice_2=['This person only likes Good Old War and Person L , no other bands .','My roommate was burnt out from too much work .','We did not check weather to know if we should move the vehicles around to protect the load','I will only eat Irish food .']
|
23 |
+
cchoice_3=['Other Bands is not on tour and this person can not see them .','None of the above choices .','The pickup fit around the cars and into carport','I would eat some amazing food .']
|
24 |
+
|
25 |
+
|
26 |
+
if change_example:
|
27 |
+
number=random.randint(1,3)
|
28 |
+
context=st.text_area('Context',ccontext[number],height=70)
|
29 |
+
question=st.text_input('Question',qquestion[number])
|
30 |
+
buff, col, buff2 = st.beta_columns([5,1,2])
|
31 |
+
choice_a=buff.text_input('choice 0',cchoice_0[number])
|
32 |
+
choice_b=buff.text_input('choice 1',cchoice_1[number])
|
33 |
+
choice_c=buff.text_input('choice 2',cchoice_2[number])
|
34 |
+
choice_d=buff.text_input('choice 3',cchoice_3[number])
|
35 |
+
else:
|
36 |
+
number=0
|
37 |
+
context=st.text_area('Context',ccontext[number],height=35)
|
38 |
+
question=st.text_input('Question',qquestion[number])
|
39 |
+
buff, col, buff2 = st.beta_columns([5,1,2])
|
40 |
+
choice_a=buff.text_input('choice 0',cchoice_0[number])
|
41 |
+
choice_b=buff.text_input('choice 1',cchoice_1[number])
|
42 |
+
choice_c=buff.text_input('choice 2',cchoice_2[number])
|
43 |
+
choice_d=buff.text_input('choice 3',cchoice_3[number])
|
44 |
+
#context=st.text_area('Context',height=25)
|
45 |
+
#question=st.text_input('Question')
|
46 |
+
|
47 |
+
#choice_a=buff.text_input('choice 0')
|
48 |
+
#choice_b=buff.text_input('choice 1')
|
49 |
+
#choice_c=buff.text_input('choice 2')
|
50 |
+
#choice_d=buff.text_input('choice 3')
|
51 |
+
a={}
|
52 |
+
def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
|
53 |
+
a['context&question']=context+question
|
54 |
+
a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
|
55 |
+
a['second_sentence']=choice_a,choice_b,choice_c,choice_d
|
56 |
+
return a
|
57 |
+
|
58 |
+
preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
|
59 |
+
|
60 |
+
def tokenize(examples):
|
61 |
+
b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,return_tensors='jax')
|
62 |
+
return b
|
63 |
+
|
64 |
+
|
65 |
+
tokenized_data=tokenize(preprocessed_data)
|
66 |
+
|
67 |
+
|
68 |
+
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
|
69 |
+
|
70 |
+
input_id=jnp.array(tokenized_data['input_ids'])
|
71 |
+
att_mask=jnp.array(tokenized_data['attention_mask'])
|
72 |
+
|
73 |
+
input_id=input_id.reshape(1,4,-1)
|
74 |
+
att_mask=att_mask.reshape(1,4,-1)
|
75 |
+
|
76 |
+
|
77 |
+
if st.button("Run"):
|
78 |
+
with st.spinner(text="Getting results..."):
|
79 |
+
outputs=model(input_id,att_mask)
|
80 |
+
final_output=jnp.argmax(outputs,axis=-1)
|
81 |
+
#output=jax.device_get(final_output).item()
|
82 |
+
st.success(f"The answer is choice {final_output}")
|
83 |
+
|
84 |
+
|
85 |
+
|