Ahmadkhan12 commited on
Commit
2cd131f
1 Parent(s): 87ab71f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -35
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import AutoTokenizer, T5ForConditionalGeneration
2
  from datasets import load_dataset
3
  import faiss
@@ -5,10 +6,23 @@ import numpy as np
5
  import streamlit as st
6
 
7
  # Load the datasets from Hugging Face
8
- datasets_dict = {
9
- "BillSum": load_dataset("billsum"),
10
- "EurLex": load_dataset("eurlex", trust_remote_code=True) # Set trust_remote_code=True
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Load the T5 model and tokenizer for summarization
14
  t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
@@ -26,37 +40,6 @@ def prepare_dataset(dataset_name):
26
  documents = dataset['train']['text'][:100] # Use a subset for demo purposes
27
  titles = dataset['train']['title'][:100] # Get corresponding titles
28
 
29
- prepare_dataset(selected_dataset)
30
-
31
- # Function to embed text for retrieval
32
- def embed_text(text):
33
- input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
34
- with torch.no_grad():
35
- outputs = t5_model.encoder(input_ids)
36
- return outputs.last_hidden_state.mean(dim=1).numpy()
37
-
38
- # Create embeddings for the documents
39
- doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
40
-
41
- # Initialize FAISS index
42
- index = faiss.IndexFlatL2(doc_embeddings.shape[1])
43
- index.add(doc_embeddings)
44
-
45
- # Define functions for retrieving and summarizing cases
46
- def retrieve_cases(query, top_k=3):
47
- query_embedding = embed_text(query)
48
- distances, indices = index.search(query_embedding, top_k)
49
- return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
50
-
51
- def summarize_cases(cases):
52
- summaries = []
53
- for case, _ in cases:
54
- input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
55
- outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
56
- summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
57
- summaries.append(summary)
58
- return summaries
59
-
60
  # Streamlit App Code
61
  st.title("Legal Case Summarizer")
62
  st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
 
1
+ import torch
2
  from transformers import AutoTokenizer, T5ForConditionalGeneration
3
  from datasets import load_dataset
4
  import faiss
 
6
  import streamlit as st
7
 
8
  # Load the datasets from Hugging Face
9
+ datasets_dict = {}
10
+
11
+ # Function to load datasets safely
12
+ def load_datasets():
13
+ global datasets_dict
14
+ try:
15
+ datasets_dict["BillSum"] = load_dataset("billsum")
16
+ except Exception as e:
17
+ st.error(f"Error loading BillSum dataset: {e}")
18
+
19
+ try:
20
+ datasets_dict["EurLex"] = load_dataset("eurlex", trust_remote_code=True) # Set trust_remote_code=True
21
+ except Exception as e:
22
+ st.error(f"Error loading EurLex dataset: {e}")
23
+
24
+ # Load datasets at the start
25
+ load_datasets()
26
 
27
  # Load the T5 model and tokenizer for summarization
28
  t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
 
40
  documents = dataset['train']['text'][:100] # Use a subset for demo purposes
41
  titles = dataset['train']['title'][:100] # Get corresponding titles
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Streamlit App Code
44
  st.title("Legal Case Summarizer")
45
  st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")