|
import pandas as pd |
|
import nltk |
|
nltk.download('punkt') |
|
from nltk.tokenize import sent_tokenize |
|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
import gradio as gr |
|
|
|
import re |
|
|
|
|
|
|
|
|
|
emails = pd.read_csv("./cleaned_data.csv") |
|
|
|
|
|
client = chromadb.PersistentClient(path="./content") |
|
|
|
|
|
client = chromadb.Client() |
|
collection = client.create_collection("enron_emails") |
|
|
|
|
|
collection.add( |
|
documents=emails["body"].tolist()[:10000], |
|
ids=emails["file"].tolist()[:10000], |
|
metadatas=[{"source": "enron_emails"}] * len(emails[:10000]), |
|
) |
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("varl42/modello42") |
|
|
|
|
|
|
|
|
|
client = chromadb.Client() |
|
collection = client.get_collection("enron_emails") |
|
|
|
|
|
|
|
def query_collection(query_text): |
|
try: |
|
|
|
response = collection.query(query_texts=[query_text], n_results=2) |
|
|
|
|
|
if 'documents' in response and len(response['documents']) > 0: |
|
|
|
documents = response['documents'][0] |
|
return "\n\n".join(documents) |
|
else: |
|
|
|
return "No documents found or the response structure is not as expected." |
|
except Exception as e: |
|
return f"An error occurred while querying: {e}" |
|
|
|
|
|
def summarize_documents(text_input): |
|
try: |
|
|
|
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
summary_ids = model.generate(inputs['input_ids'], max_length=512, min_length=125, length_penalty=2.0, num_beams=4, early_stopping=True) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary) |
|
summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary) |
|
|
|
return summary |
|
except Exception as e: |
|
return f"An error occurred while summarizing: {e}" |
|
|
|
def query_then_summarize(query_text, _): |
|
try: |
|
|
|
query_results = query_collection(query_text) |
|
|
|
return query_results, "" |
|
except Exception as e: |
|
return f"An error occurred: {e}", "" |
|
|
|
def summarize_from_query(_, query_results): |
|
try: |
|
|
|
summary = summarize_documents(query_results) |
|
return query_results, summary |
|
except Exception as e: |
|
return query_results, f"An error occurred while summarizing: {e}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
with gr.Row(): |
|
query_input = gr.Textbox(label="Enter your query") |
|
query_button = gr.Button("Query") |
|
query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True) |
|
summarize_button = gr.Button("Summarize") |
|
summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...") |
|
|
|
query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output]) |
|
summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output]) |
|
|
|
app.launch() |
|
|
|
|