llm_with_confidence / llama_generate.py
Lihuchen's picture
Upload 3 files
f2ff742
raw
history blame
5.67 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from nltk.tokenize import sent_tokenize
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # the device to load the model onto
model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
#torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=False,
revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.unk_token
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# model.resize_token_embeddings(len(tokenizer))
def clean(result):
special_token = ['<s>', '</s>', '<unk>']
result = result.split("[/INST]")[-1].strip()
# context = "[INST] {a} [/INST]".format(a=content)
#result = result.replace(context, '')
for token in special_token:
result = result.replace(token, '').strip()
return result.strip()
def single_generate(query):
messages = [
{"role": "user", "content": query},
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
#print(encodeds)
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=150, do_sample=True, temperature=1.0)
decoded = tokenizer.batch_decode(generated_ids)
results = list()
for index, result in enumerate(decoded):
#print(result)
result = clean(result)
#print('query = ', query, ' result = ', result)
results.append(result)
return results
def prepare_input(contents):
temp = list()
for content in contents:
messages = [
{"role": "user", "content": content}
]
#print('messages = ', messages)
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", max_length=100, padding=True)
# print(encodeds.size())
# print(encodeds)
temp.append(encodeds[0])
batch_encoded = torch.stack(temp)
return batch_encoded
def batch_generate(queries):
model_inputs = prepare_input(queries).to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True, temperature=1.0)
decoded = tokenizer.batch_decode(generated_ids)
results = list()
for index, result in enumerate(decoded):
query = queries[index]
result = clean(result)
#print('query = ', query, ' result = ', result)
results.append(result)
return results
def get_yes_or_no(result):
if 'yes' in str.lower(result)[:5]:return 'Yes'
if 'no' in str.lower(result)[:5]:return 'No'
return 'N/A'
def check_score(context, sentences):
score_mapping = {'Yes':1.0, 'No':0.0}
template = """
Context: {a}
Sentence: {b}
Is the sentence supported by the context above?
Answer Yes or No (Don't give explanations):
"""
scores, results = list(), list()
for sentence in sentences:
content = template.format(a=context.strip().replace('/n', ''), b=sentence.strip().replace('/n', ''))
result = single_generate(content)[0]
#result = clean(result, context)
#print('results', results)
results.append(result)
results = [get_yes_or_no(r) for r in results]
scores = [score_mapping.get(result, 0.5) for result in results]
# for sent, score in zip(sentences, scores):
# print(sent.strip(), score)
#result_string += sent + ' ({a})'.format(a=score)
return scores
def sample_answer(query, num):
answers = list()
for _ in range(num):
answer = single_generate(query)
answers.append(answer[0])
return answers
def run(query, sample_size=5):
sampled = sample_answer(query, sample_size+1)
answer = sampled[0]
proofs = sampled[1:]
sentences = sent_tokenize(answer)
all_scores = list()
for proof in proofs:
scores = check_score(proof, sentences)
all_scores.append(scores)
final_content = ''
avg_confidence = list()
for index, scores in enumerate(zip(*all_scores)):
sentence_confidence = sum(scores) / len(scores)
avg_confidence.append(sentence_confidence)
final_content += sentences[index].strip() + ' ({a}) '.format(a=sentence_confidence)
avg_confidence = sum(avg_confidence) / len(avg_confidence)
final_content += '\nThe confidence score of this answer is {a}'.format(a=avg_confidence)
return final_content
if __name__ == '__main__':
# result = sample_answer(query="Who is Lihu Chen?", num=5)
# print(result)
#batch_generate(["Who is Lihu Chen?", "Who is Lihu Chen?"])
# context = """
# Lihu Chen is an American writer and artist who works in comics. They received their degree in psychology from California State University, Fullerton and have worked on titles such as "The Gathering Storm" and "Heartthrob".
# """
# sentences = sent_tokenize("""
# Lihu Chen is an American writer and artist who works in comics. They received their degree in psychology from California State University, Fullerton and have worked on titles such as "The Gathering Storm" and "Heartthrob".
# """)
# result = check_score(context, sentences)
# print(result)
# result = """
answer = run(query='Tell me something about Gaël Varoquaux, e.g., birth date and place and short bio ', sample_size=10)
print(answer)