baconnier commited on
Commit
169974c
1 Parent(s): 44f6549

Update prompt_refiner.py

Browse files
Files changed (1) hide show
  1. prompt_refiner.py +83 -54
prompt_refiner.py CHANGED
@@ -1,88 +1,110 @@
1
  import json
2
  import re
3
- from typing import Optional, Dict, Any, Union
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
7
- from variables import *
8
 
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
- explanation_of_refinements: Union[str, list] = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
- @validator('initial_prompt_evaluation', 'refined_prompt')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
21
- @validator('explanation_of_refinements')
22
- def clean_refinements(cls, v):
23
- if isinstance(v, str):
24
- return v.strip().replace('\\n', '\n').replace('\\"', '"')
25
- elif isinstance(v, list):
26
- return [item.strip().replace('\\n', '\n').replace('\\"', '"') if isinstance(item, str) else item for item in v]
27
- return v
28
-
29
  class PromptRefiner:
30
- def __init__(self, api_token: str, meta_prompts):
31
  self.client = InferenceClient(token=api_token, timeout=120)
32
  self.meta_prompts = meta_prompts
33
 
34
- def _sanitize_json_string(self, json_str: str) -> str:
35
- """Clean and prepare JSON string for parsing."""
36
- json_str = json_str.lstrip('\ufeff').strip()
37
- json_str = json_str.replace('\n', ' ')
38
- json_str = re.sub(r'\s+', ' ', json_str)
39
- json_str = json_str.replace('•', '*')
40
- return json_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def _extract_json_content(self, content: str) -> str:
43
- """Extract JSON content from between <json> tags."""
44
- json_match = re.search(r'<json>\s*(.*?)\s*</json>', content, re.DOTALL)
45
- if json_match:
46
- return self._sanitize_json_string(json_match.group(1))
47
- return content
48
 
49
  def _parse_response(self, response_content: str) -> dict:
 
50
  try:
51
- # First attempt: Try to parse the entire content as JSON
52
- cleaned_content = self._sanitize_json_string(response_content)
53
- try:
54
- parsed_json = json.loads(cleaned_content)
55
- if isinstance(parsed_json, str):
56
- parsed_json = json.loads(parsed_json)
57
- return self._normalize_json_output(parsed_json)
58
- except json.JSONDecodeError:
59
- # Second attempt: Try to extract JSON from <json> tags
60
- json_content = self._extract_json_content(response_content)
61
  try:
62
- parsed_json = json.loads(json_content)
63
  if isinstance(parsed_json, str):
64
  parsed_json = json.loads(parsed_json)
65
- return self._normalize_json_output(parsed_json)
66
- except json.JSONDecodeError:
67
- # Third attempt: Try to parse using regex
68
- return self._parse_with_regex(response_content)
 
 
 
 
 
 
 
 
 
69
 
70
  except Exception as e:
71
- print(f"Error parsing response: {str(e)}")
72
  print(f"Raw content: {response_content}")
73
  return self._create_error_dict(str(e))
74
 
75
- def _normalize_json_output(self, json_output: dict) -> dict:
76
- """Normalize JSON output to expected format."""
77
- return {
78
- "initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
79
- "refined_prompt": json_output.get("refined_prompt", ""),
80
- "explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
81
- "response_content": json_output
82
- }
83
-
84
  def _parse_with_regex(self, content: str) -> dict:
85
- """Parse content using regex patterns."""
86
  output = {}
87
  for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
88
  pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
@@ -93,7 +115,7 @@ class PromptRefiner:
93
  return output
94
 
95
  def _create_error_dict(self, error_message: str) -> dict:
96
- """Create standardized error response dictionary."""
97
  return {
98
  "initial_prompt_evaluation": f"Error parsing response: {error_message}",
99
  "refined_prompt": "",
@@ -101,4 +123,11 @@ class PromptRefiner:
101
  "response_content": {"error": error_message}
102
  }
103
 
104
- # Rest of your code remains the same...
 
 
 
 
 
 
 
 
1
  import json
2
  import re
3
+ from typing import Optional, Dict, Any, Tuple
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
7
+ from variables import meta_prompts, prompt_refiner_model
8
 
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
+ explanation_of_refinements: str = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
+ @validator('initial_prompt_evaluation', 'refined_prompt', 'explanation_of_refinements')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
 
 
 
 
 
 
 
 
21
  class PromptRefiner:
22
+ def __init__(self, api_token: str, meta_prompts: dict):
23
  self.client = InferenceClient(token=api_token, timeout=120)
24
  self.meta_prompts = meta_prompts
25
 
26
+ def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> Tuple[str, str, str, dict]:
27
+ """Refine the given prompt using the selected meta prompt."""
28
+ try:
29
+ selected_meta_prompt = self.meta_prompts.get(
30
+ meta_prompt_choice,
31
+ self.meta_prompts["star"]
32
+ )
33
+
34
+ messages = [
35
+ {
36
+ "role": "system",
37
+ "content": 'You are an expert at refining and extending prompts. Given a basic prompt, provide a more relevant and detailed prompt.'
38
+ },
39
+ {
40
+ "role": "user",
41
+ "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
42
+ }
43
+ ]
44
+
45
+ response = self.client.chat_completion(
46
+ model=prompt_refiner_model,
47
+ messages=messages,
48
+ max_tokens=3000,
49
+ temperature=0.8
50
+ )
51
+
52
+ response_content = response.choices[0].message.content.strip()
53
+ result = self._parse_response(response_content)
54
+
55
+ try:
56
+ llm_response = LLMResponse(**result)
57
+ return (
58
+ llm_response.initial_prompt_evaluation,
59
+ llm_response.refined_prompt,
60
+ llm_response.explanation_of_refinements,
61
+ llm_response.dict()
62
+ )
63
+ except Exception as e:
64
+ print(f"Error creating LLMResponse: {e}")
65
+ return self._create_error_response(f"Error validating response: {str(e)}")
66
 
67
+ except HfHubHTTPError as e:
68
+ return self._create_error_response("Model timeout. Please try again later.")
69
+ except Exception as e:
70
+ return self._create_error_response(f"Unexpected error: {str(e)}")
 
 
71
 
72
  def _parse_response(self, response_content: str) -> dict:
73
+ """Parse the LLM response content."""
74
  try:
75
+ # Try to extract JSON from <json> tags
76
+ json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
77
+ if json_match:
78
+ json_str = json_match.group(1).strip()
79
+ # Clean up the JSON string
80
+ json_str = re.sub(r'\s+', ' ', json_str)
81
+ json_str = json_str.replace('•', '*') # Replace bullet points
82
+
 
 
83
  try:
84
+ parsed_json = json.loads(json_str)
85
  if isinstance(parsed_json, str):
86
  parsed_json = json.loads(parsed_json)
87
+
88
+ return {
89
+ "initial_prompt_evaluation": parsed_json.get("initial_prompt_evaluation", ""),
90
+ "refined_prompt": parsed_json.get("refined_prompt", ""),
91
+ "explanation_of_refinements": parsed_json.get("explanation_of_refinements", ""),
92
+ "response_content": parsed_json
93
+ }
94
+ except json.JSONDecodeError as e:
95
+ print(f"JSON parsing error: {e}")
96
+ return self._create_error_dict(str(e))
97
+
98
+ # Fallback to regex parsing if JSON extraction fails
99
+ return self._parse_with_regex(response_content)
100
 
101
  except Exception as e:
102
+ print(f"Error parsing response: {e}")
103
  print(f"Raw content: {response_content}")
104
  return self._create_error_dict(str(e))
105
 
 
 
 
 
 
 
 
 
 
106
  def _parse_with_regex(self, content: str) -> dict:
107
+ """Parse content using regex patterns when JSON parsing fails."""
108
  output = {}
109
  for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
110
  pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
 
115
  return output
116
 
117
  def _create_error_dict(self, error_message: str) -> dict:
118
+ """Create a standardized error response dictionary."""
119
  return {
120
  "initial_prompt_evaluation": f"Error parsing response: {error_message}",
121
  "refined_prompt": "",
 
123
  "response_content": {"error": error_message}
124
  }
125
 
126
+ def _create_error_response(self, error_message: str) -> Tuple[str, str, str, dict]:
127
+ """Create a standardized error response tuple."""
128
+ return (
129
+ f"Error: {error_message}",
130
+ "The selected model is currently unavailable.",
131
+ "An error occurred during processing.",
132
+ {"error": error_message}
133
+ )