metadata
base_model:
- google/gemma-2-9b
tags:
- text-generation-inference
- transformers
- unsloth
- gemma2
- trl
license: gemma
language:
- en
- ja
datasets:
- kanhatakeyama/wizardlm8x22b-logical-math-coding-sft_additional-ja
- kanhatakeyama/AutoMultiTurnByCalm3-22B
- kanhatakeyama/ramdom-to-fixed-multiturn-Calm3
Model Card for Model ID
Instruction tuning The models have been fine-tuned.
Usage
!pip install vllm==0.6.4.post1 --force-reinstall
import time
import torch
import transformers
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
import vllm ### packaging==24.1にしないとエラーになる!! ###
print(vllm.__version__)
MAX_LENGTH = 1000
MODEL_NAME = "bay-llm/gemma-9b-SFT-90-16bit" # コンペで提出したいモデルに適宜置換
llm = vllm.LLM(
model=MODEL_NAME,
tensor_parallel_size=1,
gpu_memory_utilization=0.95,
trust_remote_code=True,
max_model_len=1024,
)
tokenizer = llm.get_tokenizer()
# ELYZA-tasks-100-TVの読み込み。事前にファイルをアップロードしてください
# データセットの読み込み。
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。
import json
datasets = []
with open("../elyza-tasks-100-TV_0.jsonl", "r") as f:
item = ""
for line in f:
line = line.strip()
item += line
if item.endswith("}"):
datasets.append(json.loads(item))
item = ""
print(datasets[0])
messages_list = [
[{"role": "user", "content": datasets[i]["input"]}] for i in range(len(datasets))
]
prompts = [line[0]["content"] for line in messages_list]
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True) for messages in messages_list]
sampling_params = vllm.SamplingParams(
temperature=0.5,
max_tokens=512,
)
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
for prompt, response in zip(prompts, outputs):
print("prompt:", prompt)
print("output:", response.outputs[0].text.strip())
print("-"*80)
import json
data = [{
"task_id": i,
"input": prompts[i],
"output": outputs[i].outputs[0].text.strip()
} for i in range(len(datasets))]
file_path = 'submmit.jsonl'
with open(file_path, 'w', encoding='utf-8') as file:
for entry in data:
json.dump(entry, file, ensure_ascii=False)
file.write('\n')
Uploaded model
- Developed by: bay-llm
- License: gemma
- Finetuned from model : unsloth/gemma-2-9b-bnb-4bit
This gemma2 model was trained 2x faster with Unsloth and Huggingface's TRL library.