baconnier commited on
Commit
49175b2
1 Parent(s): 169974c

Update prompt_refiner.py

Browse files
Files changed (1) hide show
  1. prompt_refiner.py +139 -59
prompt_refiner.py CHANGED
@@ -1,28 +1,112 @@
1
  import json
2
  import re
3
- from typing import Optional, Dict, Any, Tuple
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
7
- from variables import meta_prompts, prompt_refiner_model
8
 
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
- explanation_of_refinements: str = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
- @validator('initial_prompt_evaluation', 'refined_prompt', 'explanation_of_refinements')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
 
 
 
 
 
 
 
 
 
21
  class PromptRefiner:
22
  def __init__(self, api_token: str, meta_prompts: dict):
23
  self.client = InferenceClient(token=api_token, timeout=120)
24
  self.meta_prompts = meta_prompts
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> Tuple[str, str, str, dict]:
27
  """Refine the given prompt using the selected meta prompt."""
28
  try:
@@ -69,60 +153,6 @@ class PromptRefiner:
69
  except Exception as e:
70
  return self._create_error_response(f"Unexpected error: {str(e)}")
71
 
72
- def _parse_response(self, response_content: str) -> dict:
73
- """Parse the LLM response content."""
74
- try:
75
- # Try to extract JSON from <json> tags
76
- json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
77
- if json_match:
78
- json_str = json_match.group(1).strip()
79
- # Clean up the JSON string
80
- json_str = re.sub(r'\s+', ' ', json_str)
81
- json_str = json_str.replace('•', '*') # Replace bullet points
82
-
83
- try:
84
- parsed_json = json.loads(json_str)
85
- if isinstance(parsed_json, str):
86
- parsed_json = json.loads(parsed_json)
87
-
88
- return {
89
- "initial_prompt_evaluation": parsed_json.get("initial_prompt_evaluation", ""),
90
- "refined_prompt": parsed_json.get("refined_prompt", ""),
91
- "explanation_of_refinements": parsed_json.get("explanation_of_refinements", ""),
92
- "response_content": parsed_json
93
- }
94
- except json.JSONDecodeError as e:
95
- print(f"JSON parsing error: {e}")
96
- return self._create_error_dict(str(e))
97
-
98
- # Fallback to regex parsing if JSON extraction fails
99
- return self._parse_with_regex(response_content)
100
-
101
- except Exception as e:
102
- print(f"Error parsing response: {e}")
103
- print(f"Raw content: {response_content}")
104
- return self._create_error_dict(str(e))
105
-
106
- def _parse_with_regex(self, content: str) -> dict:
107
- """Parse content using regex patterns when JSON parsing fails."""
108
- output = {}
109
- for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
110
- pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
111
- match = re.search(pattern, content, re.DOTALL)
112
- output[key] = match.group(1) if match else ""
113
-
114
- output["response_content"] = content
115
- return output
116
-
117
- def _create_error_dict(self, error_message: str) -> dict:
118
- """Create a standardized error response dictionary."""
119
- return {
120
- "initial_prompt_evaluation": f"Error parsing response: {error_message}",
121
- "refined_prompt": "",
122
- "explanation_of_refinements": "",
123
- "response_content": {"error": error_message}
124
- }
125
-
126
  def _create_error_response(self, error_message: str) -> Tuple[str, str, str, dict]:
127
  """Create a standardized error response tuple."""
128
  return (
@@ -130,4 +160,54 @@ class PromptRefiner:
130
  "The selected model is currently unavailable.",
131
  "An error occurred during processing.",
132
  {"error": error_message}
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import re
3
+ from typing import Optional, Dict, Any, Union, List, Tuple
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
7
+ from variables import *
8
 
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
+ explanation_of_refinements: Union[str, List[str]] = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
+ @validator('initial_prompt_evaluation', 'refined_prompt')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
21
+ @validator('explanation_of_refinements')
22
+ def clean_refinements(cls, v):
23
+ if isinstance(v, str):
24
+ return v.strip().replace('\\n', '\n').replace('\\"', '"')
25
+ elif isinstance(v, list):
26
+ return [item.strip().replace('\\n', '\n').replace('\\"', '"').replace('•', '-')
27
+ for item in v if isinstance(item, str)]
28
+ return v
29
+
30
  class PromptRefiner:
31
  def __init__(self, api_token: str, meta_prompts: dict):
32
  self.client = InferenceClient(token=api_token, timeout=120)
33
  self.meta_prompts = meta_prompts
34
 
35
+ def _clean_json_string(self, content: str) -> str:
36
+ """Clean and prepare JSON string for parsing."""
37
+ content = content.replace('•', '-') # Replace bullet points
38
+ content = re.sub(r'\s+', ' ', content) # Normalize whitespace
39
+ content = content.replace('\\"', '"') # Fix escaped quotes
40
+ return content.strip()
41
+
42
+ def _parse_response(self, response_content: str) -> dict:
43
+ """Parse the LLM response with enhanced error handling."""
44
+ try:
45
+ # Extract content between <json> tags
46
+ json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
47
+ if json_match:
48
+ json_str = self._clean_json_string(json_match.group(1))
49
+ try:
50
+ # Try parsing the cleaned JSON
51
+ parsed_json = json.loads(json_str)
52
+ if isinstance(parsed_json, str):
53
+ parsed_json = json.loads(parsed_json)
54
+
55
+ return {
56
+ "initial_prompt_evaluation": parsed_json.get("initial_prompt_evaluation", ""),
57
+ "refined_prompt": parsed_json.get("refined_prompt", ""),
58
+ "explanation_of_refinements": parsed_json.get("explanation_of_refinements", ""),
59
+ "response_content": parsed_json
60
+ }
61
+ except json.JSONDecodeError:
62
+ # If JSON parsing fails, try regex parsing
63
+ return self._parse_with_regex(json_str)
64
+
65
+ # If no JSON tags found, try regex parsing
66
+ return self._parse_with_regex(response_content)
67
+
68
+ except Exception as e:
69
+ print(f"Error parsing response: {str(e)}")
70
+ print(f"Raw content: {response_content}")
71
+ return self._create_error_dict(str(e))
72
+
73
+ def _parse_with_regex(self, content: str) -> dict:
74
+ """Parse content using regex when JSON parsing fails."""
75
+ output = {}
76
+
77
+ # Handle explanation_of_refinements list format
78
+ refinements_match = re.search(r'"explanation_of_refinements":\s*\[(.*?)\]', content, re.DOTALL)
79
+ if refinements_match:
80
+ refinements_str = refinements_match.group(1)
81
+ refinements = [
82
+ item.strip().strip('"').strip("'").replace('•', '-')
83
+ for item in re.findall(r'[•"]([^"•]+)[•"]', refinements_str)
84
+ ]
85
+ output["explanation_of_refinements"] = refinements
86
+ else:
87
+ # Try single string format
88
+ pattern = r'"explanation_of_refinements":\s*"(.*?)"(?:,|\})'
89
+ match = re.search(pattern, content, re.DOTALL)
90
+ output["explanation_of_refinements"] = match.group(1).strip() if match else ""
91
+
92
+ # Extract other fields
93
+ for key in ["initial_prompt_evaluation", "refined_prompt"]:
94
+ pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
95
+ match = re.search(pattern, content, re.DOTALL)
96
+ output[key] = match.group(1).strip() if match else ""
97
+
98
+ output["response_content"] = content
99
+ return output
100
+
101
+ def _create_error_dict(self, error_message: str) -> dict:
102
+ """Create a standardized error response dictionary."""
103
+ return {
104
+ "initial_prompt_evaluation": f"Error parsing response: {error_message}",
105
+ "refined_prompt": "",
106
+ "explanation_of_refinements": "",
107
+ "response_content": {"error": error_message}
108
+ }
109
+
110
  def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> Tuple[str, str, str, dict]:
111
  """Refine the given prompt using the selected meta prompt."""
112
  try:
 
153
  except Exception as e:
154
  return self._create_error_response(f"Unexpected error: {str(e)}")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def _create_error_response(self, error_message: str) -> Tuple[str, str, str, dict]:
157
  """Create a standardized error response tuple."""
158
  return (
 
160
  "The selected model is currently unavailable.",
161
  "An error occurred during processing.",
162
  {"error": error_message}
163
+ )
164
+
165
+ def apply_prompt(self, prompt: str, model: str) -> str:
166
+ """Apply formatting to the prompt using the specified model."""
167
+ try:
168
+ messages = [
169
+ {
170
+ "role": "system",
171
+ "content": """You are a markdown formatting expert. Format your responses with proper spacing and structure following these rules:
172
+ 1. Paragraph Spacing:
173
+ - Add TWO blank lines between major sections (##)
174
+ - Add ONE blank line between subsections (###)
175
+ - Add ONE blank line between paragraphs within sections
176
+ - Add ONE blank line before and after lists
177
+ - Add ONE blank line before and after code blocks
178
+ - Add ONE blank line before and after blockquotes
179
+
180
+ 2. Section Formatting:
181
+ # Title
182
+
183
+ ## Major Section
184
+
185
+ [blank line]
186
+ Content paragraph 1
187
+ [blank line]
188
+ Content paragraph 2
189
+ [blank line]"""
190
+ },
191
+ {
192
+ "role": "user",
193
+ "content": prompt
194
+ }
195
+ ]
196
+
197
+ response = self.client.chat_completion(
198
+ model=model,
199
+ messages=messages,
200
+ max_tokens=3000,
201
+ temperature=0.8,
202
+ stream=True
203
+ )
204
+
205
+ full_response = ""
206
+ for chunk in response:
207
+ if chunk.choices[0].delta.content is not None:
208
+ full_response += chunk.choices[0].delta.content
209
+
210
+ return full_response.replace('\n\n', '\n').strip()
211
+
212
+ except Exception as e:
213
+ return f"Error: {str(e)}"