yjernite HF Staff commited on
Commit
e672262
·
verified ·
1 Parent(s): f93d330

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/interface_utils.py +136 -0
  2. utils/llm_utils.py +315 -0
utils/interface_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ import html
3
+ import re
4
+ from typing import List, Tuple
5
+
6
+
7
+ # --- Helper Function for Markdown Highlighting ---
8
+ def generate_highlighted_markdown(text, spans_with_info):
9
+ """Applies highlighting spans with hover info to text for Markdown output."""
10
+ # Ensure spans are sorted by start index and valid
11
+ # Expects spans_with_info to be list of (start, end, hover_text_string)
12
+ valid_spans = sorted(
13
+ [
14
+ (s, e, info)
15
+ for s, e, info in spans_with_info # Unpack the tuple
16
+ if isinstance(s, int) and isinstance(e, int) and 0 <= s <= e <= len(text)
17
+ ],
18
+ key=lambda x: x[0],
19
+ )
20
+
21
+ highlighted_parts = []
22
+ current_pos = 0
23
+ # Iterate through sorted spans with info
24
+ for start, end, hover_text in valid_spans:
25
+ # Add text before the current span (NO HTML escaping)
26
+ if start > current_pos:
27
+ highlighted_parts.append(text[current_pos:start])
28
+ # Add the highlighted span with title attribute
29
+ if start < end:
30
+ # Escape hover text for the title attribute
31
+ escaped_hover_text = html.escape(hover_text, quote=True)
32
+ # Escape span content for display
33
+ escaped_content = html.escape(text[start:end])
34
+ highlighted_parts.append(
35
+ f"<span style='background-color: lightgreen;' title='{escaped_hover_text}'>{escaped_content}</span>"
36
+ )
37
+ # Update current position, ensuring it doesn't go backward in case of overlap
38
+ current_pos = max(current_pos, end)
39
+
40
+ # Add any remaining text after the last span (NO HTML escaping)
41
+ if current_pos < len(text):
42
+ highlighted_parts.append(text[current_pos:])
43
+
44
+ return "".join(highlighted_parts)
45
+
46
+
47
+ # --- Citation Span Matching Function ---
48
+ def find_citation_spans(document: str, citation: str) -> List[Tuple[int, int]]:
49
+ """
50
+ Finds character spans in the document that likely form the citation,
51
+ allowing for fragments and minor differences. Uses SequenceMatcher
52
+ on alphanumeric words and maps back to character indices.
53
+ This follows a greedy iterative strategy to find the longest match to account for cases where fragments are reordered.
54
+
55
+ Args:
56
+ document: The source document string.
57
+ citation: The citation string, potentially with fragments/typos.
58
+
59
+ Returns:
60
+ A list of (start, end) character tuples from the document,
61
+ representing the most likely origins of the citation fragments.
62
+ """
63
+ # 1. Tokenize document and citation into ALPHANUMERIC words with char spans
64
+ doc_tokens = [
65
+ (m.group(0), m.start(), m.end()) for m in re.finditer(r"[a-zA-Z0-9]+", document)
66
+ ]
67
+ cite_tokens = [
68
+ (m.group(0), m.start(), m.end()) for m in re.finditer(r"[a-zA-Z0-9]+", citation)
69
+ ]
70
+ if not doc_tokens or not cite_tokens:
71
+ return []
72
+
73
+ doc_words = [t[0].lower() for t in doc_tokens]
74
+ cite_words = [t[0].lower() for t in cite_tokens]
75
+
76
+ # 2. Find longest common blocks of words using SequenceMatcher
77
+ matcher = difflib.SequenceMatcher(None, doc_words, cite_words, autojunk=False)
78
+ matching_blocks = []
79
+ matched_tokens = 0
80
+
81
+ unmatched_doc_words = [(0, len(doc_words))]
82
+ unmatched_cite_words = [(0, len(cite_words))]
83
+
84
+ while matched_tokens < len(cite_words):
85
+ next_match_candidates = []
86
+ for da, db in unmatched_doc_words:
87
+ for ca, cb in unmatched_cite_words:
88
+ match = matcher.find_longest_match(da, db, ca, cb)
89
+ if match.size > 0:
90
+ next_match_candidates.append(match)
91
+ if len(next_match_candidates) == 0:
92
+ break
93
+ next_match = max(next_match_candidates, key=lambda x: x.size)
94
+ matching_blocks.append(next_match)
95
+ matched_tokens += next_match.size
96
+
97
+ # Update unmatched regions (this part needs careful implementation)
98
+ # Simplified logic: remove fully contained regions and split overlapping ones
99
+ new_unmatched_docs = []
100
+ for da, db in unmatched_doc_words:
101
+ # Check if this doc segment overlaps with the match
102
+ if next_match.a < db and next_match.a + next_match.size > da:
103
+ # Add segment before the match
104
+ if next_match.a > da:
105
+ new_unmatched_docs.append((da, next_match.a))
106
+ # Add segment after the match
107
+ if next_match.a + next_match.size < db:
108
+ new_unmatched_docs.append((next_match.a + next_match.size, db))
109
+ else:
110
+ new_unmatched_docs.append((da, db)) # Keep non-overlapping segment
111
+ unmatched_doc_words = new_unmatched_docs
112
+
113
+ new_unmatched_cites = []
114
+ for ca, cb in unmatched_cite_words:
115
+ if next_match.b < cb and next_match.b + next_match.size > ca:
116
+ if next_match.b > ca:
117
+ new_unmatched_cites.append((ca, next_match.b))
118
+ if next_match.b + next_match.size < cb:
119
+ new_unmatched_cites.append((next_match.b + next_match.size, cb))
120
+ else:
121
+ new_unmatched_cites.append((ca, cb))
122
+ unmatched_cite_words = new_unmatched_cites
123
+
124
+ # 3. Convert matching word blocks back to character spans
125
+ char_spans = []
126
+ for i, j, n in sorted(matching_blocks, key=lambda x: x.a):
127
+ if n == 0:
128
+ continue
129
+ start_char = doc_tokens[i][1]
130
+ end_char = doc_tokens[i + n - 1][2]
131
+ if char_spans and char_spans[-1][1] >= start_char - 1:
132
+ char_spans[-1] = (char_spans[-1][0], max(char_spans[-1][1], end_char))
133
+ else:
134
+ char_spans.append((start_char, end_char))
135
+
136
+ return char_spans
utils/llm_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import logging
4
+
5
+ from huggingface_hub import HfApi, InferenceClient
6
+
7
+ import utils.interface_utils as interface_utils
8
+
9
+ # Renamed constant to indicate it's a default/fallback
10
+ DEFAULT_LLM_ENDPOINT_URL = (
11
+ "https://r5lahjemc2zuajga.us-east-1.aws.endpoints.huggingface.cloud"
12
+ )
13
+
14
+ # Added Endpoint name constant
15
+ LLM_ENDPOINT_NAME = os.getenv(
16
+ "HF_LLM_ENDPOINT_NAME", "phi-4-max"
17
+ ) # Get from env or default
18
+
19
+ RETRIEVAL_SYSTEM_PROMPT = """**Instructions:**
20
+ You are a helpful assistant presented with a document excerpts and a question.
21
+ Your job is to retrieve the most relevant passages from the provided document excerpt that contribute to help answer the question.
22
+
23
+ For each passage retrieved from the documents, provide:
24
+ - a brief summary of the context leading up to the passage (2 sentences max)
25
+ - the supported passage quoted exactly
26
+ - a brief summary of how the points in the passage are relevant to the question (2 sentences max)
27
+
28
+ The supporting passages should be a JSON-formatted list of dictionaries with the keys 'context' 'quote' and 'relevance'.
29
+ Provide up to 4 different supporting passages covering as many different aspects of the topic in question as possible.
30
+ Only include passages that are relevant to the question. If there are fewer or no relevant passages in the document, just return a shorter or empty list.
31
+ """
32
+
33
+ QA_RETRIEVAL_PROMPT = """Find passages from the following documents that help answer the question.
34
+
35
+ **Document Content:**
36
+ ```markdown
37
+ {document}
38
+ ```
39
+
40
+ **Question:**
41
+ {question}
42
+
43
+ JSON Output:"""
44
+
45
+ ANSWER_SYSTEM_PROMPT = """**Instructions:**
46
+ You are a helpful assistant presented with a list of snippets extracted from documents and a question.
47
+ The snippets are presented in a JSON-formatted list that includes a unique id (`id`), context, relevance, and the exact quote.
48
+ Your job is to answer the question based *only* on the most relevant provided snippet quotes, citing the snippets used for each sentence.
49
+
50
+ **Output Format:**
51
+ Your response *must* be a JSON-formatted list of dictionaries. Each dictionary represents a sentence in your answer and must have the following keys:
52
+ - `sentence`: A string containing the sentence.
53
+ - `citations`: A list of integers, where each integer is the `id` of a snippet that supports the sentence.
54
+
55
+ **Example Output:**
56
+ ```json
57
+ [
58
+ {
59
+ "sentence": "This is the first sentence of the answer.",
60
+ "citations": [1, 3]
61
+ },
62
+ {
63
+ "sentence": "This is the second sentence, supported by another snippet.",
64
+ "citations": [5]
65
+ }
66
+ ]
67
+ ```
68
+
69
+ **Constraints:**
70
+ - Base your answer *only* on the information within the provided snippets.
71
+ - Do *not* use external knowledge.
72
+ - The sentences should flow together coherently.
73
+ - A single sentence can cite multiple snippets.
74
+ - The final answer should be no more than 5-6 sentences long.
75
+ - Ensure the output is valid JSON.
76
+ """
77
+
78
+ ANSWER_PROMPT = """
79
+ Given the following snippets, answer the question.
80
+ ```json
81
+ {snippets}
82
+ ```
83
+
84
+ **Question:**
85
+ {question}
86
+
87
+ JSON Output:"""
88
+
89
+ # Initialize client using token from environment variables
90
+ client = InferenceClient(token=os.getenv("HF_TOKEN"))
91
+
92
+
93
+ # --- Endpoint Status Check Function ---
94
+ def check_endpoint_status(token: str | None, endpoint_name: str = LLM_ENDPOINT_NAME):
95
+ """Checks the Inference Endpoint status and returns status dict."""
96
+ # (Function body moved from app.py - Ensure logging is configured)
97
+ logging.info(f"Checking endpoint status for '{endpoint_name}'...")
98
+ if not token:
99
+ logging.warning("HF Token not available, cannot check endpoint status.")
100
+ return {
101
+ "status": "ready",
102
+ "warning": "HF Token not available for status check.",
103
+ }
104
+ try:
105
+ api = HfApi(token=token)
106
+ endpoint = api.get_inference_endpoint(name=endpoint_name, token=token)
107
+ status = endpoint.status
108
+ logging.info(f"Endpoint '{endpoint_name}' status: {status}")
109
+ if status == "running":
110
+ return {"status": "ready"}
111
+ else:
112
+ if status == "scaledToZero":
113
+ logging.info(
114
+ f"Endpoint '{endpoint_name}' is scaled to zero. Attempting to resume..."
115
+ )
116
+ try:
117
+ endpoint.resume()
118
+ user_message = f"The required LLM endpoint ('{endpoint_name}') was scaled to zero and is **now restarting**. Please wait a few minutes and try submitting your query again."
119
+ logging.info(f"Resume command sent for '{endpoint_name}'.")
120
+ return {"status": "error", "ui_message": user_message}
121
+ except Exception as resume_error:
122
+ logging.error(
123
+ f"Failed to resume endpoint '{endpoint_name}': {resume_error}"
124
+ )
125
+ user_message = f"The required LLM endpoint ('{endpoint_name}') is scaled to zero. An attempt to automatically resume it failed: {resume_error}. Please check the endpoint status on Hugging Face."
126
+ return {"status": "error", "ui_message": user_message}
127
+ else:
128
+ user_message = f"The required LLM endpoint ('{endpoint_name}') is currently **{status}**. Analysis cannot proceed until it is running. Please check the endpoint status on Hugging Face."
129
+ logging.warning(
130
+ f"Endpoint '{endpoint_name}' is not ready (Status: {status})."
131
+ )
132
+ return {"status": "error", "ui_message": user_message}
133
+ except Exception as e:
134
+ error_msg = f"Error checking endpoint status for {endpoint_name}: {e}"
135
+ logging.error(error_msg)
136
+ return {
137
+ "status": "error",
138
+ "ui_message": f"Failed to check endpoint status. Please verify the endpoint name ('{endpoint_name}') and your token. Error: {e}",
139
+ }
140
+
141
+
142
+ def retrieve_passages(
143
+ query, doc_embeds, passages, processed_docs, embed_model, max_docs=3
144
+ ):
145
+ """Retrieves relevant passages based on embedding similarity, limited by max_docs."""
146
+ queries = [query]
147
+ query_embeddings = embed_model.encode(queries, prompt_name="query")
148
+ scores = embed_model.similarity(query_embeddings, doc_embeds)
149
+ sorted_scores = scores.sort(descending=True)
150
+ sorted_vals = sorted_scores.values[0].tolist()
151
+ sorted_idx = sorted_scores.indices[0].tolist()
152
+ results = [
153
+ {
154
+ "passage_id": i,
155
+ "document_id": passages[i][0],
156
+ "chunk_id": passages[i][1],
157
+ "document_url": processed_docs[passages[i][0]]["url"],
158
+ "passage_text": passages[i][2],
159
+ "relevance": v,
160
+ }
161
+ for i, v in zip(sorted_idx, sorted_vals)
162
+ ]
163
+ # Slice the results here
164
+ return results[:max_docs]
165
+
166
+
167
+ # --- Excerpt Processing Function ---
168
+ def process_single_excerpt(
169
+ excerpt_index: int, excerpt: dict, query: str, hf_client: InferenceClient
170
+ ):
171
+ """Processes a single retrieved excerpt using an LLM to find citations and spans."""
172
+
173
+ passage_text = excerpt.get("passage_text", "")
174
+ if not passage_text:
175
+ return {
176
+ "citations": [],
177
+ "all_spans": [],
178
+ "parse_successful": False,
179
+ "raw_error_response": "Empty passage text",
180
+ }
181
+
182
+ citations = []
183
+ all_spans = []
184
+ is_parse_successful = False
185
+ raw_error_response = None
186
+
187
+ try:
188
+ retrieval_prompt = QA_RETRIEVAL_PROMPT.format(
189
+ document=passage_text, question=query
190
+ )
191
+ response = hf_client.chat_completion(
192
+ messages=[
193
+ {"role": "system", "content": RETRIEVAL_SYSTEM_PROMPT},
194
+ {"role": "user", "content": retrieval_prompt},
195
+ ],
196
+ model=os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL),
197
+ max_tokens=2048,
198
+ temperature=0.01,
199
+ )
200
+
201
+ # Attempt to parse JSON
202
+ response_content = response.choices[0].message.content.strip()
203
+ try:
204
+ # Find JSON block
205
+ json_match = response_content.split("```json", 1)
206
+ if len(json_match) > 1:
207
+ json_str = json_match[1].split("```", 1)[0]
208
+ parsed_json = json.loads(json_str)
209
+ citations = parsed_json
210
+ is_parse_successful = True
211
+ # Find spans for each citation
212
+ for cit in citations:
213
+ quote = cit.get("quote", "")
214
+ if quote:
215
+ # Call find_citation_spans from interface_utils
216
+ spans = interface_utils.find_citation_spans(
217
+ document=passage_text, citation=quote
218
+ )
219
+ cit["char_spans"] = spans # Store spans in the citation dict
220
+ all_spans.extend(spans)
221
+ else:
222
+ raise ValueError("No ```json block found in response")
223
+ except (json.JSONDecodeError, ValueError, IndexError) as json_e:
224
+ print(f"Error parsing JSON for excerpt {excerpt_index}: {json_e}")
225
+ is_parse_successful = False
226
+ raw_error_response = f"LLM Response (failed to parse): {response_content}" # Fixed potential newline issue
227
+
228
+ except Exception as llm_e:
229
+ print(f"Error during LLM call for excerpt {excerpt_index}: {llm_e}")
230
+ is_parse_successful = False
231
+ raw_error_response = f"LLM API Error: {llm_e}"
232
+
233
+ return {
234
+ "citations": citations,
235
+ "all_spans": all_spans,
236
+ "parse_successful": is_parse_successful,
237
+ "raw_error_response": raw_error_response,
238
+ }
239
+
240
+
241
+ def generate_summary_answer(snippets: list, query: str, hf_client: InferenceClient):
242
+ """Generates a summarized answer based on provided snippets using an LLM."""
243
+ # NOTE: Removed llm_endpoint_url parameter, using env var directly
244
+ endpoint_url = os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL)
245
+ if not snippets:
246
+ return {
247
+ "answer_sentences": [],
248
+ "parse_successful": False,
249
+ "raw_error_response": "No snippets provided for summarization.",
250
+ }
251
+
252
+ try:
253
+ # Ensure snippets are formatted as a JSON string for the prompt
254
+ snippets_json_string = json.dumps(snippets, indent=2)
255
+
256
+ answer_prompt_formatted = ANSWER_PROMPT.format(
257
+ snippets=snippets_json_string, question=query
258
+ )
259
+
260
+ response = hf_client.chat_completion(
261
+ messages=[
262
+ {"role": "system", "content": ANSWER_SYSTEM_PROMPT},
263
+ {"role": "user", "content": answer_prompt_formatted},
264
+ ],
265
+ model=endpoint_url,
266
+ max_tokens=512,
267
+ temperature=0.01,
268
+ )
269
+
270
+ # Attempt to parse JSON response
271
+ response_content = response.choices[0].message.content.strip()
272
+ try:
273
+ # Find JSON block (assuming it might be wrapped in ```json ... ```)
274
+ json_match = response_content.split("```json", 1)
275
+ if len(json_match) > 1:
276
+ json_str = json_match[1].split("```", 1)[0]
277
+ else: # Assume the response *is* the JSON if no backticks found
278
+ json_str = response_content
279
+
280
+ parsed_json = json.loads(json_str)
281
+
282
+ # Basic validation: check if it's a list of dictionaries with expected keys
283
+ if isinstance(parsed_json, list) and all(
284
+ isinstance(item, dict) and "sentence" in item and "citations" in item
285
+ for item in parsed_json
286
+ ):
287
+ return {
288
+ "answer_sentences": parsed_json,
289
+ "parse_successful": True,
290
+ "raw_error_response": None,
291
+ }
292
+ else:
293
+ raise ValueError(
294
+ "Parsed JSON does not match expected format (list of {'sentence':..., 'citations':...})"
295
+ )
296
+
297
+ except (json.JSONDecodeError, ValueError, IndexError) as json_e:
298
+ print(f"Error parsing summary JSON: {json_e}")
299
+ return {
300
+ "answer_sentences": [],
301
+ "parse_successful": False,
302
+ "raw_error_response": f"LLM Response (failed to parse summary): {response_content}",
303
+ }
304
+
305
+ except Exception as llm_e:
306
+ print(f"Error during LLM summary call: {llm_e}")
307
+ return {
308
+ "answer_sentences": [],
309
+ "parse_successful": False,
310
+ "raw_error_response": f"LLM API Error during summary generation: {llm_e}",
311
+ }
312
+
313
+
314
+ # REMOVED Comment: This function will now live in app.py or interface_utils.py as it handles single excerpt processing
315
+ # def make_supporting_snippets(...): -> Now handled excerpt by excerpt in app.py