Vivek commited on
Commit
b1b7999
1 Parent(s): d006373
Files changed (2) hide show
  1. __pycache__/model_file.cpython-39.pyc +0 -0
  2. app.py +0 -67
__pycache__/model_file.cpython-39.pyc ADDED
Binary file (9 kB). View file
 
app.py DELETED
@@ -1,67 +0,0 @@
1
- import streamlit as st
2
- import transformers
3
- from transformers import (
4
- GPT2Config,
5
- GPT2Tokenizer)
6
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
7
- from model_file import FlaxGPT2ForMultipleChoice
8
- import jax.numpy as jnp
9
-
10
- st.title('GPT2 for common sense reasoning')
11
- st.write('Multiple Choice Question Answering using CosmosQA Dataset')
12
-
13
- context=st.text_area('Context',height=25)
14
- st.write(context)
15
- #context = st.text_input('Context :')
16
-
17
-
18
-
19
-
20
-
21
- question=st.text_input('Question')
22
-
23
-
24
- buff, col, buff2 = st.beta_columns([5,1,2])
25
- choice_a=buff.text_input('choice 0:')
26
- choice_b=buff.text_input('choice 1:')
27
- choice_c=buff.text_input('choice 2:')
28
- choice_d=buff.text_input('choice 3:')
29
-
30
- a={}
31
- def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
32
- a['context&question']=context+question
33
- a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
34
- a['second_sentence']=choice_a,choice_b,choice_c,choice_d
35
- return a
36
-
37
- preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
38
-
39
- def tokenize(examples):
40
- b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
41
- return b
42
-
43
- tokenized_data=tokenize(preprocessed_data)
44
-
45
-
46
- model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
47
-
48
- input_id=jnp.array(tokenized_data['input_ids'])
49
- att_mask=jnp.array(tokenized_data['attention_mask'])
50
-
51
-
52
- if st.button("Run"):
53
- with st.spinner(text="Getting results..."):
54
- outputs=model(input_id,att_mask)
55
- final_output=jnp.argmax(outputs,axis=-1)
56
- if final_output==0:
57
- result='0'
58
- elif final_output==1:
59
- result='1'
60
- elif final_output==2:
61
- result='2'
62
- elif final_output==3:
63
- result='3'
64
- st.success(f"The answer is choice {result1}")
65
-
66
-
67
-