imagfff commited on
Commit
a37cae1
·
verified ·
1 Parent(s): de8d87a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -0
README.md CHANGED
@@ -20,3 +20,114 @@ language:
20
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
23
+
24
+ # 提出したjsonlファイルの出力方法
25
+
26
+ 1. 必要なライブラリのインストール
27
+ ```
28
+ pip install unsloth
29
+ pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
30
+ ```
31
+
32
+ 1. 下記のコードを実行
33
+ ```
34
+ import json
35
+ from dataclasses import dataclass
36
+ from pathlib import Path
37
+ from typing import Any, Dict, List
38
+
39
+ import torch
40
+ from tqdm import tqdm
41
+ from unsloth import FastLanguageModel
42
+
43
+
44
+ @dataclass
45
+ class ModelConfig:
46
+ model_name: str = "imagfff/llm-jp-3-13b-it"
47
+ max_seq_length: int = 2048
48
+ dtype: Any = None
49
+ load_in_4bit: bool = True
50
+ token: str = "HF token"
51
+
52
+
53
+ def load_model(config: ModelConfig) -> tuple[Any, Any]:
54
+ """モデルとトークナイザーを読み込む"""
55
+ model, tokenizer = FastLanguageModel.from_pretrained(
56
+ model_name=config.model_name,
57
+ max_seq_length=config.max_seq_length,
58
+ dtype=config.dtype,
59
+ load_in_4bit=config.load_in_4bit,
60
+ token=config.token,
61
+ )
62
+ FastLanguageModel.for_inference(model)
63
+ return model, tokenizer
64
+
65
+
66
+ def load_datasets(file_path: str) -> List[Dict[str, Any]]:
67
+ """JSONLファイルからデータセットを読み込む"""
68
+ datasets = []
69
+ try:
70
+ with open(file_path) as f:
71
+ item = ""
72
+ for line in f:
73
+ line = line.strip()
74
+ item += line
75
+ if item.endswith("}"):
76
+ datasets.append(json.loads(item))
77
+ item = ""
78
+ return datasets
79
+ except (FileNotFoundError, json.JSONDecodeError) as e:
80
+ raise Exception(f"データセットの読み込みに失敗しました: {e}") from e
81
+
82
+
83
+ def generate_prediction(model: Any, tokenizer: Any, input_text: str) -> str:
84
+ """モデルによる推論を実行"""
85
+ prompt = f"### 指示\n{input_text}\n### 回答\n"
86
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=512,
92
+ use_cache=True,
93
+ do_sample=False,
94
+ repetition_penalty=1.2,
95
+ )
96
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("\n### 回答")[
97
+ -1
98
+ ]
99
+
100
+
101
+ def save_results(results: List[Dict[str, Any]], output_path: str) -> None:
102
+ """結果をJSONLファイルに保存"""
103
+ output_path = Path(output_path)
104
+ output_path.parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ with open(output_path, "w", encoding="utf-8") as f:
107
+ for result in results:
108
+ json.dump(result, f, ensure_ascii=False)
109
+ f.write("\n")
110
+
111
+
112
+ def main():
113
+ config = ModelConfig()
114
+ model, tokenizer = load_model(config)
115
+
116
+ datasets = load_datasets("./elyza-tasks-100-TV_0.jsonl")
117
+
118
+ results = []
119
+ for dt in tqdm(datasets, desc="推論実行中"):
120
+ prediction = generate_prediction(model, tokenizer, dt["input"])
121
+ results.append(
122
+ {"task_id": dt["task_id"], "input": dt["input"], "output": prediction}
123
+ )
124
+
125
+ model_basename = config.model_name.split("/")[-1]
126
+ save_results(results, f"/content/{model_basename}_output.jsonl")
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
131
+
132
+
133
+ ```