Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, T5Tokenizer | |
import csv, re, mojimoji | |
class Zmaker: | |
#GPT2のモデル名 | |
gpt_model_name = "rinna/japanese-gpt2-medium" | |
#文章の最大長 | |
min_len, max_len = 1, 128 | |
#予測時のパラメータ | |
top_k, top_p = 40, 0.95 #top-k検索の閾値 | |
num_text = 1 #出力する文の数 | |
temp = 0.1 | |
repeat_ngram_size = 1 | |
#推論にCPU利用を強制するか | |
use_cpu = True | |
def __init__(self, ft_path = None): | |
"""コンストラクタ | |
コンストラクタ。モデルをファイルから読み込む場合と, | |
新規作成する場合で動作を分ける. | |
Args: | |
ft_path : ファインチューニングされたモデルのパス. | |
Returns: | |
なし | |
""" | |
#モデルの設定 | |
self.__SetModel(ft_path) | |
#モデルの状態をCPUかGPUかで切り替える | |
if self.use_cpu: #CPUの利用を強制する場合の処理 | |
device = torch.device('cpu') | |
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
self.model.to(device) | |
def __SetModel(self, ft_path = None): | |
"""GPT2の設定 | |
GPT2のTokenizerおよびモデルを設定する. | |
ユーザー定義後と顔文字も語彙として認識されるように設定する. | |
Args: | |
ft_path : ファインチューニング済みのモデルを読み込む | |
何も指定しないとself.gpt_model_nameの事前学習モデルを | |
ネットからダウンロードする. | |
Returns: | |
なし | |
""" | |
#GPT2のTokenizerのインスタンスを生成 | |
self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name) | |
self.tokenizer.do_lower_case = True # due to some bug of tokenizer config loading | |
#モデルの読み込み | |
if ft_path is not None: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
ft_path, #torch_dtype = torch.bfloat16 | |
) | |
else: | |
print("fine-tuned model was not found") | |
#モデルをevalモードに | |
self.model.eval() | |
def __TextCleaning(self, texts): | |
"""テキストの前処理をする | |
テキストの前処理を行う.具体的に行うこととしては... | |
・全角/半角スペースの除去 | |
・半角数字/アルファベットの全角化 | |
""" | |
#半角スペース,タブ,改行改ページを削除 | |
texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts] | |
#半角/全角を変換 | |
texts = [mojimoji.han_to_zen(t) for t in texts] | |
return texts | |
def GenLetter(self, prompt): | |
"""怪文書の生成 | |
GPT2で怪文書を生成する. | |
promptに続く文章を生成して出力する | |
Args: | |
prompt : 文章の先頭 | |
Retunrs: | |
生成された文章のリスト | |
""" | |
#テキストをクリーニング | |
prompt_clean = [prompt] | |
#文章をtokenizerでエンコード | |
x = self.tokenizer.encode( | |
prompt_clean[0], return_tensors="pt", | |
add_special_tokens=False | |
) | |
#デバイスの選択 | |
if self.use_cpu: #CPUの利用を強制する場合の処理 | |
device = torch.device('cpu') | |
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
x = x.to(device) | |
#gpt2による推論 | |
with torch.no_grad(): | |
y = self.model.generate( | |
x, #入力 | |
min_length=self.min_len, # 文章の最小長 | |
max_length=self.max_len, # 文章の最大長 | |
do_sample=True, # 次の単語を確率で選ぶ | |
top_k=self.top_k, # Top-Kサンプリング | |
top_p=self.top_p, # Top-pサンプリング | |
temperature=self.temp, # 確率分布の調整 | |
no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか | |
num_return_sequences=self.num_text, # 生成する文章の数 | |
pad_token_id=self.tokenizer.pad_token_id, # パディングのトークンID | |
bos_token_id=self.tokenizer.bos_token_id, # テキスト先頭のトークンID | |
eos_token_id=self.tokenizer.eos_token_id, # テキスト終端のトークンID | |
early_stopping=True | |
) | |
# 特殊トークンをスキップして推論結果を文章にデコード | |
res = self.tokenizer.batch_decode(y, skip_special_tokens=True) | |
return res |