shakhal commited on
Commit
576552f
1 Parent(s): 3cf2a82

Update README.md

Browse files

Fix for running on cpu

Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -34,8 +34,8 @@ tokenizer = AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False)
34
  # GPU.
35
  model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="auto", torch_dtype=torch.float16).eval()
36
  # CPU.
37
- # model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="cpu", torch_dtype=torch.float16).eval()
38
-
39
  conv = get_default_conv_template("minichat")
40
 
41
  question = "Implement a program to find the common elements in two arrays without using any extra data structures."
@@ -44,7 +44,7 @@ conv.append_message(conv.roles[1], None)
44
  prompt = conv.get_prompt()
45
  input_ids = tokenizer([prompt]).input_ids
46
  output_ids = model.generate(
47
- torch.as_tensor(input_ids).cuda(),
48
  do_sample=True,
49
  temperature=0.7,
50
  max_new_tokens=1024,
 
34
  # GPU.
35
  model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="auto", torch_dtype=torch.float16).eval()
36
  # CPU.
37
+ # model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="cpu", torch_dtype=torch.float32).eval()
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
  conv = get_default_conv_template("minichat")
40
 
41
  question = "Implement a program to find the common elements in two arrays without using any extra data structures."
 
44
  prompt = conv.get_prompt()
45
  input_ids = tokenizer([prompt]).input_ids
46
  output_ids = model.generate(
47
+ torch.as_tensor(input_ids).to(device),
48
  do_sample=True,
49
  temperature=0.7,
50
  max_new_tokens=1024,