Transformers
Safetensors
Japanese
text-generation-inference
unsloth
llama
trl
Inference Endpoints
poprap commited on
Commit
f2636a2
1 Parent(s): 6530c70

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -0
README.md CHANGED
@@ -31,6 +31,74 @@ LLM-jp-3-13bに対して以下のデータセットを用いてSFTを行った
31
  サンプルコード(ipynb)がレポジトリに含まれています。
32
  `dakesan0-inference-testcode.ipynb`
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Special thanks
35
 
36
  本コンペを運営いただいた方々に深く御礼申し上げます。
 
31
  サンプルコード(ipynb)がレポジトリに含まれています。
32
  `dakesan0-inference-testcode.ipynb`
33
 
34
+ unslothを用いた推論のみ動作を確認しています。
35
+
36
+ ```py
37
+ from unsloth import FastLanguageModel
38
+ from peft import PeftModel
39
+ import torch
40
+ import json
41
+ from tqdm import tqdm
42
+ import re
43
+ import datasets
44
+
45
+ model_id = "llm-jp/llm-jp-3-13b"
46
+ adapter_id = "poprap/llm-jp-3-13b-it-2-3"
47
+ adapter_dpo_id = "poprap/llm-jp-3-13b-dpo"
48
+
49
+ dtype = None
50
+ load_in_4bit = True
51
+
52
+ model, tokenizer = FastLanguageModel.from_pretrained(
53
+ model_name=model_id,
54
+ dtype=dtype,
55
+ load_in_4bit=load_in_4bit,
56
+ trust_remote_code=True,
57
+ )
58
+
59
+ model = PeftModel.from_pretrained(model, adapter_id, token = HF_TOKEN)
60
+ model = PeftModel.from_pretrained(model, adapter_dpo_id, token = HF_TOKEN)
61
+
62
+ ds = []
63
+
64
+ with open("elyza-tasks-100-TV_0.jsonl", "r") as f:
65
+ item = ""
66
+ for line in f:
67
+ line = line.strip()
68
+ item += line
69
+ if item.endswith("}"):
70
+ ds.append(json.loads(item))
71
+ item = ""
72
+
73
+ # 推論するためにモデルのモードを変更
74
+ FastLanguageModel.for_inference(model)
75
+
76
+ results = []
77
+ for dt in tqdm(ds):
78
+ input = dt["input"]
79
+
80
+ prompt = f"""### 指示\n{input}\n上記指示に簡潔に回答してください。\n### 回答\n"""
81
+
82
+ inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)
83
+
84
+ outputs = model.generate(
85
+ **inputs,
86
+ max_new_tokens=1024,
87
+ use_cache = True,
88
+ do_sample=False,
89
+ repetition_penalty=1.2
90
+ )
91
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]
92
+
93
+ results.append({"task_id": dt['task_id'], "input": input, "output": prediction})
94
+
95
+ json_file_id = re.sub(".*/", "", adapter_id)
96
+ with open(f"{json_file_id}_output.jsonl", 'w', encoding='utf-8') as f:
97
+ for result in results:
98
+ json.dump(result, f, ensure_ascii=False)
99
+ f.write('\n')
100
+ ```
101
+
102
  # Special thanks
103
 
104
  本コンペを運営いただいた方々に深く御礼申し上げます。