Mizuiro-sakura commited on
Commit
a7f2cd0
1 Parent(s): 91b112b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -0
README.md CHANGED
@@ -1,3 +1,86 @@
1
  ---
2
  license: mit
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ language: ja
4
+
5
  ---
6
+
7
+ ```python
8
+
9
+ import torch
10
+ from peft import PeftModel, PeftConfig
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ model_name = "cyberagent/open-calm-large"
14
+ lora_weights = "Mizuiro-sakura/open-calm-large-finetuned-databricks-dolly"
15
+
16
+ # モデルの準備
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name
19
+ )
20
+
21
+ # トークンナイザーの準備
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ # LoRAモデルの準備
25
+ model = PeftModel.from_pretrained(
26
+ model,
27
+ lora_weights,
28
+ adapter_name=lora_weights
29
+ )
30
+
31
+ # 評価モード
32
+ model.eval()
33
+
34
+ # プロンプトテンプレートの準備
35
+ def generate_prompt(data_point):
36
+ if data_point["input"]:
37
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
38
+
39
+ ### Instruction:
40
+ {data_point["instruction"]}
41
+
42
+ ### Input:
43
+ {data_point["input"]}
44
+
45
+ ### Response:"""
46
+ else:
47
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
48
+
49
+ ### Instruction:
50
+ {data_point["instruction"]}
51
+
52
+ ### Response:"""
53
+
54
+ # テキスト生成関数の定義
55
+ def generate(instruction,input=None,maxTokens=256):
56
+ # 推論
57
+ prompt = generate_prompt({'instruction':instruction,'input':input})
58
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.to('mps')
59
+ outputs = model.generate(
60
+ input_ids=input_ids,
61
+ max_new_tokens=maxTokens,
62
+ do_sample=True,
63
+ temperature=0.7,
64
+ top_p=0.75,
65
+ top_k=40,
66
+ no_repeat_ngram_size=2,
67
+ )
68
+ outputs = outputs[0].tolist()
69
+
70
+ # EOSトークンにヒットしたらデコード完了
71
+ if tokenizer.eos_token_id in outputs:
72
+ eos_index = outputs.index(tokenizer.eos_token_id)
73
+ else:
74
+ eos_index = len(outputs)
75
+ decoded = tokenizer.decode(outputs[:eos_index])
76
+
77
+ # レスポンス内容のみ抽出
78
+ sentinel = "### Response:"
79
+ sentinelLoc = decoded.find(sentinel)
80
+ if sentinelLoc >= 0:
81
+ print(decoded[sentinelLoc+len(sentinel):])
82
+ else:
83
+ print('Warning: Expected prompt template to be emitted. Ignoring output.')
84
+
85
+ generate("自然言語処理とは?")
86
+ ```