rodrigomasini commited on
Commit
3ac5658
1 Parent(s): b99cc1c

Update app_v1.py

Browse files
Files changed (1) hide show
  1. app_v1.py +16 -17
app_v1.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import torch
7
 
8
  # Clear up some memory
9
- # torch.cuda.empty_cache()
10
 
11
  # Try reducing the number of threads PyTorch uses
12
  # torch.set_num_threads(1)
@@ -54,22 +54,21 @@ model = AutoGPTQForCausalLM.from_quantized(
54
  quantize_config=quantize_config
55
  )
56
 
57
- st.write(model.hf_device_map)
 
58
 
59
- #user_input = st.text_input("Input a phrase")
60
-
61
- #prompt_template = f'USER: {user_input}\nASSISTANT:'
62
 
63
  # Generate output when the "Generate" button is pressed
64
- #if st.button("Generate the prompt"):
65
- # inputs = tokenizer(prompt_template, return_tensors="pt")
66
- # outputs = model.generate(
67
- # input_ids=inputs.input_ids.to("cuda:0"),
68
- # attention_mask=inputs.attention_mask.to("cuda:0"),
69
- # max_length=512 + inputs.input_ids.size(-1),
70
- # temperature=0.1,
71
- # top_p=0.95,
72
- # repetition_penalty=1.15
73
- # )
74
- # generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
- # st.text_area("Prompt", value=generated_text)
 
6
  import torch
7
 
8
  # Clear up some memory
9
+ #torch.cuda.empty_cache()
10
 
11
  # Try reducing the number of threads PyTorch uses
12
  # torch.set_num_threads(1)
 
54
  quantize_config=quantize_config
55
  )
56
 
57
+ #st.write(model.hf_device_map)
58
+ user_input = st.text_input("Input a phrase")
59
 
60
+ prompt_template = f'USER: {user_input}\nASSISTANT:'
 
 
61
 
62
  # Generate output when the "Generate" button is pressed
63
+ if st.button("Generate the prompt"):
64
+ inputs = tokenizer(prompt_template, return_tensors="pt")
65
+ outputs = model.generate(
66
+ input_ids=inputs.input_ids.to("cuda:0"),
67
+ attention_mask=inputs.attention_mask.to("cuda:0"),
68
+ max_length=512 + inputs.input_ids.size(-1),
69
+ temperature=0.1,
70
+ top_p=0.95,
71
+ repetition_penalty=1.15
72
+ )
73
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+ st.text_area("Prompt", value=generated_text)