mobicham commited on
Commit
4057de4
1 Parent(s): 3ccc970

update readme.me with cpu/gpu runtime guide

Browse files
Files changed (1) hide show
  1. README.md +9 -2
README.md CHANGED
@@ -30,9 +30,16 @@ pip install pip --upgrade && pip install transformers --upgrade
30
  ``` Python
31
  #Load model
32
  import transformers, torch
 
 
 
33
  compute_dtype = torch.float16
 
 
 
 
 
34
  cache_path = ''
35
- device = 'cuda'
36
  model_id = "mobiuslabsgmbh/aanaphi2-v0.1"
37
  model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype,
38
  cache_dir=cache_path,
@@ -50,7 +57,7 @@ model.eval();
50
  @torch.no_grad()
51
  def generate(prompt, max_length=1024):
52
  prompt_chat = prompt_format(prompt)
53
- inputs = tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to('cuda')
54
  outputs = model.generate(**inputs, max_length=max_length, eos_token_id= tokenizer.eos_token_id)
55
  text = tokenizer.batch_decode(outputs[:,:-1])[0]
56
  return text
 
30
  ``` Python
31
  #Load model
32
  import transformers, torch
33
+
34
+ #GPU runtime
35
+ device = 'cuda'
36
  compute_dtype = torch.float16
37
+
38
+ ##CPU runtime
39
+ #device = 'cpu'
40
+ #compute_dtype = torch.float32
41
+
42
  cache_path = ''
 
43
  model_id = "mobiuslabsgmbh/aanaphi2-v0.1"
44
  model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype,
45
  cache_dir=cache_path,
 
57
  @torch.no_grad()
58
  def generate(prompt, max_length=1024):
59
  prompt_chat = prompt_format(prompt)
60
+ inputs = tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to(device)
61
  outputs = model.generate(**inputs, max_length=max_length, eos_token_id= tokenizer.eos_token_id)
62
  text = tokenizer.batch_decode(outputs[:,:-1])[0]
63
  return text