File size: 3,687 Bytes
3c4cefa
 
 
c16cd03
5b0b77b
 
3c4cefa
 
f8318a4
3c4cefa
c16cd03
 
5b0b77b
3c4cefa
c16cd03
 
 
3c4cefa
fe37341
 
3c4cefa
 
fe37341
 
 
 
3c4cefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions


# Load the email dataset
# emails = pd.read_csv("/content/drive/MyDrive/Clean/cleaned_data.csv")


client = chromadb.Client()
client = chromadb.PersistentClient(path="blob/main/chroma.sqlite3")

# Load the ChromaDB collection
collection = client.get_collection("enron_emails")

# Create a ChromaDB client
# client = chromadb.Client()
# collection = client.create_collection("enron_emails")

# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
# collection.add(
#    documents=emails["body"].tolist()[:1000],
#    ids=emails["file"].tolist()[:1000],
#    metadatas=[{"source": "enron_emails"}] * len(emails[:1000]),  # Optional metadata



# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")

def query_collection(query_text):
    try:
        # Perform the query
        response = collection.query(query_texts=[query_text], n_results=2)

        # Extract documents from the response
        if 'documents' in response and len(response['documents']) > 0:
            # Assuming each query only has one set of responses, hence response['documents'][0]
            documents = response['documents'][0]  # This gets the first (and possibly only) list of documents
            return "\n\n".join(documents)
        else:
            # Handle cases where no documents are found or the structure is unexpected
            return "No documents found or the response structure is not as expected."
    except Exception as e:
        return f"An error occurred while querying: {e}"


def summarize_documents(text_input):
    try:
        # Tokenize input text for the model
        inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
        # Generate a summary with the model
        summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary
    except Exception as e:
        return f"An error occurred while summarizing: {e}"

def query_then_summarize(query_text, _):
    try:
        # Perform the query
        query_results = query_collection(query_text)
        # Return empty summary initially
        return query_results, ""
    except Exception as e:
        return f"An error occurred: {e}", ""

def summarize_from_query(_, query_results):
    try:
        # Use the query results for summarization
        summary = summarize_documents(query_results)
        return query_results, summary
    except Exception as e:
        return query_results, f"An error occurred while summarizing: {e}"


# Setup the Gradio interface
with gr.Blocks() as app:
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Query")
    query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
    summarize_button = gr.Button("Summarize")
    summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")

    query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
    summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])

app.launch()