Lihuchen commited on
Commit
cf63839
1 Parent(s): 686aa3a

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +3 -3
  2. confidence.py +21 -4
  3. requirements.txt +1 -1
  4. self_check_gpt.py +17 -0
  5. tiny_llama.py +40 -0
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from confidence import run
3
 
4
 
5
  def greet(query):
6
- results = run(query, sample_size=5)
7
  return results
8
 
9
 
10
  sample_list = [
11
- "Tell me something about Lihu Chen, e.g., birth date and place and short bio ",
12
  ]
13
 
14
  iface = gr.Interface(fn=greet, inputs="text", outputs="text", examples=sample_list, cache_examples=True)
 
1
  import gradio as gr
2
+ from confidence import run_nli
3
 
4
 
5
  def greet(query):
6
+ results = run_nli(query, sample_size=5)
7
  return results
8
 
9
 
10
  sample_list = [
11
+ "Tell me something about Albert Einstein, e.g., birth date and place and short bio ",
12
  ]
13
 
14
  iface = gr.Interface(fn=greet, inputs="text", outputs="text", examples=sample_list, cache_examples=True)
confidence.py CHANGED
@@ -3,7 +3,8 @@ nltk.download('punkt')
3
  from nltk.tokenize import sent_tokenize
4
  #from tiny_llama import generate_answer
5
  #from llama_generate import generate_answer
6
- from cpu_llama_generate import generate_answer
 
7
 
8
  def get_yes_or_no(result):
9
  if 'yes' in str.lower(result)[:5]:return 'Yes'
@@ -23,7 +24,7 @@ def check_score(context, sentences):
23
  for sentence in sentences:
24
  content = template.format(a=context.strip().replace('/n', ''), b=sentence.strip().replace('/n', ''))
25
  result = generate_answer(content, sample_num=1)[0]
26
- print(result)
27
  results.append(result)
28
 
29
  results = [get_yes_or_no(r) for r in results]
@@ -36,7 +37,7 @@ def check_score(context, sentences):
36
  return scores
37
 
38
 
39
- def run(query, sample_size=5):
40
  sampled = generate_answer(query, sample_size+1)
41
  answer = sampled[0]
42
  proofs = sampled[1:]
@@ -58,6 +59,22 @@ def run(query, sample_size=5):
58
  return final_content
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if __name__ == '__main__':
62
  # result = generate_answer(query="Who is Lihu Chen?", sample_num=3)
63
  # print(result)
@@ -72,5 +89,5 @@ if __name__ == '__main__':
72
  # print(result)
73
  # result = """
74
 
75
- answer = run(query='who is Lihu Chen', sample_size=5)
76
  print(answer)
 
3
  from nltk.tokenize import sent_tokenize
4
  #from tiny_llama import generate_answer
5
  #from llama_generate import generate_answer
6
+ from tiny_llama import generate_answer
7
+ from self_check_gpt import nli_confidence
8
 
9
  def get_yes_or_no(result):
10
  if 'yes' in str.lower(result)[:5]:return 'Yes'
 
24
  for sentence in sentences:
25
  content = template.format(a=context.strip().replace('/n', ''), b=sentence.strip().replace('/n', ''))
26
  result = generate_answer(content, sample_num=1)[0]
27
+ #print(result)
28
  results.append(result)
29
 
30
  results = [get_yes_or_no(r) for r in results]
 
37
  return scores
38
 
39
 
40
+ def run_prompt(query, sample_size=5):
41
  sampled = generate_answer(query, sample_size+1)
42
  answer = sampled[0]
43
  proofs = sampled[1:]
 
59
  return final_content
60
 
61
 
62
+ def run_nli(query, sample_size=5):
63
+ sampled = generate_answer(query, sample_size + 1)
64
+ answer = sampled[0]
65
+ proofs = sampled[1:]
66
+ sentences = sent_tokenize(answer)
67
+
68
+ scores = nli_confidence(proofs, sentences)
69
+
70
+ final_content = ''
71
+ for index, sent in enumerate(sentences):
72
+ final_content += sent.strip() + ' ({a}) '.format(a=scores[index])
73
+ final_content += '\nThe confidence score of this answer is {a}'.format(a=sum(scores)/len(scores))
74
+ return final_content
75
+
76
+
77
+
78
  if __name__ == '__main__':
79
  # result = generate_answer(query="Who is Lihu Chen?", sample_num=3)
80
  # print(result)
 
89
  # print(result)
90
  # result = """
91
 
92
+ answer = run_nli(query='tell me something about Albert Einstein', sample_size=5)
93
  print(answer)
requirements.txt CHANGED
@@ -4,4 +4,4 @@ optimum>=1.12.0
4
  auto-gptq
5
  torch==2.1.0
6
  transformers>=4.32.0
7
- ctransformers>=0.2.24
 
4
  auto-gptq
5
  torch==2.1.0
6
  transformers>=4.32.0
7
+ selfcheckgpt
self_check_gpt.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from selfcheckgpt.modeling_selfcheck import SelfCheckNLI
3
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4
+
5
+
6
+ selfcheck_nli = SelfCheckNLI(device=device) # set device to 'cuda' if GPU is available
7
+
8
+
9
+ def nli_confidence(proofs, sentences):
10
+
11
+ sent_scores_nli = selfcheck_nli.predict(
12
+ sentences = sentences, # list of sentences
13
+ sampled_passages = proofs, # list of sampled passages
14
+ )
15
+ #print(sent_scores_nli)
16
+ sent_scores_nli = [1-s for s in sent_scores_nli]
17
+ return sent_scores_nli
tiny_llama.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TRANSFORMERS_CACHE'] = "data/parietal/store3/soda/lihu/hf_model/"
3
+ from transformers import AutoTokenizer
4
+ import transformers
5
+ import torch
6
+
7
+ model = "PY007/TinyLlama-1.1B-Chat-v0.3"
8
+ tokenizer = AutoTokenizer.from_pretrained(model)
9
+ pipeline = transformers.pipeline(
10
+ "text-generation",
11
+ model=model,
12
+ torch_dtype=torch.float16,
13
+ device_map="auto",
14
+ )
15
+ CHAT_EOS_TOKEN_ID = 32002
16
+
17
+ def generate_answer(query, sample_num=3):
18
+ #prompt = "Who is Lihu Chen?"
19
+ formatted_prompt = (
20
+ f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"
21
+
22
+ )
23
+
24
+ sequences = pipeline(
25
+ formatted_prompt,
26
+ do_sample=True,
27
+ top_k=50,
28
+ top_p = 0.9,
29
+ num_return_sequences=sample_num,
30
+ repetition_penalty=1.1,
31
+ max_new_tokens=150,
32
+ eos_token_id=CHAT_EOS_TOKEN_ID,
33
+ )
34
+ answers = list()
35
+ for seq in sequences:
36
+ answer = seq['generated_text'].replace(formatted_prompt, "")
37
+ answers.append(answer)
38
+ #print(f"Result: {answer}")
39
+ #print("------------------------------------------")
40
+ return answers