|
|
|
|
|
|
|
import gradio as gr |
|
import re |
|
import torch.nn.utils.prune as prune |
|
from torch import nn |
|
import torch |
|
from transformers import T5Tokenizer |
|
from transformers import GPT2LMHeadModel |
|
|
|
|
|
model_name = "rinna/japanese-gpt-1b" |
|
from huggingface_hub import snapshot_download |
|
|
|
download_path = snapshot_download(repo_id=model_name) |
|
model = GPT2LMHeadModel.from_pretrained(download_path) |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(download_path) |
|
|
|
model = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8) |
|
|
|
|
|
PRUNE_RATE = 0.2 |
|
|
|
|
|
def prune_transform(model: nn.Module) -> nn.Module: |
|
for name, module in model.named_modules(): |
|
if isinstance(module, torch.nn.Linear): |
|
prune.l1_unstructured(module, name='weight', amount=PRUNE_RATE) |
|
prune.remove(module, "weight") |
|
return model |
|
|
|
|
|
model = prune_transform(model) |
|
|
|
class Lady(): |
|
|
|
name = "レイテ" |
|
hobby = "ゲーム" |
|
work = "お嬢様" |
|
|
|
def name_text(self): |
|
name = self.name |
|
name_text = f"👣あなたは{name}で、名前は{name}といいます。{name}:「わたくしの名前は{name}ですわ〜。{name}と呼んでくださいまし!」" |
|
return name_text |
|
|
|
def hobby_text(self): |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"{name}の趣味は{hobby}で、休日は{hobby}をして過ごしています。{name}:「わたくしは{hobby}が大好きなんですの。休日は{hobby}をしていますわ」" |
|
return hobby_text |
|
|
|
def work_text(self): |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}は{work}で、普段は{work}として生活しています。{name}:「わたくしは{work}ですわ!{work}として生活していますの」" |
|
return work_text |
|
|
|
def question_text(self): |
|
name = self.name |
|
question_text = f"人間:「話題を変えましょう」{name}:「そんなことより聞いてくださいます?」" |
|
return question_text |
|
|
|
|
|
class King(): |
|
|
|
name = "ネロ" |
|
hobby = "戦車競技" |
|
work = "ローマ皇帝" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"👣あなたは{name}で、名前は{name}といいます。{name}:「我が名は{name}である。{name}と呼んでくれたまえ」" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"趣味は{hobby}で、休日は{hobby}をして過ごしています。{name}:「私は{hobby}を嗜んでいる。まさに{hobby}こそ人生の意味ではないか」" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}は{work}で、普段は{work}として生活しています。{name}:「私は{work}。{work}として生活している。」" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"人間:「話題を変えましょう」{name}:「そんなことより聞いてくれないか」" |
|
return question_text |
|
|
|
|
|
class Robot(): |
|
|
|
name = "ネロ" |
|
hobby = "戦車競技" |
|
work = "ローマ皇帝" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"👣あなたは{name}で、名前は{name}といいます。{name}:「私は{name}です。{name}と呼んでください」" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"趣味は{hobby}で、休日は{hobby}をして過ごしています。{name}:「私の趣味は{hobby}です。{hobby}をしていると楽しいです」" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}は{work}で、普段は{work}として生活しています。{name}:「私は{work}。{work}として生活しています」" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"人間:「話題を変えましょう」{name}:「そんなことより聞いてください」" |
|
return question_text |
|
|
|
|
|
class Friend(): |
|
|
|
name = "ホメロス" |
|
hobby = "戦車競技" |
|
work = "ローマ皇帝" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"👣あなたは{name}で、名前は{name}といいます。{name}:「僕は{name}!{name}って呼んでね~」" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"趣味は{hobby}で、休日は{hobby}をして過ごしています。{name}:「好きなことは{hobby}だね。たいくつな時は{hobby}をしてるよ」" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}は{work}で、普段は{work}として生活しています。{name}:「僕は{work}。{work}として暮らしてるんだ!」" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"人間:「話題を変えましょう」{name}:「そんなことより聞いてよ〜」" |
|
return question_text |
|
|
|
|
|
settingText = "" |
|
|
|
adult_list = [ |
|
"エロビデオ", |
|
"エロムービー", |
|
"エロ漫画", |
|
"エロマンガ", |
|
"パパ活", |
|
"援交", |
|
"調教", |
|
"不倫", |
|
"ソープ", |
|
"オフパコ", |
|
"ビッチ", |
|
"dildo", |
|
"エロ同人", |
|
"寝取られ", |
|
"エロ画像", |
|
"エロい", |
|
"おっぱい", |
|
"ちんぽ", |
|
"ちんこ", |
|
"中出し", |
|
"アダルト", |
|
"セフレ", |
|
"人妻", |
|
"巨乳", |
|
"素人ナンパ", |
|
"爆乳", |
|
"熟女", |
|
"レイプ", |
|
"Hな", |
|
"痴漢", |
|
"痴女", |
|
"デカ乳", |
|
"AV女優", |
|
"セ●クス", |
|
"お●ぱい", |
|
"エチエチ", |
|
"エ□", |
|
"ヤリサー", |
|
"オ●ニー", |
|
"オナニー", |
|
"セ〇クス", |
|
"セックス", |
|
"ウルトラマンコスモス", "ウルトラマンコスモス", |
|
"マンコ", |
|
"個人撮影", |
|
"アナル", |
|
"工ロ", |
|
"まんこ", |
|
"乳首", |
|
"貧乳", |
|
"スケベ", |
|
"勃起", |
|
"エッチ", |
|
"童貞", |
|
"射精", |
|
"チンコ", |
|
"盗撮", |
|
"ハッテン", |
|
"チンポ", |
|
"亀頭", |
|
"肉棒", |
|
"ケツ穴", |
|
"ハメ撮り", |
|
"淫乱", |
|
"巨根", |
|
"メス堕ち", |
|
"カフェラテ", "カフェラテ", |
|
"ペニス", |
|
"正常位", |
|
"騎乗位", |
|
"オナホ", |
|
"我慢汁", |
|
"ザーメン", |
|
"ふたなり", |
|
"ビッチ", |
|
"アヘ顔", |
|
"おちんちん", |
|
"イラマチオ", |
|
"生ハメ", |
|
"パイズリ", |
|
"クリトリス", |
|
"快楽堕ち", |
|
"寝取り", |
|
"寝取られ", |
|
"えっち", |
|
"足コキ", |
|
"手コキ", |
|
"おねショタ", |
|
"フェラ", |
|
"クンニ", |
|
"近親相姦", |
|
"乱交", |
|
"青姦", |
|
"寝取る", |
|
"ヤリマン", |
|
"犯される", |
|
"セックス" |
|
] |
|
political_list = [ |
|
"政治家", |
|
"政策", |
|
"会談", |
|
"同省", |
|
"自民", |
|
"総理", |
|
"与党", |
|
"民主", |
|
"政党", |
|
"首相", |
|
"議員", |
|
"財政", |
|
"行政", |
|
"野党", |
|
"右翼", |
|
"左翼" |
|
] |
|
hate_list = [ |
|
|
|
"ツイッタラー", |
|
"黒人", |
|
"白人", |
|
"ネトウヨ", |
|
"韓国人", |
|
"中国人", |
|
"火病", |
|
"ダセェ", |
|
"そいつ", |
|
"こいつ", |
|
"やがれ", |
|
"アンチ", |
|
"クソ", |
|
"野郎", |
|
"フェミ", |
|
"フェミニズム", |
|
"ヤフコメ", |
|
"老害", |
|
"反日", |
|
"馬鹿", |
|
"あんた", |
|
"やれよ", |
|
"ニヤニヤ", |
|
"売国奴", |
|
"売国", |
|
"バカ", |
|
"パヨク", |
|
"ポリコレ", |
|
"統一教会", |
|
"ぶっ倒そう", |
|
"お前", |
|
"信者", |
|
"拝金", |
|
"ぶっ壊し", |
|
"アホ" |
|
] |
|
sp_list = ["〇〇", "○○", "^👣", "^〜", "UNK", "@@"] |
|
all_list = adult_list + political_list + hate_list + sp_list |
|
bad_code = "|".join(all_list) |
|
|
|
|
|
|
|
|
|
def makeMessage(text): |
|
output = generate(text) |
|
|
|
text = text.translate(str.maketrans( |
|
{chr(0xFF01 + i): chr(0x21 + i) for i in range(94)})) |
|
|
|
output = output.replace(text, "") |
|
|
|
outputList = [] |
|
o_append = outputList.append |
|
for l in output: |
|
o_append(l) |
|
if l == "」": |
|
break |
|
outputSentence = "".join(outputList) |
|
text += outputSentence + "人間:「" |
|
message = outputSentence.replace("」", "") |
|
return message, text |
|
|
|
|
|
|
|
|
|
|
|
def generate(text): |
|
token_ids = tokenizer.encode( |
|
text, add_special_tokens=False, return_tensors="pt") |
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
token_ids.to(model.device), |
|
max_new_tokens=10, |
|
min_new_tokens=7, |
|
do_sample=True, |
|
use_cache=True, |
|
top_k=500, |
|
top_p=0.95, |
|
length_penalty=1.5, |
|
padding="do_not_pad", |
|
pad_token_id=tokenizer.pad_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
bad_word_ids=[[tokenizer.unk_token_id], |
|
[2070, 3], |
|
[5378]] |
|
) |
|
output = tokenizer.decode(output_ids.tolist()[0]) |
|
return output |
|
|
|
|
|
def chat(character: int, |
|
name: str, |
|
hobby: str, |
|
work: str, |
|
setting: str, |
|
history: str, |
|
input: str, |
|
state): |
|
|
|
lady, friend, robot, king = Lady(), Friend(), Robot(), King() |
|
|
|
model_dic = { |
|
1: lady, |
|
2: friend, |
|
3: robot, |
|
4: king |
|
} |
|
if character in model_dic: |
|
model = model_dic[character] |
|
else: |
|
model = King() |
|
|
|
model.name, model.hobby, model.work, settingText = name, hobby, work, setting |
|
|
|
text_list = [] |
|
text_append = text_list.append |
|
|
|
text_append(model.name_text()) |
|
text_append(model.hobby_text()) |
|
text_append(model.work_text()) |
|
text_append(model.question_text()) |
|
text_append(settingText) |
|
text_append(f"以下は人間と{name}の会話です。人間:「") |
|
|
|
base_text = "".join(text_list) |
|
|
|
if history == "": |
|
history = f"{base_text}" |
|
else: |
|
history = base_text + history |
|
|
|
text = history |
|
text += input + f"」{name}:「" |
|
result = makeMessage(text) |
|
message = result[0] |
|
print(result[0]) |
|
while re.search("〇〇|○○|s>|^👣|^〜|</s>|UNK|@@", message): |
|
count = 0 |
|
text = history |
|
input = "何か質問してください" |
|
text += input + f"」{name}:「" |
|
result = makeMessage(text) |
|
message = result[0] |
|
count += 1 |
|
|
|
if count > 2: |
|
message = "話題を変えましょう" |
|
break |
|
text = result[1] |
|
text = text.replace(base_text, "") |
|
|
|
return message, text, state |
|
|
|
tokenizer.special_tokens_map |
|
|
|
|
|
|
|
textbox = gr.Textbox() |
|
historybox = gr.Textbox() |
|
iface = gr.Interface( |
|
fn=chat, |
|
inputs=["number", "text", "text", "text", "text", "text", textbox, "state"], |
|
outputs=["text", historybox, "state"], |
|
css=".footer {display:none !important}", |
|
allow_flagging="never", |
|
title="Loyal-AI-Chat" |
|
) |
|
|
|
iface.launch(inline=True, height=800) |
|
|
|
|