gojiteji commited on
Commit
df8a0e2
1 Parent(s): 6dba677

fixed enc-dec model

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -29,16 +29,13 @@ def BALTHASAR(sue):#mT5
29
  p_answer=None
30
  probs=None
31
  i=0
32
- txt="<pad>"
33
- probs=mT5Model(inputs_embeds=encoder_output.last_hidden_state,decoder_input_ids=mT5Tokenizer(txt,return_tensors="pt").input_ids).logits[0]
34
- id=torch.argmax(probs[i+1])
35
- txt=txt+"<X>"
36
  i=i+1
37
- probs=mT5Model(inputs_embeds=encoder_output.last_hidden_state,decoder_input_ids=mT5Tokenizer(txt,return_tensors="pt").input_ids).logits[0]
38
- id=torch.argmax(probs[i+1])
39
  txt=txt+mT5Tokenizer.decode(id)
40
- votes.append(1 if probs[i+1][allow]>probs[i+1][deny] else -1)
41
- return "承認" if probs[i+1][allow]>probs[i+1][deny] else "否定"
42
 
43
  def CASPER(sue):#GPT2
44
  allow=GPT2Tokenizer("承認").input_ids[1]
 
29
  p_answer=None
30
  probs=None
31
  i=0
32
+ txt="<pad><X>"
 
 
 
33
  i=i+1
34
+ probs=mT5Model(inputs_embeds=encoder_output.last_hidden_state,decoder_input_ids=mT5Tokenizer(txt,return_tensors="pt").input_ids[:,:-1]).logits[0]
35
+ id=torch.argmax(probs[-1])
36
  txt=txt+mT5Tokenizer.decode(id)
37
+ votes.append(1 if probs[-1][allow]>probs[-1][deny] else -1)
38
+ return "承認" if probs[-1][allow]>probs[-1][deny] else "否定"
39
 
40
  def CASPER(sue):#GPT2
41
  allow=GPT2Tokenizer("承認").input_ids[1]