Spaces:
Runtime error
Runtime error
LLaMaWhisperer
commited on
Commit
•
102dc72
1
Parent(s):
8ffce3e
MvP of the Project, added the base functionality for it become a simple ChatBot
Browse files- .gitignore +1 -0
- README.md +50 -2
- app.py +5 -0
- legal_llama/__init__.py +0 -0
- legal_llama/bill_retrieval.py +145 -0
- legal_llama/chat_bot_interface.py +83 -0
- legal_llama/dialog_management.py +83 -0
- legal_llama/summarizer.py +54 -0
- requirements.txt +4 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/.streamlit/
|
README.md
CHANGED
@@ -1,4 +1,52 @@
|
|
1 |
-
# LegalLLaMa (*WORK IN PROGRESS*)
|
2 |
LegalLLaMa: Your friendly neighborhood lawyer llama, turning legal jargon into a piece of cake!
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LegalLLaMa 🦙 (*WORK IN PROGRESS*)
|
2 |
LegalLLaMa: Your friendly neighborhood lawyer llama, turning legal jargon into a piece of cake!
|
3 |
|
4 |
+
Legal LLaMa is a chatbot developed to provide summaries of U.S. legislative bills based on user queries. It's built using the Hugging Face's Transformers library, and is hosted using Streamlit on Hugging Face Spaces.
|
5 |
+
|
6 |
+
You can interact with the live demo of Legal LLaMa on Hugging Face Spaces [here](https://huggingface.co/spaces/LLaMaWhisperer/legalLLaMa).
|
7 |
+
|
8 |
+
The chatbot uses a frame-based dialog management system to handle conversations, and leverages the ProPublica and Congress APIs to fetch information about legislative bills. The summaries of bills are generated using a state-of-the-art text summarization model.
|
9 |
+
|
10 |
+
## Features 🎁
|
11 |
+
|
12 |
+
- Frame-based dialog management
|
13 |
+
- Intent recognition and slot filling
|
14 |
+
- Real-time interaction with users
|
15 |
+
- Bill retrieval using ProPublica and Congress APIs
|
16 |
+
- Bill summarization using Transformer models
|
17 |
+
|
18 |
+
## Future Work 💡
|
19 |
+
|
20 |
+
Legal LLaMa is still a work in progress, and there are plans to make it even more useful and user-friendly. Here are some of the planned improvements:
|
21 |
+
|
22 |
+
- Enhance intent recognition and slot filling using Natural Language Understanding (NLU) models
|
23 |
+
- Expand the chatbot's capabilities to handle more tasks, such as providing summaries of recent bills by a particular congressman
|
24 |
+
- Train a custom summarization model specifically for legislative texts
|
25 |
+
|
26 |
+
## Getting Started 🚀
|
27 |
+
|
28 |
+
To get the project running on your local machine, follow these steps:
|
29 |
+
|
30 |
+
1. Clone the repository:
|
31 |
+
```commandline
|
32 |
+
git clone https://github.com/YuvrajSharma9981/LegalLLaMa.git
|
33 |
+
```
|
34 |
+
2. Install the required packages:
|
35 |
+
```commandline
|
36 |
+
pip install -r requirements.txt
|
37 |
+
```
|
38 |
+
|
39 |
+
3. Run the Streamlit app:
|
40 |
+
```commandline
|
41 |
+
streamlit run app.py
|
42 |
+
```
|
43 |
+
|
44 |
+
Please note that you will need to obtain API keys from ProPublica and Congress to access their APIs.
|
45 |
+
|
46 |
+
## Contributing 🤝
|
47 |
+
|
48 |
+
Contributions to improve Legal LLaMa are welcomed. Feel free to submit a pull request or create an issue for any bugs, feature requests, or questions about the project.
|
49 |
+
|
50 |
+
## License 📄
|
51 |
+
|
52 |
+
This project is licensed under the GPL-3.0 License - see the [LICENSE](LICENSE) file for details.
|
app.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from legal_llama.chat_bot_interface import ChatBotInterface
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
chat_bot = ChatBotInterface()
|
5 |
+
chat_bot.continue_conversation()
|
legal_llama/__init__.py
ADDED
File without changes
|
legal_llama/bill_retrieval.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import streamlit as st
|
3 |
+
import xml.etree.ElementTree as ET
|
4 |
+
|
5 |
+
|
6 |
+
class BillRetriever:
|
7 |
+
"""
|
8 |
+
A class used to retrieve bills using the ProPublica Congress API & United States Congress API.
|
9 |
+
"""
|
10 |
+
PROPUBLICA_URL = "https://api.propublica.org/congress/v1/bills/search.json"
|
11 |
+
CONGRESS_URL_BASE = "https://api.congress.gov/v3/bill/{congress}/{billType}/{billNumber}/text"
|
12 |
+
|
13 |
+
def __init__(self, api_key=None):
|
14 |
+
"""
|
15 |
+
Initialize the BillRetriever with API keys.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
api_key (str, optional): The API key to be used for authentication. Default is None.
|
19 |
+
"""
|
20 |
+
self.pro_publica_api_key = st.secrets["PRO_PUBLICA_API_KEY"]
|
21 |
+
self.congress_api_key = st.secrets["CONGRESS_API_KEY"]
|
22 |
+
|
23 |
+
def make_api_call(self, api_url, api_key, params=None):
|
24 |
+
"""
|
25 |
+
Make an API call to the specified URL with optional parameters and API key.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
api_url (str): The URL of the API endpoint.
|
29 |
+
api_key (str): The API Key for the API
|
30 |
+
params (dict, optional): Optional parameters to pass with the API call. Default is None.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
dict: JSON response data if the request is successful, None otherwise.
|
34 |
+
"""
|
35 |
+
headers = {"X-API-Key": api_key} if api_key else {}
|
36 |
+
|
37 |
+
try:
|
38 |
+
response = requests.get(api_url, params=params, headers=headers)
|
39 |
+
response.raise_for_status() # Raise an exception for non-2xx status codes
|
40 |
+
return response.json()
|
41 |
+
except requests.exceptions.RequestException as e:
|
42 |
+
print(f"Error occurred: {e}")
|
43 |
+
return None
|
44 |
+
except ValueError as e:
|
45 |
+
print(f"Invalid response received: {e}")
|
46 |
+
return None
|
47 |
+
|
48 |
+
def search_bill_propublica(self, query):
|
49 |
+
"""
|
50 |
+
Search for a bill using the ProPublica Congress API.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
query (str): The query string to search for.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
dict: JSON response data if the request is successful, None otherwise.
|
57 |
+
"""
|
58 |
+
params = {"query": query, "sort": "date", "dir": "desc"}
|
59 |
+
return self.make_api_call(self.PROPUBLICA_URL, params=params, api_key=self.pro_publica_api_key)
|
60 |
+
|
61 |
+
def get_bill_text_congress(self, congress, bill_type, bill_number):
|
62 |
+
"""
|
63 |
+
Retrieve the text of a bill using the Congress API.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
congress (str): The number of the congress.
|
67 |
+
bill_type (str): The type of the bill.
|
68 |
+
bill_number (str): The number of the bill.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
dict: JSON response data if the request is successful, None otherwise.
|
72 |
+
"""
|
73 |
+
url = self.CONGRESS_URL_BASE.format(congress=congress, billType=bill_type, billNumber=bill_number)
|
74 |
+
return self.make_api_call(url, api_key=self.congress_api_key)
|
75 |
+
|
76 |
+
def get_bill_by_query(self, query):
|
77 |
+
"""
|
78 |
+
Search for a bill by query and retrieve its text.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
query (str): The query string to search for.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The text of the bill if the request is successful, None otherwise.
|
85 |
+
"""
|
86 |
+
# First search for the bill using the ProPublica API
|
87 |
+
propublica_data = self.search_bill_propublica(query)
|
88 |
+
if propublica_data and 'results' in propublica_data:
|
89 |
+
# Iterate over the list of bills, till we find the bill which has text available on Congress Website
|
90 |
+
for bill_data in propublica_data['results'][0]['bills']:
|
91 |
+
congress = bill_data['bill_id'].split('-')[1]
|
92 |
+
bill_type = bill_data['bill_type']
|
93 |
+
bill_number = bill_data['number'].split('.')[-1]
|
94 |
+
|
95 |
+
# Then get the text of the bill using the Congress API
|
96 |
+
congress_data = self.get_bill_text_congress(congress, bill_type, bill_number)
|
97 |
+
if congress_data and 'textVersions' in congress_data and congress_data['textVersions']:
|
98 |
+
# Check if textVersions list is not empty
|
99 |
+
xml_url = congress_data['textVersions'][0]['formats'][2]['url']
|
100 |
+
return self.extract_bill_text(xml_url)
|
101 |
+
return None
|
102 |
+
|
103 |
+
def extract_bill_text(self, url):
|
104 |
+
"""
|
105 |
+
Extract the text content from a bill's XML data.
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
url (str): The URL of the bill's XML data.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
str: The text content of the bill.
|
112 |
+
"""
|
113 |
+
# Get the XML data from the URL
|
114 |
+
try:
|
115 |
+
xml_data = requests.get(url).content
|
116 |
+
except requests.exceptions.RequestException as e:
|
117 |
+
print(f"Error occurred: {e}")
|
118 |
+
return None
|
119 |
+
|
120 |
+
# Decode bytes to string and parse XML
|
121 |
+
try:
|
122 |
+
root = ET.fromstring(xml_data.decode('utf-8'))
|
123 |
+
except ET.ParseError as e:
|
124 |
+
print(f"Error parsing XML: {e}")
|
125 |
+
return None
|
126 |
+
|
127 |
+
return self.get_all_text(root)
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def get_all_text(element):
|
131 |
+
"""
|
132 |
+
Recursively extract text from an XML element and its children.
|
133 |
+
|
134 |
+
Parameters:
|
135 |
+
element (xml.etree.ElementTree.Element): An XML element.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
str: The concatenated text from the element and its children.
|
139 |
+
"""
|
140 |
+
text = element.text or '' # Get the text of the current element, if it exists
|
141 |
+
for child in element:
|
142 |
+
text += BillRetriever.get_all_text(child) # Recursively get the text of all child elements
|
143 |
+
if child.tail:
|
144 |
+
text += child.tail # Add any trailing text of the child element
|
145 |
+
return text
|
legal_llama/chat_bot_interface.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from legal_llama.dialog_management import DialogManager
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
|
5 |
+
class ChatBotInterface:
|
6 |
+
def __init__(self):
|
7 |
+
"""Initializes the chatbot interface, sets the page title, and initializes the DialogManager."""
|
8 |
+
# Set up Streamlit page configuration
|
9 |
+
st.set_page_config(page_title="Legal LLaMa 🦙")
|
10 |
+
st.title("Legal LLaMa 🦙")
|
11 |
+
|
12 |
+
# Define roles
|
13 |
+
self.user = "user"
|
14 |
+
self.llama = "Assistant"
|
15 |
+
|
16 |
+
# Initialize the DialogManager for managing conversations
|
17 |
+
self.dialog_manager = DialogManager()
|
18 |
+
|
19 |
+
# Initialize chat history in the session state if it doesn't exist
|
20 |
+
if "messages" not in st.session_state:
|
21 |
+
st.session_state.messages = []
|
22 |
+
|
23 |
+
# Start the conversation with a greeting message
|
24 |
+
first_message = ("Hello there! I'm Legal LLaMa, your friendly guide to the complex world of U.S. legislation."
|
25 |
+
"\n\nThink of me as a law student who is always eager to learn and share knowledge. Right now,"
|
26 |
+
"my skills are a bit limited, but I can certainly help you understand the gist of the latest "
|
27 |
+
"bills proposed in the U.S. Congress. You just have to provide me with a topic - could be "
|
28 |
+
"climate change, prison reform, healthcare, you name it! I'll then fetch the latest related "
|
29 |
+
"bill and serve you up a digestible summary.\n\nRemember, being a law student (and a LLaMa, no"
|
30 |
+
"less!) is tough, so if I miss a step, bear with me. I promise to get better with every "
|
31 |
+
"interaction. So, what topic intrigues you today?")
|
32 |
+
self.display_message(self.llama, first_message)
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def display_chat_history():
|
36 |
+
"""Displays the chat history stored in the session state."""
|
37 |
+
for message in st.session_state.messages:
|
38 |
+
with st.chat_message(message["role"]):
|
39 |
+
st.markdown(message["content"])
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def add_message_to_history(role, chat):
|
43 |
+
"""Adds a message to the chat history in the session state."""
|
44 |
+
st.session_state.messages.append({"role": role, "content": chat})
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def display_message(role, text):
|
48 |
+
"""Displays a chat message in the chat interface."""
|
49 |
+
st.chat_message(role).markdown(text)
|
50 |
+
|
51 |
+
def handle_user_input(self, user_input):
|
52 |
+
"""Handles user input by recognizing the intent and updating the dialog frame."""
|
53 |
+
# In future, use the IntentRecognizer to check for intent
|
54 |
+
intent = "bill_summarization"
|
55 |
+
|
56 |
+
# Update the dialog frame based on the recognized intent
|
57 |
+
self.dialog_manager.set_frame(intent, user_input)
|
58 |
+
|
59 |
+
def continue_conversation(self):
|
60 |
+
"""Continues the conversation by displaying chat history, handling user input, and generating responses."""
|
61 |
+
# Display chat history
|
62 |
+
self.display_chat_history()
|
63 |
+
|
64 |
+
# Handle user input
|
65 |
+
if prompt := st.chat_input("Ask your questions here!"):
|
66 |
+
# Display user message
|
67 |
+
self.display_message(self.user, prompt)
|
68 |
+
|
69 |
+
# Add user message to chat history
|
70 |
+
self.add_message_to_history(self.user, prompt)
|
71 |
+
|
72 |
+
# Handle user input (recognize intent and update frame)
|
73 |
+
self.handle_user_input(prompt)
|
74 |
+
|
75 |
+
with st.spinner('Processing your request...'):
|
76 |
+
# Generate response based on the current dialog frame
|
77 |
+
response = self.dialog_manager.generate_response()
|
78 |
+
|
79 |
+
# Display assistant response
|
80 |
+
self.display_message(self.llama, response)
|
81 |
+
|
82 |
+
# Add assistant response to chat history
|
83 |
+
self.add_message_to_history(self.llama, response)
|
legal_llama/dialog_management.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from legal_llama.bill_retrieval import BillRetriever
|
2 |
+
from legal_llama.summarizer import BillSummarizer
|
3 |
+
|
4 |
+
|
5 |
+
class DialogManager:
|
6 |
+
"""
|
7 |
+
A class for managing conversation frames.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
"""
|
12 |
+
Initialize the DialogManager with predefined frames.
|
13 |
+
"""
|
14 |
+
self.frames = {
|
15 |
+
"bill_summarization": {
|
16 |
+
"intent": "bill_summarization",
|
17 |
+
"bill_query": None,
|
18 |
+
},
|
19 |
+
# Add more frames here as needed
|
20 |
+
}
|
21 |
+
self.current_frame = None
|
22 |
+
|
23 |
+
def set_frame(self, intent, slot):
|
24 |
+
"""
|
25 |
+
Set the current frame based on the recognized intent and provided slot value.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
intent (str): The recognized intent.
|
29 |
+
slot (str): The value of the slot provided by the user.
|
30 |
+
"""
|
31 |
+
# Update this function in the future to check for intent.
|
32 |
+
self.current_frame = self.frames.get(intent, {}).copy()
|
33 |
+
if self.current_frame is not None:
|
34 |
+
self.update_slot('bill_query', slot)
|
35 |
+
else:
|
36 |
+
print(f"Unrecognized intent: {intent}")
|
37 |
+
|
38 |
+
def update_slot(self, slot_name, slot_value):
|
39 |
+
"""
|
40 |
+
Update the value of a slot in the current frame.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
slot_name (str): The name of the slot.
|
44 |
+
slot_value (str): The new value of the slot.
|
45 |
+
"""
|
46 |
+
if self.current_frame is not None and slot_name in self.current_frame:
|
47 |
+
# If the current frame is set and the slot name exists in the frame, update the slot value
|
48 |
+
self.current_frame[slot_name] = slot_value
|
49 |
+
else:
|
50 |
+
print(f"Cannot update slot '{slot_name}' - no current frame or slot does not exist")
|
51 |
+
|
52 |
+
def generate_response(self):
|
53 |
+
"""
|
54 |
+
Generate a response based on the current frame.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
str: The generated response.
|
58 |
+
"""
|
59 |
+
# Check if a frame has been set
|
60 |
+
if self.current_frame is None:
|
61 |
+
print("No frame has been set")
|
62 |
+
return None
|
63 |
+
|
64 |
+
frame = self.current_frame
|
65 |
+
if frame['intent'] == 'bill_summarization':
|
66 |
+
# Extract the bill's text
|
67 |
+
bill_retriever = BillRetriever()
|
68 |
+
bill_text = bill_retriever.get_bill_by_query(frame['bill_query'])
|
69 |
+
if bill_text is None:
|
70 |
+
print("Unable to retrieve bill text")
|
71 |
+
return None
|
72 |
+
|
73 |
+
# Summarize the bill's text
|
74 |
+
summarizer = BillSummarizer()
|
75 |
+
summary = summarizer.summarize(bill_text)
|
76 |
+
if summary is None:
|
77 |
+
print("Unable to summarize bill text")
|
78 |
+
return None
|
79 |
+
|
80 |
+
return summary
|
81 |
+
else:
|
82 |
+
print(f"Unrecognized frame intent: {frame['intent']}")
|
83 |
+
return None
|
legal_llama/summarizer.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
|
5 |
+
@st.cache_resource
|
6 |
+
def load_model():
|
7 |
+
tokenizers = AutoTokenizer.from_pretrained("nsi319/legal-led-base-16384")
|
8 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-led-base-16384")
|
9 |
+
return tokenizers, model
|
10 |
+
|
11 |
+
|
12 |
+
class BillSummarizer:
|
13 |
+
def __init__(self):
|
14 |
+
"""
|
15 |
+
Initialize a BillSummarizer, which uses the Hugging Face transformers library to summarize bills.
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
self.tokenizer, self.model = load_model()
|
19 |
+
except Exception as e:
|
20 |
+
print(f"Error initializing summarizer pipeline: {e}")
|
21 |
+
|
22 |
+
def summarize(self, bill_text):
|
23 |
+
"""
|
24 |
+
Summarize a bill's text using the summarization pipeline.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
bill_text (str): The text of the bill to be summarized.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
str: The summarized text.
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
input_tokenized = self.tokenizer.encode(bill_text, return_tensors='pt',
|
34 |
+
padding="max_length",
|
35 |
+
pad_to_max_length=True,
|
36 |
+
max_length=6144,
|
37 |
+
truncation=True)
|
38 |
+
|
39 |
+
summary_ids = self.model.generate(input_tokenized,
|
40 |
+
num_beams=4,
|
41 |
+
no_repeat_ngram_size=3,
|
42 |
+
length_penalty=2,
|
43 |
+
min_length=350,
|
44 |
+
max_length=500)
|
45 |
+
|
46 |
+
summary = [self.tokenizer.decode(g,
|
47 |
+
skip_special_tokens=True,
|
48 |
+
clean_up_tokenization_spaces=False)
|
49 |
+
for g in summary_ids][0]
|
50 |
+
|
51 |
+
return summary
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error summarizing text: {e}")
|
54 |
+
return "Sorry, I couldn't summarize this bill. Please try again."
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers~=4.31.0
|
2 |
+
torch~=2.0.1
|
3 |
+
streamlit~=1.24.1
|
4 |
+
requests~=2.31.0
|