import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import chromadb import difflib import pandas as pd from chromadb.config import Settings from chromadb.utils import embedding_functions # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" model_checkpoint = "PleIAs/Estienne" token_classifier = pipeline( "token-classification", model=editorial_model, aggregation_strategy="simple", device=device ) tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) def split_text(text, max_tokens=500): # Split the text by newline characters parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: # Add part to current chunk if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part # Tokenize the temporary chunk num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) # If no newlines were found and still exceeding max_tokens, split further if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 # Ensure split_point does not go out of range if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks #Curtesy of claude def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') # We're not adding anything for words that start with '- ' return ' '.join(html_diff) # Class to encapsulate the Falcon chatbot class MistralChatBot: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def predict(self, user_message): #We drop the newlines. editorial_text = re.sub("\n", " ¶ ", user_message) # Tokenize the prompt and check if it exceeds 500 tokens num_tokens = len(tokenizer.tokenize(prompt)) if num_tokens > 500: # Split the prompt into chunks batch_prompts = split_text(prompt, max_tokens=500) else: batch_prompts = [prompt] out = token_classifier(batch_prompts) out = "".join(out) generated_text = '