guoday commited on
Commit
5085756
1 Parent(s): 1aa20ce

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -47,18 +47,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
47
  import torch
48
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-5.7bmqa-base", trust_remote_code=True)
49
  model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-5.7bmqa-base", trust_remote_code=True).cuda()
50
- input_text = """<fim_prefix>def quick_sort(arr):
51
  if len(arr) <= 1:
52
  return arr
53
  pivot = arr[0]
54
  left = []
55
  right = []
56
- <fim_middle>
57
  if arr[i] < pivot:
58
  left.append(arr[i])
59
  else:
60
  right.append(arr[i])
61
- return quick_sort(left) + [pivot] + quick_sort(right)<fim_suffix>"""
62
  inputs = tokenizer(input_text, return_tensors="pt").cuda()
63
  outputs = model.generate(**inputs, max_length=128)
64
  print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):])
 
47
  import torch
48
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-5.7bmqa-base", trust_remote_code=True)
49
  model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-5.7bmqa-base", trust_remote_code=True).cuda()
50
+ input_text = """<|fim▁begin|>def quick_sort(arr):
51
  if len(arr) <= 1:
52
  return arr
53
  pivot = arr[0]
54
  left = []
55
  right = []
56
+ <|fim▁hole|>
57
  if arr[i] < pivot:
58
  left.append(arr[i])
59
  else:
60
  right.append(arr[i])
61
+ return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>"""
62
  inputs = tokenizer(input_text, return_tensors="pt").cuda()
63
  outputs = model.generate(**inputs, max_length=128)
64
  print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):])