File size: 3,687 Bytes
3c4cefa c16cd03 5b0b77b 3c4cefa f8318a4 3c4cefa c16cd03 5b0b77b 3c4cefa c16cd03 3c4cefa fe37341 3c4cefa fe37341 3c4cefa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions
# Load the email dataset
# emails = pd.read_csv("/content/drive/MyDrive/Clean/cleaned_data.csv")
client = chromadb.Client()
client = chromadb.PersistentClient(path="blob/main/chroma.sqlite3")
# Load the ChromaDB collection
collection = client.get_collection("enron_emails")
# Create a ChromaDB client
# client = chromadb.Client()
# collection = client.create_collection("enron_emails")
# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
# collection.add(
# documents=emails["body"].tolist()[:1000],
# ids=emails["file"].tolist()[:1000],
# metadatas=[{"source": "enron_emails"}] * len(emails[:1000]), # Optional metadata
# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
def query_collection(query_text):
try:
# Perform the query
response = collection.query(query_texts=[query_text], n_results=2)
# Extract documents from the response
if 'documents' in response and len(response['documents']) > 0:
# Assuming each query only has one set of responses, hence response['documents'][0]
documents = response['documents'][0] # This gets the first (and possibly only) list of documents
return "\n\n".join(documents)
else:
# Handle cases where no documents are found or the structure is unexpected
return "No documents found or the response structure is not as expected."
except Exception as e:
return f"An error occurred while querying: {e}"
def summarize_documents(text_input):
try:
# Tokenize input text for the model
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
# Generate a summary with the model
summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"An error occurred while summarizing: {e}"
def query_then_summarize(query_text, _):
try:
# Perform the query
query_results = query_collection(query_text)
# Return empty summary initially
return query_results, ""
except Exception as e:
return f"An error occurred: {e}", ""
def summarize_from_query(_, query_results):
try:
# Use the query results for summarization
summary = summarize_documents(query_results)
return query_results, summary
except Exception as e:
return query_results, f"An error occurred while summarizing: {e}"
# Setup the Gradio interface
with gr.Blocks() as app:
with gr.Row():
query_input = gr.Textbox(label="Enter your query")
query_button = gr.Button("Query")
query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
summarize_button = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")
query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])
app.launch()
|