gemma-7B-it-firefly / README.md
yys's picture
Upload folder using huggingface_hub
29f5049 verified
|
raw
history blame
2.16 kB
---
library_name: transformers
license: apache-2.0
basemodel: google/gemma-7b
---
## Model Card for Firefly-Gemma
[gemma-7B-it-firefly](https://huggingface.co/yys/gemma-7B-it-firefly) is trained based on [gemma-7b-it](https://huggingface.co/google/gemma-7b-it) to act as a helpful and harmless AI assistant.
We use [Firefly](https://github.com/yangjianxin1/Firefly) to train the model with LoRA.
<img src="open_llm_leaderboard.png" width="800">
We advise you to install transformers>=4.38.2.
## Performance
We evaluate our models on [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard), they achieve good performance.
## Usage
The chat template of our chat models is similar as Official gemma-7b-it:
```text
<bos><start_of_turn>user
hello, who are you?<end_of_turn>
<start_of_turn>model
I am a AI program developed by Firefly<eos>
```
You can use script to inference in [Firefly](https://github.com/yangjianxin1/Firefly/blob/master/script/chat/chat.py).
You can also use the following code:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name_or_path = "yys/gemma-7B-it-firefly"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
prompt = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions. "
text = f"""
<bos><start_of_turn>user
{prompt}<end_of_turn>
<start_of_turn>model
""".strip()
model_inputs = tokenizer([text], return_tensors="pt").to('cuda')
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=1500,
top_p = 0.9,
temperature = 0.35,
repetition_penalty = 1.0,
eos_token_id=tokenizer.encode('<eos>', add_special_tokens=False)
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
```