Text Generation
Transformers
Safetensors
English
Japanese
gemma2
text-generation-inference
unsloth
trl
conversational
Inference Endpoints
bay-llm's picture
Update README.md
9460b5f verified
metadata
base_model:
  - google/gemma-2-9b
tags:
  - text-generation-inference
  - transformers
  - unsloth
  - gemma2
  - trl
license: gemma
language:
  - en
  - ja
datasets:
  - kanhatakeyama/wizardlm8x22b-logical-math-coding-sft_additional-ja
  - kanhatakeyama/AutoMultiTurnByCalm3-22B
  - kanhatakeyama/ramdom-to-fixed-multiturn-Calm3

Model Card for Model ID

Instruction tuning The models have been fine-tuned.

Usage

!pip install vllm==0.6.4.post1 --force-reinstall

import time
import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
import vllm ### packaging==24.1にしないとエラーになる!! ###
print(vllm.__version__)

MAX_LENGTH = 1000
MODEL_NAME = "bay-llm/gemma-9b-SFT-90-16bit" # コンペで提出したいモデルに適宜置換

llm = vllm.LLM(
    model=MODEL_NAME,
    tensor_parallel_size=1,
    gpu_memory_utilization=0.95,
    trust_remote_code=True,
    max_model_len=1024,
    
)
tokenizer = llm.get_tokenizer()

# ELYZA-tasks-100-TVの読み込み。事前にファイルをアップロードしてください
# データセットの読み込み。
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。
import json
datasets = []
with open("../elyza-tasks-100-TV_0.jsonl", "r") as f:
    item = ""
    for line in f:
      line = line.strip()
      item += line
      if item.endswith("}"):
        datasets.append(json.loads(item))
        item = ""

print(datasets[0])

messages_list = [
    [{"role": "user", "content": datasets[i]["input"]}] for i in range(len(datasets))
]

prompts = [line[0]["content"] for line in messages_list]
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True) for messages in messages_list]
sampling_params = vllm.SamplingParams(
    temperature=0.5,
    max_tokens=512,
)
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
for prompt, response in zip(prompts, outputs):
    print("prompt:", prompt)
    print("output:", response.outputs[0].text.strip())
    print("-"*80)

import json
data = [{
    "task_id": i,
    "input": prompts[i],
    "output": outputs[i].outputs[0].text.strip()
} for i in range(len(datasets))]
file_path = 'submmit.jsonl'
with open(file_path, 'w', encoding='utf-8') as file:
    for entry in data:
        json.dump(entry, file, ensure_ascii=False)
        file.write('\n')

Uploaded model

  • Developed by: bay-llm
  • License: gemma
  • Finetuned from model : unsloth/gemma-2-9b-bnb-4bit

This gemma2 model was trained 2x faster with Unsloth and Huggingface's TRL library.