LLaMaWhisperer commited on
Commit
102dc72
1 Parent(s): 8ffce3e

MvP of the Project, added the base functionality for it become a simple ChatBot

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /.streamlit/
README.md CHANGED
@@ -1,4 +1,52 @@
1
- # LegalLLaMa (*WORK IN PROGRESS*)
2
  LegalLLaMa: Your friendly neighborhood lawyer llama, turning legal jargon into a piece of cake!
3
 
4
- LegalLLaMA is a chatbot powered by a fine-tuned LLaMa model, providing summaries and insights from U.S. Congressional bills. Bridging the gap between law and AI, one conversation at a time.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LegalLLaMa 🦙 (*WORK IN PROGRESS*)
2
  LegalLLaMa: Your friendly neighborhood lawyer llama, turning legal jargon into a piece of cake!
3
 
4
+ Legal LLaMa is a chatbot developed to provide summaries of U.S. legislative bills based on user queries. It's built using the Hugging Face's Transformers library, and is hosted using Streamlit on Hugging Face Spaces.
5
+
6
+ You can interact with the live demo of Legal LLaMa on Hugging Face Spaces [here](https://huggingface.co/spaces/LLaMaWhisperer/legalLLaMa).
7
+
8
+ The chatbot uses a frame-based dialog management system to handle conversations, and leverages the ProPublica and Congress APIs to fetch information about legislative bills. The summaries of bills are generated using a state-of-the-art text summarization model.
9
+
10
+ ## Features 🎁
11
+
12
+ - Frame-based dialog management
13
+ - Intent recognition and slot filling
14
+ - Real-time interaction with users
15
+ - Bill retrieval using ProPublica and Congress APIs
16
+ - Bill summarization using Transformer models
17
+
18
+ ## Future Work 💡
19
+
20
+ Legal LLaMa is still a work in progress, and there are plans to make it even more useful and user-friendly. Here are some of the planned improvements:
21
+
22
+ - Enhance intent recognition and slot filling using Natural Language Understanding (NLU) models
23
+ - Expand the chatbot's capabilities to handle more tasks, such as providing summaries of recent bills by a particular congressman
24
+ - Train a custom summarization model specifically for legislative texts
25
+
26
+ ## Getting Started 🚀
27
+
28
+ To get the project running on your local machine, follow these steps:
29
+
30
+ 1. Clone the repository:
31
+ ```commandline
32
+ git clone https://github.com/YuvrajSharma9981/LegalLLaMa.git
33
+ ```
34
+ 2. Install the required packages:
35
+ ```commandline
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ 3. Run the Streamlit app:
40
+ ```commandline
41
+ streamlit run app.py
42
+ ```
43
+
44
+ Please note that you will need to obtain API keys from ProPublica and Congress to access their APIs.
45
+
46
+ ## Contributing 🤝
47
+
48
+ Contributions to improve Legal LLaMa are welcomed. Feel free to submit a pull request or create an issue for any bugs, feature requests, or questions about the project.
49
+
50
+ ## License 📄
51
+
52
+ This project is licensed under the GPL-3.0 License - see the [LICENSE](LICENSE) file for details.
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from legal_llama.chat_bot_interface import ChatBotInterface
2
+
3
+ if __name__ == '__main__':
4
+ chat_bot = ChatBotInterface()
5
+ chat_bot.continue_conversation()
legal_llama/__init__.py ADDED
File without changes
legal_llama/bill_retrieval.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import streamlit as st
3
+ import xml.etree.ElementTree as ET
4
+
5
+
6
+ class BillRetriever:
7
+ """
8
+ A class used to retrieve bills using the ProPublica Congress API & United States Congress API.
9
+ """
10
+ PROPUBLICA_URL = "https://api.propublica.org/congress/v1/bills/search.json"
11
+ CONGRESS_URL_BASE = "https://api.congress.gov/v3/bill/{congress}/{billType}/{billNumber}/text"
12
+
13
+ def __init__(self, api_key=None):
14
+ """
15
+ Initialize the BillRetriever with API keys.
16
+
17
+ Parameters:
18
+ api_key (str, optional): The API key to be used for authentication. Default is None.
19
+ """
20
+ self.pro_publica_api_key = st.secrets["PRO_PUBLICA_API_KEY"]
21
+ self.congress_api_key = st.secrets["CONGRESS_API_KEY"]
22
+
23
+ def make_api_call(self, api_url, api_key, params=None):
24
+ """
25
+ Make an API call to the specified URL with optional parameters and API key.
26
+
27
+ Parameters:
28
+ api_url (str): The URL of the API endpoint.
29
+ api_key (str): The API Key for the API
30
+ params (dict, optional): Optional parameters to pass with the API call. Default is None.
31
+
32
+ Returns:
33
+ dict: JSON response data if the request is successful, None otherwise.
34
+ """
35
+ headers = {"X-API-Key": api_key} if api_key else {}
36
+
37
+ try:
38
+ response = requests.get(api_url, params=params, headers=headers)
39
+ response.raise_for_status() # Raise an exception for non-2xx status codes
40
+ return response.json()
41
+ except requests.exceptions.RequestException as e:
42
+ print(f"Error occurred: {e}")
43
+ return None
44
+ except ValueError as e:
45
+ print(f"Invalid response received: {e}")
46
+ return None
47
+
48
+ def search_bill_propublica(self, query):
49
+ """
50
+ Search for a bill using the ProPublica Congress API.
51
+
52
+ Parameters:
53
+ query (str): The query string to search for.
54
+
55
+ Returns:
56
+ dict: JSON response data if the request is successful, None otherwise.
57
+ """
58
+ params = {"query": query, "sort": "date", "dir": "desc"}
59
+ return self.make_api_call(self.PROPUBLICA_URL, params=params, api_key=self.pro_publica_api_key)
60
+
61
+ def get_bill_text_congress(self, congress, bill_type, bill_number):
62
+ """
63
+ Retrieve the text of a bill using the Congress API.
64
+
65
+ Parameters:
66
+ congress (str): The number of the congress.
67
+ bill_type (str): The type of the bill.
68
+ bill_number (str): The number of the bill.
69
+
70
+ Returns:
71
+ dict: JSON response data if the request is successful, None otherwise.
72
+ """
73
+ url = self.CONGRESS_URL_BASE.format(congress=congress, billType=bill_type, billNumber=bill_number)
74
+ return self.make_api_call(url, api_key=self.congress_api_key)
75
+
76
+ def get_bill_by_query(self, query):
77
+ """
78
+ Search for a bill by query and retrieve its text.
79
+
80
+ Parameters:
81
+ query (str): The query string to search for.
82
+
83
+ Returns:
84
+ str: The text of the bill if the request is successful, None otherwise.
85
+ """
86
+ # First search for the bill using the ProPublica API
87
+ propublica_data = self.search_bill_propublica(query)
88
+ if propublica_data and 'results' in propublica_data:
89
+ # Iterate over the list of bills, till we find the bill which has text available on Congress Website
90
+ for bill_data in propublica_data['results'][0]['bills']:
91
+ congress = bill_data['bill_id'].split('-')[1]
92
+ bill_type = bill_data['bill_type']
93
+ bill_number = bill_data['number'].split('.')[-1]
94
+
95
+ # Then get the text of the bill using the Congress API
96
+ congress_data = self.get_bill_text_congress(congress, bill_type, bill_number)
97
+ if congress_data and 'textVersions' in congress_data and congress_data['textVersions']:
98
+ # Check if textVersions list is not empty
99
+ xml_url = congress_data['textVersions'][0]['formats'][2]['url']
100
+ return self.extract_bill_text(xml_url)
101
+ return None
102
+
103
+ def extract_bill_text(self, url):
104
+ """
105
+ Extract the text content from a bill's XML data.
106
+
107
+ Parameters:
108
+ url (str): The URL of the bill's XML data.
109
+
110
+ Returns:
111
+ str: The text content of the bill.
112
+ """
113
+ # Get the XML data from the URL
114
+ try:
115
+ xml_data = requests.get(url).content
116
+ except requests.exceptions.RequestException as e:
117
+ print(f"Error occurred: {e}")
118
+ return None
119
+
120
+ # Decode bytes to string and parse XML
121
+ try:
122
+ root = ET.fromstring(xml_data.decode('utf-8'))
123
+ except ET.ParseError as e:
124
+ print(f"Error parsing XML: {e}")
125
+ return None
126
+
127
+ return self.get_all_text(root)
128
+
129
+ @staticmethod
130
+ def get_all_text(element):
131
+ """
132
+ Recursively extract text from an XML element and its children.
133
+
134
+ Parameters:
135
+ element (xml.etree.ElementTree.Element): An XML element.
136
+
137
+ Returns:
138
+ str: The concatenated text from the element and its children.
139
+ """
140
+ text = element.text or '' # Get the text of the current element, if it exists
141
+ for child in element:
142
+ text += BillRetriever.get_all_text(child) # Recursively get the text of all child elements
143
+ if child.tail:
144
+ text += child.tail # Add any trailing text of the child element
145
+ return text
legal_llama/chat_bot_interface.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from legal_llama.dialog_management import DialogManager
2
+ import streamlit as st
3
+
4
+
5
+ class ChatBotInterface:
6
+ def __init__(self):
7
+ """Initializes the chatbot interface, sets the page title, and initializes the DialogManager."""
8
+ # Set up Streamlit page configuration
9
+ st.set_page_config(page_title="Legal LLaMa 🦙")
10
+ st.title("Legal LLaMa 🦙")
11
+
12
+ # Define roles
13
+ self.user = "user"
14
+ self.llama = "Assistant"
15
+
16
+ # Initialize the DialogManager for managing conversations
17
+ self.dialog_manager = DialogManager()
18
+
19
+ # Initialize chat history in the session state if it doesn't exist
20
+ if "messages" not in st.session_state:
21
+ st.session_state.messages = []
22
+
23
+ # Start the conversation with a greeting message
24
+ first_message = ("Hello there! I'm Legal LLaMa, your friendly guide to the complex world of U.S. legislation."
25
+ "\n\nThink of me as a law student who is always eager to learn and share knowledge. Right now,"
26
+ "my skills are a bit limited, but I can certainly help you understand the gist of the latest "
27
+ "bills proposed in the U.S. Congress. You just have to provide me with a topic - could be "
28
+ "climate change, prison reform, healthcare, you name it! I'll then fetch the latest related "
29
+ "bill and serve you up a digestible summary.\n\nRemember, being a law student (and a LLaMa, no"
30
+ "less!) is tough, so if I miss a step, bear with me. I promise to get better with every "
31
+ "interaction. So, what topic intrigues you today?")
32
+ self.display_message(self.llama, first_message)
33
+
34
+ @staticmethod
35
+ def display_chat_history():
36
+ """Displays the chat history stored in the session state."""
37
+ for message in st.session_state.messages:
38
+ with st.chat_message(message["role"]):
39
+ st.markdown(message["content"])
40
+
41
+ @staticmethod
42
+ def add_message_to_history(role, chat):
43
+ """Adds a message to the chat history in the session state."""
44
+ st.session_state.messages.append({"role": role, "content": chat})
45
+
46
+ @staticmethod
47
+ def display_message(role, text):
48
+ """Displays a chat message in the chat interface."""
49
+ st.chat_message(role).markdown(text)
50
+
51
+ def handle_user_input(self, user_input):
52
+ """Handles user input by recognizing the intent and updating the dialog frame."""
53
+ # In future, use the IntentRecognizer to check for intent
54
+ intent = "bill_summarization"
55
+
56
+ # Update the dialog frame based on the recognized intent
57
+ self.dialog_manager.set_frame(intent, user_input)
58
+
59
+ def continue_conversation(self):
60
+ """Continues the conversation by displaying chat history, handling user input, and generating responses."""
61
+ # Display chat history
62
+ self.display_chat_history()
63
+
64
+ # Handle user input
65
+ if prompt := st.chat_input("Ask your questions here!"):
66
+ # Display user message
67
+ self.display_message(self.user, prompt)
68
+
69
+ # Add user message to chat history
70
+ self.add_message_to_history(self.user, prompt)
71
+
72
+ # Handle user input (recognize intent and update frame)
73
+ self.handle_user_input(prompt)
74
+
75
+ with st.spinner('Processing your request...'):
76
+ # Generate response based on the current dialog frame
77
+ response = self.dialog_manager.generate_response()
78
+
79
+ # Display assistant response
80
+ self.display_message(self.llama, response)
81
+
82
+ # Add assistant response to chat history
83
+ self.add_message_to_history(self.llama, response)
legal_llama/dialog_management.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from legal_llama.bill_retrieval import BillRetriever
2
+ from legal_llama.summarizer import BillSummarizer
3
+
4
+
5
+ class DialogManager:
6
+ """
7
+ A class for managing conversation frames.
8
+ """
9
+
10
+ def __init__(self):
11
+ """
12
+ Initialize the DialogManager with predefined frames.
13
+ """
14
+ self.frames = {
15
+ "bill_summarization": {
16
+ "intent": "bill_summarization",
17
+ "bill_query": None,
18
+ },
19
+ # Add more frames here as needed
20
+ }
21
+ self.current_frame = None
22
+
23
+ def set_frame(self, intent, slot):
24
+ """
25
+ Set the current frame based on the recognized intent and provided slot value.
26
+
27
+ Parameters:
28
+ intent (str): The recognized intent.
29
+ slot (str): The value of the slot provided by the user.
30
+ """
31
+ # Update this function in the future to check for intent.
32
+ self.current_frame = self.frames.get(intent, {}).copy()
33
+ if self.current_frame is not None:
34
+ self.update_slot('bill_query', slot)
35
+ else:
36
+ print(f"Unrecognized intent: {intent}")
37
+
38
+ def update_slot(self, slot_name, slot_value):
39
+ """
40
+ Update the value of a slot in the current frame.
41
+
42
+ Parameters:
43
+ slot_name (str): The name of the slot.
44
+ slot_value (str): The new value of the slot.
45
+ """
46
+ if self.current_frame is not None and slot_name in self.current_frame:
47
+ # If the current frame is set and the slot name exists in the frame, update the slot value
48
+ self.current_frame[slot_name] = slot_value
49
+ else:
50
+ print(f"Cannot update slot '{slot_name}' - no current frame or slot does not exist")
51
+
52
+ def generate_response(self):
53
+ """
54
+ Generate a response based on the current frame.
55
+
56
+ Returns:
57
+ str: The generated response.
58
+ """
59
+ # Check if a frame has been set
60
+ if self.current_frame is None:
61
+ print("No frame has been set")
62
+ return None
63
+
64
+ frame = self.current_frame
65
+ if frame['intent'] == 'bill_summarization':
66
+ # Extract the bill's text
67
+ bill_retriever = BillRetriever()
68
+ bill_text = bill_retriever.get_bill_by_query(frame['bill_query'])
69
+ if bill_text is None:
70
+ print("Unable to retrieve bill text")
71
+ return None
72
+
73
+ # Summarize the bill's text
74
+ summarizer = BillSummarizer()
75
+ summary = summarizer.summarize(bill_text)
76
+ if summary is None:
77
+ print("Unable to summarize bill text")
78
+ return None
79
+
80
+ return summary
81
+ else:
82
+ print(f"Unrecognized frame intent: {frame['intent']}")
83
+ return None
legal_llama/summarizer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import streamlit as st
3
+
4
+
5
+ @st.cache_resource
6
+ def load_model():
7
+ tokenizers = AutoTokenizer.from_pretrained("nsi319/legal-led-base-16384")
8
+ model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-led-base-16384")
9
+ return tokenizers, model
10
+
11
+
12
+ class BillSummarizer:
13
+ def __init__(self):
14
+ """
15
+ Initialize a BillSummarizer, which uses the Hugging Face transformers library to summarize bills.
16
+ """
17
+ try:
18
+ self.tokenizer, self.model = load_model()
19
+ except Exception as e:
20
+ print(f"Error initializing summarizer pipeline: {e}")
21
+
22
+ def summarize(self, bill_text):
23
+ """
24
+ Summarize a bill's text using the summarization pipeline.
25
+
26
+ Parameters:
27
+ bill_text (str): The text of the bill to be summarized.
28
+
29
+ Returns:
30
+ str: The summarized text.
31
+ """
32
+ try:
33
+ input_tokenized = self.tokenizer.encode(bill_text, return_tensors='pt',
34
+ padding="max_length",
35
+ pad_to_max_length=True,
36
+ max_length=6144,
37
+ truncation=True)
38
+
39
+ summary_ids = self.model.generate(input_tokenized,
40
+ num_beams=4,
41
+ no_repeat_ngram_size=3,
42
+ length_penalty=2,
43
+ min_length=350,
44
+ max_length=500)
45
+
46
+ summary = [self.tokenizer.decode(g,
47
+ skip_special_tokens=True,
48
+ clean_up_tokenization_spaces=False)
49
+ for g in summary_ids][0]
50
+
51
+ return summary
52
+ except Exception as e:
53
+ print(f"Error summarizing text: {e}")
54
+ return "Sorry, I couldn't summarize this bill. Please try again."
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers~=4.31.0
2
+ torch~=2.0.1
3
+ streamlit~=1.24.1
4
+ requests~=2.31.0