|
import streamlit as st |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification |
|
from langchain_community.llms import HuggingFacePipeline |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from PyPDF2 import PdfReader |
|
from docx import Document |
|
import csv |
|
import json |
|
import torch |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from huggingface_hub import login |
|
|
|
|
|
huggingface_token = st.secrets["HUGGINGFACE_TOKEN"] |
|
login(huggingface_token) |
|
|
|
|
|
model_name = 'mistralai/Mistral-7B-Instruct-v0.3' |
|
model_config = AutoConfig.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
text_generation_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.2, |
|
repetition_penalty=1.1, |
|
return_full_text=True, |
|
max_new_tokens=300, |
|
) |
|
|
|
prompt_template = """ |
|
### [INST] |
|
Instruction: Answer the question based on your knowledge. Here is context to help: |
|
|
|
{context} |
|
|
|
### QUESTION: |
|
{question} |
|
|
|
[/INST] |
|
""" |
|
|
|
mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
|
|
|
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=prompt_template, |
|
) |
|
|
|
|
|
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt) |
|
|
|
|
|
def handle_uploaded_file(uploaded_file): |
|
try: |
|
if uploaded_file.name.endswith(".txt"): |
|
text = uploaded_file.read().decode("utf-8") |
|
elif uploaded_file.name.endswith(".pdf"): |
|
reader = PdfReader(uploaded_file) |
|
text = "" |
|
for page in range(len(reader.pages)): |
|
text += reader.pages[page].extract_text() |
|
elif uploaded_file.name.endswith(".docx"): |
|
doc = Document(uploaded_file) |
|
text = "\n".join([para.text for para in doc.paragraphs]) |
|
elif uploaded_file.name.endswith(".csv"): |
|
text = "" |
|
content = uploaded_file.read().decode("utf-8").splitlines() |
|
reader = csv.reader(content) |
|
text = " ".join([" ".join(row) for row in reader]) |
|
elif uploaded_file.name.endswith(".json"): |
|
data = json.load(uploaded_file) |
|
text = json.dumps(data, indent=4) |
|
else: |
|
text = "Tipo de archivo no soportado." |
|
return text |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def translate(text, target_language): |
|
context = "" |
|
question = f"Por favor, traduzca el siguiente documento al {target_language}:\n{text}\nAseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento." |
|
response = llm_chain.run(context=context, question=question) |
|
return response |
|
|
|
|
|
def summarize(text, length): |
|
context = "" |
|
question = f"Por favor, haga un resumen {length} del siguiente documento:\n{text}\nAseg煤rese de que el resumen sea conciso y conserve el significado original del documento." |
|
response = llm_chain.run(context=context, question=question) |
|
return response |
|
|
|
|
|
@st.cache_resource |
|
def load_classification_model(): |
|
tokenizer_cls = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") |
|
model_cls = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") |
|
return model_cls, tokenizer_cls |
|
|
|
classification_model, classification_tokenizer = load_classification_model() |
|
|
|
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"} |
|
|
|
def classify_text(text): |
|
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length") |
|
classification_model.eval() |
|
with torch.no_grad(): |
|
outputs = classification_model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_id = logits.argmax(dim=-1).item() |
|
predicted_label = id2label[predicted_class_id] |
|
return predicted_label |
|
|
|
|
|
def load_json_documents(category): |
|
try: |
|
with open(f"./{category}.json", "r", encoding="utf-8") as f: |
|
data = json.load(f)["questions_and_answers"] |
|
documents = [entry["question"] + " " + entry["answer"] for entry in data] |
|
return documents |
|
except FileNotFoundError: |
|
return [] |
|
|
|
|
|
@st.cache_resource |
|
def create_vector_store(docs): |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"}) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
|
split_docs = text_splitter.split_text(docs) |
|
vector_store = FAISS.from_texts(split_docs, embeddings) |
|
return vector_store |
|
|
|
def explain_text(user_input, document_text): |
|
classification = classify_text(document_text) |
|
if classification in ["multas", "politicas_de_privacidad", "contratos", "denuncias"]: |
|
docs = load_json_documents(classification) |
|
if docs: |
|
vector_store = create_vector_store(docs) |
|
search_docs = vector_store.similarity_search(user_input) |
|
context = " ".join([doc.page_content for doc in search_docs]) |
|
else: |
|
context = "" |
|
else: |
|
context = "" |
|
question = user_input |
|
response = llm_chain.run(context=context, question=question) |
|
return response |
|
|
|
def main(): |
|
st.title("LexAIcon") |
|
st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.") |
|
|
|
with st.sidebar: |
|
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)") |
|
|
|
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"]) |
|
|
|
if operation == "Explicar": |
|
user_input = st.text_area("Introduce tu pregunta:", "") |
|
uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"]) |
|
if uploaded_file and user_input: |
|
document_text = handle_uploaded_file(uploaded_file) |
|
bot_response = explain_text(user_input, document_text) |
|
st.write(f"**Assistant:** {bot_response}") |
|
else: |
|
uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"]) |
|
if uploaded_file: |
|
document_text = handle_uploaded_file(uploaded_file) |
|
if operation == "Traducir": |
|
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"]) |
|
if target_language: |
|
bot_response = translate(document_text, target_language) |
|
st.write(f"**Assistant:** {bot_response}") |
|
elif operation == "Resumir": |
|
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"]) |
|
if summary_length: |
|
if summary_length == "corto": |
|
length = "de aproximadamente 50 palabras" |
|
elif summary_length == "medio": |
|
length = "de aproximadamente 100 palabras" |
|
elif summary_length == "largo": |
|
length = "de aproximadamente 500 palabras" |
|
bot_response = summarize(document_text, length) |
|
st.write(f"**Assistant:** {bot_response}") |
|
|
|
if __name__ == "__main__": |
|
main() |