2z299 commited on
Commit
e271162
·
verified ·
1 Parent(s): 83566e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -26,11 +26,18 @@ def generate_text(
26
  return ""
27
 
28
  # トークナイザーとモデルのロード(GPUが使える場合はGPUへ移動)
29
- tokenizer = AutoTokenizer.from_pretrained('Local-Novel-LLM-project/Vecteus-v1-abliterated', torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="cuda", use_auth_token=HF_TOKEN)
30
  model = AutoModelForCausalLM.from_pretrained('Local-Novel-LLM-project/Vecteus-v1-abliterated', use_auth_token=HF_TOKEN)
31
 
 
 
 
 
 
 
 
32
  # 入力テキストをトークン化してトークン数を取得
33
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda")
34
  input_token_count = input_ids.shape[1]
35
 
36
  # 総トークン数の上限を入力トークン数 + max_length(max_lengthはトークン数として扱う)
 
26
  return ""
27
 
28
  # トークナイザーとモデルのロード(GPUが使える場合はGPUへ移動)
29
+ tokenizer = AutoTokenizer.from_pretrained('Local-Novel-LLM-project/Vecteus-v1-abliterated', attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN)
30
  model = AutoModelForCausalLM.from_pretrained('Local-Novel-LLM-project/Vecteus-v1-abliterated', use_auth_token=HF_TOKEN)
31
 
32
+ # GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
35
+ model.to(device, dtype=torch.bfloat16)
36
+ else:
37
+ model.to(device)
38
+
39
  # 入力テキストをトークン化してトークン数を取得
40
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
41
  input_token_count = input_ids.shape[1]
42
 
43
  # 総トークン数の上限を入力トークン数 + max_length(max_lengthはトークン数として扱う)