import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import lancedb
import pandas as pd
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "PleIAs/Pleias-Rag"
# Get Hugging Face token from environment variable
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
raise ValueError("Please set the HF_TOKEN environment variable")
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token)
model.to(device)
# Set tokenizer configuration
tokenizer.eos_token = "<|answer_end|>"
eos_token_id=tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 1
# Define variables
temperature = 0.0
max_new_tokens = 1500
top_p = 0.95
repetition_penalty = 1.0
min_new_tokens = 800
early_stopping = False
# Connect to the LanceDB database
db = lancedb.connect("content 5/lancedb_data")
table = db.open_table("sciencev4")
def hybrid_search(text):
results = table.search(text, query_type="hybrid").limit(6).to_pandas()
document = []
document_html = []
for _, row in results.iterrows():
hash_id = str(row['hash'])
title = row['section']
content = row['text']
document.append(f"**{hash_id}**\n{title}\n{content}")
document_html.append(f'
{hash_id} : {title}
{content}
')
document = "\n\n".join(document)
document_html = '' + "".join(document_html) + "
"
return document, document_html
class CassandreChatBot:
def __init__(self, system_prompt="Tu es Appli, un asistant de recherche qui donne des responses sourcées"):
self.system_prompt = system_prompt
def predict(self, user_message):
fiches, fiches_html = hybrid_search(user_message)
detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Analysis ###\n"""
# Convert inputs to tensor
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device)
attention_mask = torch.ones_like(input_ids)
try:
# Add some debug prints
print("Input length:", len(input_ids[0]))
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
early_stopping=early_stopping,
min_new_tokens=min_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# Add return_dict_in_generate=True to see full output info
return_dict_in_generate=True,
output_scores=True
)
# Print debug info about output
print("Output sequence length:", len(output.sequences[0]))
print("New tokens generated:", len(output.sequences[0]) - len(input_ids[0]))
# Try decoding only the new tokens
generated_text = tokenizer.decode(output.sequences[0][len(input_ids[0]):])
generated_text = 'Réponse\n
' + format_references(generated_text) + "
"
fiches_html = 'Sources\n' + fiches_html
return generated_text, fiches_html
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return None, None
def format_references(text):
ref_start_marker = '[', start_pos)
if end_pos == -1:
break
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">")
ref_end_pos = text.find(ref_end_marker, end_pos)
if ref_end_pos == -1:
break
ref_id = text[end_pos + 2:ref_end_pos].strip()
tooltip_html = f'[{ref_number}]'
parts.append(tooltip_html)
current_pos = ref_end_pos + len(ref_end_marker)
ref_number = ref_number + 1
return ''.join(parts)
# Initialize the CassandreChatBot
cassandre_bot = CassandreChatBot()
# CSS for styling
css = """
.generation {
margin-left:2em;
margin-right:2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float:left;
max-width:17%;
margin-left:2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
"""
# Gradio interface
def gradio_interface(user_message):
response, sources = cassandre_bot.predict(user_message)
return response, sources
# Create Gradio app
demo = gr.Blocks(css=css)
with demo:
gr.HTML("""]Cassandre
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
text_button = gr.Button("Interroger Cassandre")
with gr.Column(scale=3):
text_output = gr.HTML(label="La réponse de Cassandre")
with gr.Row():
embedding_output = gr.HTML(label="Les sources utilisées")
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
# Launch the app
if __name__ == "__main__":
demo.launch()