Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import time | |
from datetime import datetime | |
import logging | |
from pathlib import Path | |
import requests | |
import json | |
import numpy as np | |
import pandas as pd | |
import spacy | |
from sentence_transformers import CrossEncoder | |
import litellm | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification | |
import torch | |
import cohere | |
from openai import OpenAI | |
import anthropic | |
import replicate | |
# pip install -U google-generativeai | |
import google.generativeai as genai | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
import src.backend.util as util | |
import src.envs as envs | |
litellm.set_verbose=True | |
# Set up basic configuration for logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load spacy model for word tokenization | |
nlp = spacy.load("en_core_web_sm") | |
os.environ["HUGGINGFACE_API_KEY"] = envs.TOKEN | |
class ModelLoadingException(Exception): | |
"""Exception raised for errors in loading a model. | |
Attributes: | |
model_id (str): The model identifier. | |
revision (str): The model revision. | |
""" | |
def __init__(self, model_id, revision, messages="Error initializing model"): | |
self.model_id = model_id | |
self.revision = revision | |
super().__init__(f"{messages} id={model_id} revision={revision}") | |
class SummaryGenerator: | |
"""A class to generate summaries using a causal language model. | |
Attributes: | |
model (str): huggingface/{model_id} | |
api_base (str): https://api-inference.huggingface.co/models/{model_id} | |
summaries_df (DataFrame): DataFrame to store generated summaries. | |
revision (str): Model revision. | |
avg_length (float): Average length of summaries. | |
answer_rate (float): Rate of non-empty summaries. | |
""" | |
def __init__(self, model_id, revision, device): | |
""" | |
Initializes the SummaryGenerator with a model. | |
Args: | |
model_id (str): Identifier for the model. | |
revision (str): Revision of the model. | |
""" | |
self.model_id = model_id | |
self.model = f"huggingface/{model_id}" | |
self.api_base = f"https://api-inference.huggingface.co/models/{model_id}" | |
self.summaries_df = pd.DataFrame() | |
self.revision = revision | |
self.device = device | |
self.avg_length = None | |
self.answer_rate = None | |
self.exceptions = None | |
self.local_model = None | |
self.local_pipeline = None | |
def generate_summaries(self, df, save_path=None): | |
"""Generate summaries for a given DataFrame of source docs. | |
Args: | |
df (DataFrame): DataFrame containing source docs. | |
Returns: | |
summaries_df (DataFrame): Generated summaries by the model. | |
""" | |
exceptions = [] | |
if (save_path is not None) and os.path.exists(save_path): | |
self.summaries_df = pd.read_csv(save_path) | |
print(f'Loaded generated summaries from {save_path}') | |
else: | |
source, summary, dataset = [], [], [] | |
print(f"Total: {df.shape[0]}") | |
for index, row in tqdm(df.iterrows(), total=df.shape[0]): | |
_source = row['text'] | |
_dataset = row['dataset'] | |
system_prompt = envs.SYSTEM_PROMPT | |
user_prompt = f"{envs.USER_PROMPT}\nPassage:\n{_source}" | |
_summary = None | |
while not _summary: | |
try: | |
_summary = self.generate_summary(system_prompt, user_prompt) | |
# print(f"Finish index {index}") | |
break | |
except Exception as e: | |
if 'Rate limit reached' in str(e): | |
wait_time = 300 | |
current_time = datetime.now().strftime('%H:%M:%S') | |
print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") | |
time.sleep(wait_time) | |
elif 'is currently loading' in str(e): | |
wait_time = 200 | |
print(f"Model is loading, wait for {wait_time}") | |
time.sleep(wait_time) | |
elif '429' in str(e): # for gemini models | |
wait_time = 60 | |
print(f"Quota has reached, wait for {wait_time}") | |
time.sleep(wait_time) | |
else: | |
print(f"Error at index {index}: {e}") | |
_summary = "" | |
exceptions.append(index) | |
break | |
summary.append(_summary) | |
source.append(_source) | |
dataset.append(_dataset) | |
# Sleep to prevent hitting rate limits too frequently | |
time.sleep(1) | |
self.summaries_df = pd.DataFrame(list(zip(source, summary, dataset)), | |
columns=["source", "summary", "dataset"]) | |
if save_path is not None: | |
print(f'Save summaries to {save_path}') | |
fpath = Path(save_path) | |
fpath.parent.mkdir(parents=True, exist_ok=True) | |
self.summaries_df.to_csv(fpath) | |
self.exceptions = exceptions | |
self._compute_avg_length() | |
self._compute_answer_rate() | |
return self.summaries_df | |
def generate_summary(self, system_prompt: str, user_prompt: str): | |
# Using Together AI API | |
using_together_api = False | |
together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen', 'zero-one-ai'] #, 'mistralai' | |
using_replicate_api = False | |
replicate_api_models = ['snowflake', 'llama-3.1-405b'] | |
using_pipeline = False | |
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b'] | |
for replicate_api_model in replicate_api_models: | |
if replicate_api_model in self.model_id.lower(): | |
using_replicate_api = True | |
break | |
if not using_replicate_api: | |
for together_ai_api_model in together_ai_api_models: | |
if together_ai_api_model in self.model_id.lower(): | |
using_together_api = True | |
break | |
if not using_replicate_api and not using_together_api: | |
for pipeline_model in pipeline_models: | |
if pipeline_model in self.model_id.lower(): | |
using_pipeline = True | |
break | |
# if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API | |
if using_together_api: | |
# print('using together api') | |
# suffix = "completions" if ('mixtral' in self.model_id.lower() or 'base' in self.model_id.lower()) else "chat/completions" | |
suffix = "chat/completions" | |
url = f"https://api.together.xyz/v1/{suffix}" | |
payload = { | |
"model": self.model_id, | |
'max_new_tokens': 250, | |
"temperature": 0.0, | |
} | |
payload['messages'] = [{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}] | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}" | |
} | |
response = requests.post(url, json=payload, headers=headers) | |
print(response) | |
try: | |
result = json.loads(response.text) | |
# print(result) | |
result = result["choices"][0] | |
if 'message' in result: | |
result = result["message"]["content"].strip() | |
else: | |
result = result["text"] | |
result_candidates = [result_cancdidate for result_cancdidate in result.split('\n\n') if len(result_cancdidate) > 0] | |
result = result_candidates[0] | |
# print(result) | |
except: | |
# print(response) | |
result = '' | |
print(result) | |
return result | |
# Using OpenAI API | |
elif 'gpt' in self.model_id.lower(): | |
client = OpenAI() | |
response = client.chat.completions.create( | |
model=self.model_id.replace('openai/',''), | |
messages=[{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}], | |
temperature=0.0, | |
max_tokens=250, | |
) | |
# print(response) | |
result = response.choices[0].message.content | |
print(result) | |
return result | |
# Using Google AI API for Gemini models | |
elif 'gemini' in self.model_id.lower(): | |
genai.configure(api_key=os.getenv('GOOGLE_AI_API_KEY')) | |
generation_config = { | |
"temperature": 0, | |
"top_p": 0.95, # cannot change | |
"top_k": 0, | |
"max_output_tokens": 250, | |
# "response_mime_type": "application/json", | |
} | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
] | |
model = genai.GenerativeModel(model_name=self.model_id.lower().split('google/')[-1], | |
generation_config=generation_config, | |
system_instruction=system_prompt, | |
safety_settings=safety_settings) | |
# print(model) | |
convo = model.start_chat(history=[]) | |
convo.send_message(user_prompt) | |
# print(convo.last) | |
result = convo.last.text | |
print(result) | |
return result | |
elif using_replicate_api: | |
print("using replicate") | |
if 'snowflake' in self.model_id.lower(): | |
input = { | |
"prompt": user_prompt, | |
"temperature": 0, | |
"max_new_tokens": 250, | |
"stop_sequences": "<|im_end|>", | |
"prompt_template": f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + "<|im_start|>user\n{prompt}<|im_end|>\n\n<|im_start|>assistant\n", | |
} | |
else: | |
input = { | |
"prompt": user_prompt, | |
"system_prompt": system_prompt, | |
"temperature": 0, | |
"max_new_tokens": 250 | |
} | |
response = replicate.run( | |
self.model_id, | |
input=input | |
) | |
# print(response) | |
if isinstance(response, list): | |
response = ''.join(response) | |
# print(response) | |
# print() | |
print(response) | |
return response | |
elif 'claude' in self.model_id.lower(): # using anthropic api | |
client = anthropic.Anthropic() | |
message = client.messages.create( | |
model=self.model_id.split('/')[-1], | |
max_tokens=250, | |
temperature=0, | |
system=system_prompt, | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": user_prompt | |
} | |
] | |
} | |
] | |
) | |
result = message.content[0].text | |
print(result) | |
return result | |
elif 'mistral-large' in self.model_id.lower(): | |
api_key = os.environ["MISTRAL_API_KEY"] | |
client = MistralClient(api_key=api_key) | |
messages = [ | |
ChatMessage(role="system", content=system_prompt), | |
ChatMessage(role="user", content=user_prompt) | |
] | |
# No streaming | |
chat_response = client.chat( | |
model=self.model_id, | |
messages=messages, | |
) | |
result = chat_response.choices[0].message.content | |
print(result) | |
return result | |
# Using HF API or download checkpoints | |
elif self.local_model is None and self.local_pipeline is None: | |
# try: # try use HuggingFace API | |
# print('** using huggingface api') | |
# response = litellm.completion( | |
# model=self.model, | |
# messages=[{"role": "system", "content": system_prompt}, | |
# {"role": "user", "content": user_prompt}], | |
# temperature=0.0, | |
# max_tokens=250, | |
# api_base=self.api_base, | |
# ) | |
# result = response['choices'][0]['message']['content'] | |
# result = result.split('<|im_end|>')[0] | |
# print(result) | |
# return result | |
# except Exception as e: | |
# if 'Rate limit reached' in str(e) : | |
# wait_time = 300 | |
# current_time = datetime.now().strftime('%H:%M:%S') | |
# print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") | |
# time.sleep(wait_time) | |
# else: | |
if using_pipeline: | |
self.local_pipeline = pipeline( | |
"text-generation", | |
model=self.model_id, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True) | |
print("Tokenizer loaded") | |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto") | |
print(self.local_model.device) | |
print("Local model loaded") | |
# Using local model/pipeline | |
if self.local_pipeline: | |
print('Using Transformers pipeline') | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
outputs = self.local_pipeline( | |
messages, | |
max_new_tokens=250, | |
) | |
result = outputs[0]["generated_text"][-1]['content'] | |
print(result) | |
return result | |
elif self.local_model: # cannot call API. using local model / pipeline | |
print('Using local model') | |
if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower(): | |
messages=[ | |
# gemma-1.1, mistral-7b does not accept system role | |
{"role": "user", "content": system_prompt + ' ' + user_prompt} | |
] | |
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False) | |
elif 'phi-2' in self.model_id.lower(): | |
prompt = system_prompt + '\n' + user_prompt | |
elif 'intel' in self.model_id.lower(): | |
prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n" | |
else: | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False) | |
# print(prompt) | |
# print('-'*50) | |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id) | |
if 'glm' in self.model_id.lower(): | |
outputs = outputs[:, input_ids['input_ids'].shape[1]:] | |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if 'gemma-2' in self.model_id.lower(): | |
result = result.split(user_prompt + '\nmodel')[-1].strip() | |
elif 'intel' in self.model_id.lower(): | |
result = result.split("### Assistant:\n")[-1] | |
else: | |
print(prompt) | |
print('-'*50) | |
result = result.replace(prompt.strip(), '') | |
print(result) | |
return result | |
def _compute_avg_length(self): | |
""" | |
Compute the average length of non-empty summaries using SpaCy. | |
""" | |
total_word_count = 0 | |
total_count = 0 | |
for summary in self.summaries_df['summary']: | |
if util.is_summary_valid(summary): | |
doc = nlp(summary) | |
words = [token.text for token in doc if token.is_alpha] | |
total_word_count += len(words) | |
total_count += 1 | |
self.avg_length = 0 if total_count == 0 else total_word_count / total_count | |
def _compute_answer_rate(self): | |
""" | |
Compute the rate of non-empty summaries. | |
""" | |
valid_count = sum(1 for summary in self.summaries_df['summary'] | |
if util.is_summary_valid(summary)) | |
total_count = len(self.summaries_df) | |
self.answer_rate = 0 if total_count == 0 else valid_count / total_count | |
class EvaluationModel: | |
"""A class to evaluate generated summaries. | |
Attributes: | |
model (CrossEncoder): The evaluation model. | |
scores (list): List of evaluation scores. | |
accuracy (float): Accuracy of the summaries. | |
hallucination_rate (float): Rate of hallucination in summaries. | |
""" | |
def __init__(self, model_path, device): | |
""" | |
Initializes the EvaluationModel with a CrossEncoder model. | |
Args: | |
model_path (str): Path to the CrossEncoder model. | |
""" | |
self.model = AutoModelForTokenClassification.from_pretrained(model_path) | |
self.device = device | |
self.model.to(self.device) | |
self.scores = [] | |
self.factual_consistency_rate = None | |
self.hallucination_rate = None | |
def predict(self, text_pairs): | |
"""Load LoRA adapters of HHEM and make predictions | |
All HHEM 2.1 settings, e.g., prompt template, are hardcoded in this function. | |
Args: | |
text_pairs: list of tuples, each tuple contains two strings (premise, hypothesis) | |
checkpoint: model ID on Hugging Face | |
""" | |
prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" | |
tokenizer = AutoTokenizer.from_pretrained('t5-base') | |
inputs = tokenizer( | |
[prompt.format(text1=pair[0], text2=pair[1]) for pair in text_pairs], | |
return_tensors='pt', padding='longest').to(self.device) | |
self.model.eval() | |
with torch.no_grad(): | |
output = self.model(**inputs) | |
logits = output.logits | |
logits = logits[:,0,:] # get the logits on the first token | |
logits = torch.softmax(logits, dim=-1) | |
scores = [round(x, 5) for x in logits[:, 1].tolist()] # list of float | |
return scores | |
def evaluate_hallucination(self, summaries_df): | |
""" | |
Evaluate the hallucination rate in summaries. Updates the 'scores' attribute | |
of the instance with the computed scores. | |
Args: | |
summaries_df (DataFrame): DataFrame containing source docs and summaries. | |
Returns: | |
list: List of hallucination scores. Also updates the 'scores' attribute of the instance. | |
""" | |
hem_scores = [] | |
sources = [] | |
summaries = [] | |
source_summary_pairs = util.create_pairs(summaries_df) | |
for doc, summary in source_summary_pairs: | |
if util.is_summary_valid(summary): | |
try: | |
summary = summary.replace('<bos>','').replace('<eos>','').strip() | |
score = self.predict([(doc, summary)])[0] | |
# print(score) | |
# if score < 0.5: | |
# print(doc) | |
# print('-'*10) | |
# print(summary) | |
# print('='*20) | |
hem_scores.append(score) | |
sources.append(doc) | |
summaries.append(summary) | |
except Exception as e: | |
logging.error(f"Error while running HEM: {e}") | |
raise | |
self.scores = hem_scores | |
eval_results = {'source': sources, 'summary': summaries, 'HEM scores': hem_scores} | |
return hem_scores, eval_results | |
def compute_factual_consistency_rate(self, threshold=0.5): | |
""" | |
Compute the factual consistency rate of the evaluated summaries based on | |
the previously calculated scores. This method relies on the 'scores' | |
attribute being populated, typically via the 'evaluate_hallucination' method. | |
Returns: | |
float: Factual Consistency Rate. Also updates the 'factual_consistency_rate' | |
and 'hallucination_rate' attributes of the instance. | |
Raises: | |
ValueError: If scores have not been calculated prior to calling this method. | |
""" | |
if not self.scores: | |
error_msg = "Scores not calculated. Call evaluate_hallucination() first." | |
logging.error(error_msg) | |
raise ValueError(error_msg) | |
# Use threshold of 0.5 to compute factual_consistency_rate | |
num_above_threshold = sum(score >= threshold for score in self.scores) | |
num_total = len(self.scores) | |
if not num_total: | |
raise ValueError("No scores available to compute factual consistency rate.") | |
self.factual_consistency_rate = (num_above_threshold / num_total) * 100 | |
self.hallucination_rate = 100 - self.factual_consistency_rate | |
return self.factual_consistency_rate | |