baconnier commited on
Commit
51d80c4
1 Parent(s): 0b54c30

Update prompt_refiner.py

Browse files
Files changed (1) hide show
  1. prompt_refiner.py +54 -34
prompt_refiner.py CHANGED
@@ -1,11 +1,25 @@
1
  import json
2
  import re
 
 
3
  from huggingface_hub import InferenceClient
4
  from huggingface_hub.errors import HfHubHTTPError
5
- from variables import meta_prompts, prompt_refiner_model
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class PromptRefiner:
8
- def __init__(self, api_token: str,meta_prompts):
9
  self.client = InferenceClient(token=api_token, timeout=120)
10
  self.meta_prompts = meta_prompts
11
 
@@ -13,11 +27,9 @@ class PromptRefiner:
13
  try:
14
  selected_meta_prompt = self.meta_prompts.get(
15
  meta_prompt_choice,
16
- self.meta_prompts["star"] # Default to "star" if choice not found
17
  )
18
- #print('_'*100)
19
- #print(selected_meta_prompt)
20
- #print('°°'*100)
21
  messages = [
22
  {
23
  "role": "system",
@@ -28,8 +40,7 @@ class PromptRefiner:
28
  "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
29
  }
30
  ]
31
- #print(messages)
32
- #print('°°'*100)
33
  response = self.client.chat_completion(
34
  model=prompt_refiner_model,
35
  messages=messages,
@@ -38,33 +49,40 @@ class PromptRefiner:
38
  )
39
 
40
  response_content = response.choices[0].message.content.strip()
41
-
42
  result = self._parse_response(response_content)
43
 
 
 
 
44
  return (
45
- result.get('initial_prompt_evaluation', ''),
46
- result.get('refined_prompt', ''),
47
- result.get('explanation_of_refinements', ''),
48
- result
49
  )
50
 
51
  except HfHubHTTPError as e:
52
- return (
53
- "Error: Model timeout. Please try again later.",
54
- "The selected model is currently experiencing high traffic.",
55
- "The selected model is currently experiencing high traffic.",
56
- {}
57
- )
58
  except Exception as e:
59
- return (
60
- f"Error: {str(e)}",
61
- "",
62
- "An unexpected error occurred.",
63
- {}
64
- )
 
 
 
 
 
 
 
 
 
65
 
66
  def _parse_response(self, response_content: str) -> dict:
67
  try:
 
68
  json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
69
  if json_match:
70
  json_str = json_match.group(1)
@@ -74,19 +92,22 @@ class PromptRefiner:
74
 
75
  if isinstance(json_output, str):
76
  json_output = json.loads(json_output)
77
- output = {
78
- key: value.replace('\\"', '"') if isinstance(value, str) else value
79
- for key, value in json_output.items()
 
 
 
80
  }
81
- output['response_content'] = json_output
82
- return output
83
 
 
84
  output = {}
85
  for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
86
  pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
87
  match = re.search(pattern, response_content, re.DOTALL)
88
- output[key] = match.group(1).replace('\\n', '\n').replace('\\"', '"') if match else ""
89
- output['response_content'] = response_content
 
90
  return output
91
 
92
  except (json.JSONDecodeError, ValueError) as e:
@@ -96,7 +117,7 @@ class PromptRefiner:
96
  "initial_prompt_evaluation": "Error parsing response",
97
  "refined_prompt": "",
98
  "explanation_of_refinements": str(e),
99
- 'response_content': str(e)
100
  }
101
 
102
  def apply_prompt(self, prompt: str, model: str) -> str:
@@ -140,7 +161,6 @@ class PromptRefiner:
140
  )
141
 
142
  full_response = ""
143
-
144
  for chunk in response:
145
  if chunk.choices[0].delta.content is not None:
146
  full_response += chunk.choices[0].delta.content
 
1
  import json
2
  import re
3
+ from typing import Optional, Dict, Any
4
+ from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
7
+ from variables import *
8
+
9
+ class LLMResponse(BaseModel):
10
+ initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
+ refined_prompt: str = Field(..., description="The refined version of the prompt")
12
+ explanation_of_refinements: str = Field(..., description="Explanation of the refinements made")
13
+ response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
+
15
+ @validator('initial_prompt_evaluation', 'refined_prompt', 'explanation_of_refinements')
16
+ def clean_text_fields(cls, v):
17
+ if isinstance(v, str):
18
+ return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
+ return v
20
 
21
  class PromptRefiner:
22
+ def __init__(self, api_token: str, meta_prompts):
23
  self.client = InferenceClient(token=api_token, timeout=120)
24
  self.meta_prompts = meta_prompts
25
 
 
27
  try:
28
  selected_meta_prompt = self.meta_prompts.get(
29
  meta_prompt_choice,
30
+ self.meta_prompts["star"]
31
  )
32
+
 
 
33
  messages = [
34
  {
35
  "role": "system",
 
40
  "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
41
  }
42
  ]
43
+
 
44
  response = self.client.chat_completion(
45
  model=prompt_refiner_model,
46
  messages=messages,
 
49
  )
50
 
51
  response_content = response.choices[0].message.content.strip()
 
52
  result = self._parse_response(response_content)
53
 
54
+ # Create and validate LLMResponse
55
+ llm_response = LLMResponse(**result)
56
+
57
  return (
58
+ llm_response.initial_prompt_evaluation,
59
+ llm_response.refined_prompt,
60
+ llm_response.explanation_of_refinements,
61
+ llm_response.dict()
62
  )
63
 
64
  except HfHubHTTPError as e:
65
+ return self._create_error_response("Model timeout. Please try again later.")
 
 
 
 
 
66
  except Exception as e:
67
+ return self._create_error_response(f"Unexpected error: {str(e)}")
68
+
69
+ def _create_error_response(self, error_message: str) -> tuple:
70
+ error_response = LLMResponse(
71
+ initial_prompt_evaluation=f"Error: {error_message}",
72
+ refined_prompt="The selected model is currently unavailable.",
73
+ explanation_of_refinements="An error occurred during processing.",
74
+ response_content={"error": error_message}
75
+ )
76
+ return (
77
+ error_response.initial_prompt_evaluation,
78
+ error_response.refined_prompt,
79
+ error_response.explanation_of_refinements,
80
+ error_response.dict()
81
+ )
82
 
83
  def _parse_response(self, response_content: str) -> dict:
84
  try:
85
+ # First attempt: Try to extract JSON from <json> tags
86
  json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
87
  if json_match:
88
  json_str = json_match.group(1)
 
92
 
93
  if isinstance(json_output, str):
94
  json_output = json.loads(json_output)
95
+
96
+ return {
97
+ "initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
98
+ "refined_prompt": json_output.get("refined_prompt", ""),
99
+ "explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
100
+ "response_content": json_output
101
  }
 
 
102
 
103
+ # Second attempt: Try to extract fields using regex
104
  output = {}
105
  for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
106
  pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
107
  match = re.search(pattern, response_content, re.DOTALL)
108
+ output[key] = match.group(1) if match else ""
109
+
110
+ output["response_content"] = response_content
111
  return output
112
 
113
  except (json.JSONDecodeError, ValueError) as e:
 
117
  "initial_prompt_evaluation": "Error parsing response",
118
  "refined_prompt": "",
119
  "explanation_of_refinements": str(e),
120
+ "response_content": str(e)
121
  }
122
 
123
  def apply_prompt(self, prompt: str, model: str) -> str:
 
161
  )
162
 
163
  full_response = ""
 
164
  for chunk in response:
165
  if chunk.choices[0].delta.content is not None:
166
  full_response += chunk.choices[0].delta.content