zupposhi-maker / Zmaker.py
Oishiyo's picture
コメント文の更新
93770f8
import torch
from transformers import AutoModelForCausalLM, T5Tokenizer
import csv, re, mojimoji
class Zmaker:
#GPT2のモデル名
gpt_model_name = "rinna/japanese-gpt2-medium"
#文章の最大長
min_len, max_len = 1, 128
#予測時のパラメータ
top_k, top_p = 40, 0.95 #top-k検索の閾値
num_text = 1 #出力する文の数
temp = 0.1
repeat_ngram_size = 1
#推論にCPU利用を強制するか
use_cpu = True
def __init__(self, ft_path = None):
"""コンストラクタ
コンストラクタ。モデルをファイルから読み込む場合と,
新規作成する場合で動作を分ける.
Args:
ft_path : ファインチューニングされたモデルのパス.
Returns:
なし
"""
#モデルの設定
self.__SetModel(ft_path)
#モデルの状態をCPUかGPUかで切り替える
if self.use_cpu: #CPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(device)
def __SetModel(self, ft_path = None):
"""GPT2の設定
GPT2のTokenizerおよびモデルを設定する.
ユーザー定義後と顔文字も語彙として認識されるように設定する.
Args:
ft_path : ファインチューニング済みのモデルを読み込む
何も指定しないとself.gpt_model_nameの事前学習モデルを
ネットからダウンロードする.
Returns:
なし
"""
#GPT2のTokenizerのインスタンスを生成
self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name)
self.tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
#モデルの読み込み
if ft_path is not None:
self.model = AutoModelForCausalLM.from_pretrained(
ft_path, #torch_dtype = torch.bfloat16
)
else:
print("fine-tuned model was not found")
#モデルをevalモードに
self.model.eval()
def __TextCleaning(self, texts):
"""テキストの前処理をする
テキストの前処理を行う.具体的に行うこととしては...
・全角/半角スペースの除去
・半角数字/アルファベットの全角化
"""
#半角スペース,タブ,改行改ページを削除
texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]
#半角/全角を変換
texts = [mojimoji.han_to_zen(t) for t in texts]
return texts
def GenLetter(self, prompt):
"""怪文書の生成
GPT2で怪文書を生成する.
promptに続く文章を生成して出力する
Args:
prompt : 文章の先頭
Retunrs:
生成された文章のリスト
"""
#テキストをクリーニング
prompt_clean = [prompt]
#文章をtokenizerでエンコード
x = self.tokenizer.encode(
prompt_clean[0], return_tensors="pt",
add_special_tokens=False
)
#デバイスの選択
if self.use_cpu: #CPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = x.to(device)
#gpt2による推論
with torch.no_grad():
y = self.model.generate(
x, #入力
min_length=self.min_len, # 文章の最小長
max_length=self.max_len, # 文章の最大長
do_sample=True, # 次の単語を確率で選ぶ
top_k=self.top_k, # Top-Kサンプリング
top_p=self.top_p, # Top-pサンプリング
temperature=self.temp, # 確率分布の調整
no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか
num_return_sequences=self.num_text, # 生成する文章の数
pad_token_id=self.tokenizer.pad_token_id, # パディングのトークンID
bos_token_id=self.tokenizer.bos_token_id, # テキスト先頭のトークンID
eos_token_id=self.tokenizer.eos_token_id, # テキスト終端のトークンID
early_stopping=True
)
# 特殊トークンをスキップして推論結果を文章にデコード
res = self.tokenizer.batch_decode(y, skip_special_tokens=True)
return res