dmar1313 commited on
Commit
7ec2237
·
1 Parent(s): c5303ee

Create truthbot.py

Browse files
Files changed (1) hide show
  1. truthbot.py +57 -0
truthbot.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, pipeline, logging
2
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
3
+
4
+ model_name_or_path = "TheBloke/Llama-2-13B-GPTQ"
5
+ model_basename = "gptq_model-4bit-128g"
6
+
7
+ use_triton = False
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
10
+
11
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
12
+ model_basename=model_basename,
13
+ use_safetensors=True,
14
+ trust_remote_code=True,
15
+ device="cuda:0",
16
+ use_triton=use_triton,
17
+ quantize_config=None)
18
+
19
+ """
20
+ To download from a specific branch, use the revision parameter, as in this example:
21
+
22
+ model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
23
+ revision="gptq-4bit-32g-actorder_True",
24
+ model_basename=model_basename,
25
+ use_safetensors=True,
26
+ trust_remote_code=True,
27
+ device="cuda:0",
28
+ quantize_config=None)
29
+ """
30
+
31
+ prompt = "Tell me about AI"
32
+ prompt_template=f'''{prompt}
33
+ '''
34
+
35
+ print("\n\n*** Generate:")
36
+
37
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
38
+ output = model.generate(inputs=input_ids, temperature=0.7, max_new_tokens=512)
39
+ print(tokenizer.decode(output[0]))
40
+
41
+ # Inference can also be done using transformers' pipeline
42
+
43
+ # Prevent printing spurious transformers error when using pipeline with AutoGPTQ
44
+ logging.set_verbosity(logging.CRITICAL)
45
+
46
+ print("*** Pipeline:")
47
+ pipe = pipeline(
48
+ "text-generation",
49
+ model=model,
50
+ tokenizer=tokenizer,
51
+ max_new_tokens=512,
52
+ temperature=0.7,
53
+ top_p=0.95,
54
+ repetition_penalty=1.15
55
+ )
56
+
57
+ print(pipe(prompt_template)[0]['generated_text'])