import gradio as gr from transformers import pipeline from huggingface_hub import InferenceClient, login, snapshot_download from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings import os import pandas as pd from datetime import datetime """ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference """ HF_TOKEN=os.getenv('TOKEN') login(HF_TOKEN) #model = "meta-llama/Llama-3.2-1B-Instruct" #model = "google/mt5-small" model = "mistralai/Mistral-7B-Instruct-v0.3" client = InferenceClient(model) folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd()) embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") vector_db = FAISS.load_local("faiss_index_mpnet", embeddings, allow_dangerous_deserialization=True) df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv") def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score, ): messages = [{"role": "system", "content": system_message}] print(datetime.now()) print(system_message) # retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10}) retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10}) # retriever = vector_db.as_retriever(search_type="mmr") documents = retriever.invoke(message) spacer = " \n" context = "" #print(message) print(len(documents)) for doc in documents: #case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0] context += "#######" + spacer context += "# Case number: " + doc.metadata["case_nb"] + spacer context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer context += "# Case date: " + doc.metadata["case_date"] + spacer context += "# Case url: " + doc.metadata["case_url"] + spacer context += "# Case text: " + doc.page_content + spacer #context += "Case text: " + case_text[:8000] + spacer #print("# Case number: " + doc.metadata["case_nb"] + spacer) #print("# Case url: " + doc.metadata["case_url"] + spacer) message = f""" A user is asking you the following question: {message} Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet. # Context: The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question: {context} # Task: If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information". Don't make up an answer or give irrelevant information not requested by the user. Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates. # Instructions: Always answer the user using the language used in his question: {message} """ print(message) # for val in history: # if val[0]: # messages.append({"role": "user", "content": val[0]}) # if val[1]: # messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) response = "" for message in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"), gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"), ], description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence", ) if __name__ == "__main__": demo.launch(debug=True)