nengrenjie83 commited on
Commit
f31d363
1 Parent(s): 3e89706

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -11
model.py CHANGED
@@ -6,18 +6,16 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  model_id = 'elyza/ELYZA-japanese-Llama-2-7b-instruct'
9
- if torch.cuda.is_available():
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.bfloat16,
13
- device_map='auto',
14
- use_auth_token=True,
15
- use_cache=True,
16
- ).float()
17
- else:
18
- model = None
19
- tokenizer = AutoTokenizer.from_pretrained(model_id)
20
 
 
 
 
 
 
 
 
 
 
21
 
22
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
23
  system_prompt: str) -> str:
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  model_id = 'elyza/ELYZA-japanese-Llama-2-7b-instruct'
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # トークナイザーとモデルの準備
11
+ tokenizer = AutoTokenizer.from_pretrained(
12
+ "elyza/ELYZA-japanese-Llama-2-7b-instruct"
13
+ )
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "elyza/ELYZA-japanese-Llama-2-7b-instruct",
16
+ torch_dtype=torch.float16,
17
+ device_map="auto"
18
+ ).float()
19
 
20
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
21
  system_prompt: str) -> str: