assesment / app.py
varl42's picture
Update app.py
5c05bec verified
raw
history blame
4.26 kB
import pandas as pd
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
import re
#######################################################
# Load the email dataset
emails = pd.read_csv("./cleaned_data.csv")
######################################################
client = chromadb.PersistentClient(path="./content")
# Create a ChromaDB client
client = chromadb.Client()
collection = client.create_collection("enron_emails")
# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
collection.add(
documents=emails["body"].tolist()[:10000],
ids=emails["file"].tolist()[:10000],
metadatas=[{"source": "enron_emails"}] * len(emails[:10000]), # Optional metadata
)
####################################################
# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("varl42/modello42")
##################################################
# Load the ChromaDB collection
client = chromadb.Client()
collection = client.get_collection("enron_emails")
##################################################
def query_collection(query_text):
try:
# Perform the query
response = collection.query(query_texts=[query_text], n_results=2)
# Extract documents from the response
if 'documents' in response and len(response['documents']) > 0:
# Assuming each query only has one set of responses, hence response['documents'][0]
documents = response['documents'][0] # This gets the first (and possibly only) list of documents
return "\n\n".join(documents)
else:
# Handle cases where no documents are found or the structure is unexpected
return "No documents found or the response structure is not as expected."
except Exception as e:
return f"An error occurred while querying: {e}"
def summarize_documents(text_input):
try:
# Tokenize input text for the model
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
# Generate a summary with the model
summary_ids = model.generate(inputs['input_ids'], max_length=512, min_length=125, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary) # Ensures that sentences ending in ? ! .
summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary)
return summary
except Exception as e:
return f"An error occurred while summarizing: {e}"
def query_then_summarize(query_text, _):
try:
# Perform the query
query_results = query_collection(query_text)
# Return empty summary initially
return query_results, ""
except Exception as e:
return f"An error occurred: {e}", ""
def summarize_from_query(_, query_results):
try:
# Use the query results for summarization
summary = summarize_documents(query_results)
return query_results, summary
except Exception as e:
return query_results, f"An error occurred while summarizing: {e}"
###################################################
# Setup the Gradio interface
with gr.Blocks() as app:
with gr.Row():
query_input = gr.Textbox(label="Enter your query")
query_button = gr.Button("Query")
query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
summarize_button = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")
query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])
app.launch()