|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoModelForCausalLM |
|
|
|
BERTTokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese") |
|
BERTModel = AutoModelForMaskedLM.from_pretrained("cl-tohoku/bert-base-japanese") |
|
|
|
mBERTTokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") |
|
mBERTModel = AutoModelForMaskedLM.from_pretrained("bert-base-multilingual-cased") |
|
|
|
GPT2Tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium") |
|
GPT2Model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium") |
|
|
|
votes=[] |
|
def MELCHIOR(sue): |
|
allow=BERTTokenizer("承認").input_ids[1] |
|
deny=BERTTokenizer("否定").input_ids[1] |
|
output=BERTModel(**BERTTokenizer('MELCHIORは科学者としての人格を持っています。人間とMELCHIORの対話です。人間「'+sue+'。承認 か 否定 のどちらかで答えてください。」'+"MELCHIOR 「[MASK]」",return_tensors="pt")).logits |
|
BERTTokenizer.batch_decode(torch.argmax(output,-1)) |
|
mask=output[0,-3,:] |
|
votes.append(1 if mask[allow]>mask[deny] else -1) |
|
return "承認" if mask[allow]>mask[deny] else "否定" |
|
|
|
def BALTHASAR(sue): |
|
allow=mBERTTokenizer("Yes").input_ids[1] |
|
deny=mBERTTokenizer("No").input_ids[1] |
|
output=mBERTModel(**mBERTTokenizer('BALTHASARは母としての人格を持っています。人間とBALTHASARの対話です。人間「'+sue+'。YesかNoか。」'+"BALTHASAR 「[MASK]」",return_tensors="pt")).logits |
|
mask=output[0,-3,:] |
|
votes.append(1 if mask[allow]>mask[deny] else -1) |
|
return "承認" if mask[allow]>mask[deny] else "否定" |
|
|
|
|
|
def CASPER(sue): |
|
allow=GPT2Tokenizer("承認").input_ids[1] |
|
deny=GPT2Tokenizer("否定").input_ids[1] |
|
inpt=GPT2Tokenizer('女としての人格を持ったAI・カスパーと人間の対話です。人間「'+sue+'。これに承認か否定か。」'+"カスパー「私は,",return_tensors="pt") |
|
probs=GPT2Model(input_ids=inpt.input_ids[:,:-1],attention_mask=inpt.attention_mask[:,:-1]).logits[0] |
|
i=-1 |
|
p_answer=probs |
|
id=torch.argmax(probs[i]) |
|
votes.append(1 if probs[i][allow]>probs[i][deny] else -1) |
|
return "承認" if probs[i][allow]>probs[i][deny] else "否定" |
|
|
|
|
|
def greet(sue): |
|
text1="BERT-1"+MELCHIOR(sue) |
|
text2="GPT-2"+CASPER(sue) |
|
text3="mBERT-3"+BALTHASAR(sue) |
|
return text1+" "+text2+" "+text3+"\n___\n\n"+("|可決|" if sum(votes[-3:])>0 else "| 否決 |")+"\n___" |
|
|
|
|
|
css="@import url('https://fonts.googleapis.com/css2?family=Shippori+Mincho:wght@800&display=swap'); .gradio-container {background-color: black} .gr-button {background-color: blue;color:black; weight:200%;font-family:'Shippori Mincho', serif;}" |
|
css+=".block{color:orange;} ::placeholder {font-size:35%} .gr-box {text-align: center;font-size: 125%;border-color:orange;background-color: #000000;weight:200%;font-family:'Shippori Mincho', serif;}:disabled {color: orange;opacity:1.0;}" |
|
with gr.Blocks(css=css) as demo: |
|
sue = gr.Textbox(label="NAGI System",placeholder="決議を入力(多数決)") |
|
greet_btn = gr.Button("提訴") |
|
output = gr.Textbox(label="決議",placeholder="本システムは事前学習モデルのpromptにより行われています.決議結果に対して当サービス開発者は一切の責任を負いません.") |
|
greet_btn.click(fn=greet, inputs=sue, outputs=output) |
|
demo.launch() |