Kiet Bui commited on
Commit
6c57aa3
1 Parent(s): f366f59

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -5
README.md CHANGED
@@ -29,8 +29,8 @@ There was great political strife in the air, and tensions were high. People rush
29
  model = GPTJForCausalLM.from_pretrained(pretrain_name, load_in_8bit=load_in_8bit, device_map='auto', torch_dtype=torch.float16)
30
  model = PeftModel.from_pretrained(model,lora_weights,torch_dtype=torch.float16,device_map={'':0})
31
  model = torch.compile(model)
32
-
33
- GenerationConfig(
34
  temperature=0.1,
35
  top_p=0.75,
36
  top_k=40,
@@ -38,11 +38,11 @@ There was great political strife in the air, and tensions were high. People rush
38
  )
39
 
40
  text = '[User]: What's the best food in Hanoi?''
41
- input_ids = st.session_state.tokenizer(text, return_tensors='pt')['input_ids'].to('cuda')
42
  with torch.no_grad():
43
- output = st.session_state['model'].generate(input_ids=input_ids, generation_config=st.session_state.gen_config,return_dict_in_generate=True, output_scores=True,max_new_tokens=256)
44
  s = output.sequences[0]
45
- output = st.session_state.tokenizer.decode(s)
46
  print('Raw:',output)
47
 
48
  ```
 
29
  model = GPTJForCausalLM.from_pretrained(pretrain_name, load_in_8bit=load_in_8bit, device_map='auto', torch_dtype=torch.float16)
30
  model = PeftModel.from_pretrained(model,lora_weights,torch_dtype=torch.float16,device_map={'':0})
31
  model = torch.compile(model)
32
+ tokenizer = AutoTokenizer.from_pretrained('pygmalion-6b') #The orginal pretrained
33
+ gen_config=GenerationConfig(
34
  temperature=0.1,
35
  top_p=0.75,
36
  top_k=40,
 
38
  )
39
 
40
  text = '[User]: What's the best food in Hanoi?''
41
+ input_ids = tokenizer(text, return_tensors='pt')['input_ids'].to('cuda')
42
  with torch.no_grad():
43
+ output = model.generate(input_ids=input_ids, generation_config=gen_config,return_dict_in_generate=True, output_scores=True,max_new_tokens=256)
44
  s = output.sequences[0]
45
+ output = tokenizer.decode(s)
46
  print('Raw:',output)
47
 
48
  ```