Spaces:
Running
Running
File size: 9,155 Bytes
16353a0 87232cf 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 1ef06de 75a5661 1ef06de 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 1ef06de 16353a0 1ef06de 16353a0 8fc987e 482c776 75a5661 16353a0 1ef06de 8fc987e 482c776 75a5661 1ef06de 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 16353a0 75a5661 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import pandas as pd
from groq import Groq, RateLimitError
import instructor
from pydantic import BaseModel
import os
# Ensure GROQ_API_KEY is set in your environment variables
api_key = os.getenv('GROQ_API_KEY')
if not api_key:
raise ValueError("GROQ_API_KEY environment variable not set.")
# Create single patched Groq client with instructor for structured output
# Using Mode.JSON for structured output based on Pydantic models
client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
# Pydantic model for summarization output
class SummaryOutput(BaseModel):
summary: str
# Pydantic model for classification output
class ClassificationOutput(BaseModel):
category: str
# Define model names (as per your original code)
PRIMARY_SUMMARIZER_MODEL = "deepseek-r1-distill-llama-70b"
FALLBACK_SUMMARIZER_MODEL = "llama-3.3-70b-versatile"
CLASSIFICATION_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct" # Or your preferred classification model
# Define the standard list of categories, including "None"
CLASSIFICATION_LABELS = [
"Company Culture and Values",
"Employee Stories and Spotlights",
"Work-Life Balance, Flexibility, and Well-being",
"Diversity, Equity, and Inclusion (DEI)",
"Professional Development and Growth Opportunities",
"Mission, Vision, and Social Responsibility",
"None" # Represents no applicable category or cases where classification isn't possible
]
def summarize_post(text: str) -> str | None:
"""
Summarizes the given post text using a primary model with a fallback.
Returns the summary string or None if summarization fails or input is invalid.
"""
# Check for NaN, None, or empty/whitespace-only string
if pd.isna(text) or text is None or not str(text).strip():
print("Summarizer: Input text is empty or None. Returning None.")
return None
# Truncate text to a reasonable length to avoid token overflow and reduce costs
processed_text = str(text)[:500]
prompt = f"""
Summarize the following LinkedIn post in 5 to 10 words.
Only return the summary inside a JSON field called 'summary'.
Post Text:
\"\"\"{processed_text}\"\"\"
"""
try:
# Attempt with primary model
print(f"Attempting summarization with primary model: {PRIMARY_SUMMARIZER_MODEL}")
response = client.chat.completions.create(
model=PRIMARY_SUMMARIZER_MODEL,
response_model=SummaryOutput,
messages=[
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
{"role": "user", "content": prompt}
],
temperature=0.3
)
return response.summary
except RateLimitError:
print(f"Rate limit hit for primary summarizer model: {PRIMARY_SUMMARIZER_MODEL}. Trying fallback: {FALLBACK_SUMMARIZER_MODEL}")
try:
# Attempt with fallback model
response = client.chat.completions.create(
model=FALLBACK_SUMMARIZER_MODEL,
response_model=SummaryOutput,
messages=[
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
{"role": "user", "content": prompt}
],
temperature=0.3
)
print(f"Summarization successful with fallback model: {FALLBACK_SUMMARIZER_MODEL}")
return response.summary
except RateLimitError as rle_fallback:
print(f"Rate limit hit for fallback summarizer model ({FALLBACK_SUMMARIZER_MODEL}): {rle_fallback}. Summarization failed.")
return None
except Exception as e_fallback:
print(f"Error during summarization with fallback model ({FALLBACK_SUMMARIZER_MODEL}): {e_fallback}")
return None
except Exception as e_primary:
print(f"Error during summarization with primary model ({PRIMARY_SUMMARIZER_MODEL}): {e_primary}")
# Consider if fallback should be attempted for other errors too, or just return None
return None
def classify_post(summary: str | None, labels: list[str]) -> str:
"""
Classifies the post summary into one of the provided labels.
Ensures the returned category is one of the labels, defaulting to "None".
"""
# If the summary is None (e.g., from a failed summarization or empty input),
# or if the summary is an empty string after stripping, classify as "None".
if pd.isna(summary) or summary is None or not str(summary).strip():
print("Classifier: Input summary is empty or None. Returning 'None' category.")
return "None" # Return the string "None" to match the label
# Join labels for the prompt to ensure the LLM knows the exact expected strings
labels_string = "', '".join(labels)
prompt = f"""
Post Summary: "{summary}"
Available Categories:
'{labels_string}'
Task: Choose the single most relevant category from the list above that applies to this summary.
Return ONLY ONE category string in a structured JSON format under the field 'category'.
The category MUST be one of the following: '{labels_string}'.
If no specific category applies, or if you are unsure, return "None".
"""
try:
system_message = (
f"You are a very strict classifier. Your ONLY job is to return a JSON object "
f"with a 'category' field. The value of 'category' MUST be one of these "
f"exact strings: '{labels_string}'."
)
result = client.chat.completions.create(
model=CLASSIFICATION_MODEL,
response_model=ClassificationOutput,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt}
],
temperature=0 # Temperature 0 for deterministic classification
)
returned_category = result.category
# Validate the output against the provided labels
if returned_category not in labels:
print(f"Warning: Classifier returned '{returned_category}', which is not in the predefined labels. Forcing to 'None'. Summary: '{summary}'")
return "None" # Force to "None" if the LLM returns an unexpected category
return returned_category
except Exception as e:
print(f"Classification error: {e}. Summary: '{summary}'. Defaulting to 'None' category.")
return "None" # Default to "None" on any exception during classification
def summarize_and_classify_post(text: str | None, labels: list[str]) -> dict:
"""
Summarizes and then classifies a single post text.
Handles cases where text is None or summarization fails.
"""
summary = summarize_post(text) # This can return None
# If summarization didn't produce a result (e.g. empty input, error),
# or if the summary itself is effectively empty, the category is "None".
if summary is None or not summary.strip():
category = "None"
else:
# If we have a valid summary, try to classify it.
# classify_post is designed to return one of the labels or "None".
category = classify_post(summary, labels)
return {
"summary": summary, # This can be None
"category": category # This will be one of the labels or "None"
}
def batch_summarize_and_classify(posts_data: list[dict]) -> list[dict]:
"""
Processes a batch of posts, performing summarization and classification for each.
Expects posts_data to be a list of dictionaries, each with at least 'id' and 'text' keys.
Returns a list of dictionaries, each with 'id', 'summary', and 'category'.
"""
results = []
if not posts_data:
print("Input 'posts_data' is empty. Returning empty results.")
return results
for i, post_item in enumerate(posts_data):
if not isinstance(post_item, dict):
print(f"Warning: Item at index {i} is not a dictionary. Skipping.")
continue
post_id = post_item.get("id")
text_to_process = post_item.get("text") # This text is passed to summarize_and_classify_post
print(f"\nProcessing Post ID: {post_id if post_id else 'N/A (ID missing)'}, Text: '{str(text_to_process)[:50]}...'")
# summarize_and_classify_post will handle None/empty text internally
# and ensure category is "None" in such cases.
summary_and_category_result = summarize_and_classify_post(text_to_process, CLASSIFICATION_LABELS)
results.append({
"id": post_id, # Include the ID for mapping back to original data
"summary": summary_and_category_result["summary"],
"category": summary_and_category_result["category"] # This is now validated
})
print(f"Result for Post ID {post_id}: Summary='{summary_and_category_result['summary']}', Category='{summary_and_category_result['category']}'")
return results |