juanpasanper commited on
Commit
2f82811
1 Parent(s): b1fb2c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -1,5 +1,21 @@
1
  import gradio as gr
2
- model = Model.from_pretrained(model_name=este_si_me_sirvio.bin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  model.load_state_dict(torch.load(juanpasanper/tigo_question_answer))
4
  def question_answer(context, question):
5
  predictions, raw_outputs = model.predict([{"context": context, "qas": [{"question": question, "id": "0",}],}])
 
1
  import gradio as gr
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ class Model(nn.Module):
6
+ def __init__(self, model_name='bert_model'):
7
+ super(Model, self).__init__()
8
+ self.bert = transformers.BertModel.from_pretrained(config['MODEL_ID'], return_dict=False)
9
+ self.bert_drop = nn.Dropout(0.0)
10
+ self.out = nn.Linear(config['HIDDEN_SIZE'], config['NUM_LABELS'])
11
+ self.model_name = model_name
12
+ def forward(self, ids, mask, token_type_ids):
13
+ _, o2 = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
14
+ bo = self.bert_drop(o2)
15
+ output = self.out(bo)
16
+ return output
17
+
18
+ model = Model(model_name=este_si_me_sirvio.bin)
19
  model.load_state_dict(torch.load(juanpasanper/tigo_question_answer))
20
  def question_answer(context, question):
21
  predictions, raw_outputs = model.predict([{"context": context, "qas": [{"question": question, "id": "0",}],}])