Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Added Salesforce endpoint
Browse files- gen_api_answer.py +51 -11
    	
        gen_api_answer.py
    CHANGED
    
    | @@ -23,7 +23,7 @@ together_client = Together() | |
| 23 | 
             
            hf_api_key = os.getenv("HF_API_KEY")
         | 
| 24 | 
             
            flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
         | 
| 25 | 
             
            cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
         | 
| 26 | 
            -
             | 
| 27 | 
             
            def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
         | 
| 28 | 
             
                """Get response from OpenAI API"""
         | 
| 29 | 
             
                try:
         | 
| @@ -195,6 +195,36 @@ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, m | |
| 195 | 
             
                except Exception as e:
         | 
| 196 | 
             
                    return f"Error with Cohere model {model_name}: {str(e)}"
         | 
| 197 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 198 | 
             
            def get_model_response(
         | 
| 199 | 
             
                model_name,
         | 
| 200 | 
             
                model_info,
         | 
| @@ -210,24 +240,25 @@ def get_model_response( | |
| 210 | 
             
                api_model = model_info["api_model"]
         | 
| 211 | 
             
                organization = model_info["organization"]
         | 
| 212 |  | 
| 213 | 
            -
                # Determine if model is Prometheus  | 
| 214 | 
             
                is_prometheus = (organization == "Prometheus")
         | 
| 215 | 
             
                is_atla = (organization == "Atla")
         | 
| 216 | 
             
                is_flow_judge = (organization == "Flow AI")
         | 
| 217 | 
            -
                 | 
| 218 | 
            -
                 | 
|  | |
|  | |
| 219 |  | 
| 220 | 
             
                # Select the appropriate base prompt
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                if is_atla:
         | 
| 223 | 
             
                    base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
         | 
| 224 | 
             
                elif is_flow_judge:
         | 
| 225 | 
             
                    base_prompt = FLOW_JUDGE_PROMPT
         | 
| 226 | 
             
                else:
         | 
| 227 | 
             
                    base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
         | 
| 228 |  | 
| 229 | 
            -
                # For non-Prometheus/non-Atla models, replace the  | 
| 230 | 
            -
                if not (is_prometheus or is_atla or is_flow_judge):
         | 
| 231 | 
             
                    base_prompt = base_prompt.replace(
         | 
| 232 | 
             
                        '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
         | 
| 233 | 
             
                        '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
         | 
| @@ -247,7 +278,6 @@ def get_model_response( | |
| 247 | 
             
                            score4_desc=prompt_data['score4_desc'],
         | 
| 248 | 
             
                            score5_desc=prompt_data['score5_desc']
         | 
| 249 | 
             
                        )
         | 
| 250 | 
            -
             | 
| 251 | 
             
                    else:
         | 
| 252 | 
             
                        human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
         | 
| 253 | 
             
                        ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
         | 
| @@ -300,8 +330,13 @@ def get_model_response( | |
| 300 | 
             
                        )
         | 
| 301 | 
             
                    elif organization == "Flow AI":
         | 
| 302 | 
             
                        return get_flow_judge_response(
         | 
| 303 | 
            -
                            api_model, final_prompt | 
| 304 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 305 | 
             
                    else:
         | 
| 306 | 
             
                        # All other organizations use Together API
         | 
| 307 | 
             
                        return get_together_response(
         | 
| @@ -324,7 +359,12 @@ def parse_model_response(response): | |
| 324 | 
             
                        data = json.loads(response)
         | 
| 325 | 
             
                        return str(data.get("result", "N/A")), data.get("feedback", "N/A")
         | 
| 326 | 
             
                    except json.JSONDecodeError:
         | 
| 327 | 
            -
                        # If that fails  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 328 | 
             
                        json_match = re.search(r"{.*}", response, re.DOTALL)
         | 
| 329 | 
             
                        if json_match:
         | 
| 330 | 
             
                            data = json.loads(json_match.group(0))
         | 
|  | |
| 23 | 
             
            hf_api_key = os.getenv("HF_API_KEY")
         | 
| 24 | 
             
            flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
         | 
| 25 | 
             
            cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
         | 
| 26 | 
            +
            salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
         | 
| 27 | 
             
            def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
         | 
| 28 | 
             
                """Get response from OpenAI API"""
         | 
| 29 | 
             
                try:
         | 
|  | |
| 195 | 
             
                except Exception as e:
         | 
| 196 | 
             
                    return f"Error with Cohere model {model_name}: {str(e)}"
         | 
| 197 |  | 
| 198 | 
            +
            def get_salesforce_response(model_name, prompt, system_prompt=None, max_tokens=2048, temperature=0):
         | 
| 199 | 
            +
                """Get response from Salesforce Research API"""
         | 
| 200 | 
            +
                try:
         | 
| 201 | 
            +
                    headers = {
         | 
| 202 | 
            +
                        'accept': 'application/json',
         | 
| 203 | 
            +
                        "content-type": "application/json",
         | 
| 204 | 
            +
                        "X-Api-Key": salesforce_api_key,
         | 
| 205 | 
            +
                    }
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # Create messages list
         | 
| 208 | 
            +
                    messages = []
         | 
| 209 | 
            +
                    messages.append({"role": "user", "content": prompt})
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    json_data = {
         | 
| 212 | 
            +
                        "prompts": messages,
         | 
| 213 | 
            +
                        "temperature": temperature,
         | 
| 214 | 
            +
                        "top_p": 1,
         | 
| 215 | 
            +
                        "max_tokens": max_tokens,
         | 
| 216 | 
            +
                    }
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    response = requests.post(
         | 
| 219 | 
            +
                        'https://gateway.salesforceresearch.ai/sfr-judge/process',
         | 
| 220 | 
            +
                        headers=headers,
         | 
| 221 | 
            +
                        json=json_data
         | 
| 222 | 
            +
                    )
         | 
| 223 | 
            +
                    response.raise_for_status()
         | 
| 224 | 
            +
                    return response.json()['result'][0]
         | 
| 225 | 
            +
                except Exception as e:
         | 
| 226 | 
            +
                    return f"Error with Salesforce model {model_name}: {str(e)}"
         | 
| 227 | 
            +
             | 
| 228 | 
             
            def get_model_response(
         | 
| 229 | 
             
                model_name,
         | 
| 230 | 
             
                model_info,
         | 
|  | |
| 240 | 
             
                api_model = model_info["api_model"]
         | 
| 241 | 
             
                organization = model_info["organization"]
         | 
| 242 |  | 
| 243 | 
            +
                # Determine if model is Prometheus, Atla, Flow Judge, or Salesforce
         | 
| 244 | 
             
                is_prometheus = (organization == "Prometheus")
         | 
| 245 | 
             
                is_atla = (organization == "Atla")
         | 
| 246 | 
             
                is_flow_judge = (organization == "Flow AI")
         | 
| 247 | 
            +
                is_salesforce = (organization == "Salesforce")  
         | 
| 248 | 
            +
                
         | 
| 249 | 
            +
                # For non-Prometheus/Atla/Flow Judge/Salesforce models, use the Judge system prompt
         | 
| 250 | 
            +
                system_prompt = None if (is_prometheus or is_atla or is_flow_judge or is_salesforce) else JUDGE_SYSTEM_PROMPT
         | 
| 251 |  | 
| 252 | 
             
                # Select the appropriate base prompt
         | 
| 253 | 
            +
                if is_atla or is_salesforce:  # Use same prompt for Atla and Salesforce
         | 
|  | |
| 254 | 
             
                    base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
         | 
| 255 | 
             
                elif is_flow_judge:
         | 
| 256 | 
             
                    base_prompt = FLOW_JUDGE_PROMPT
         | 
| 257 | 
             
                else:
         | 
| 258 | 
             
                    base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
         | 
| 259 |  | 
| 260 | 
            +
                # For non-Prometheus/non-Atla/non-Salesforce models, use Prometheus but replace the output format with JSON
         | 
| 261 | 
            +
                if not (is_prometheus or is_atla or is_flow_judge or is_salesforce):
         | 
| 262 | 
             
                    base_prompt = base_prompt.replace(
         | 
| 263 | 
             
                        '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
         | 
| 264 | 
             
                        '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
         | 
|  | |
| 278 | 
             
                            score4_desc=prompt_data['score4_desc'],
         | 
| 279 | 
             
                            score5_desc=prompt_data['score5_desc']
         | 
| 280 | 
             
                        )
         | 
|  | |
| 281 | 
             
                    else:
         | 
| 282 | 
             
                        human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
         | 
| 283 | 
             
                        ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
         | 
|  | |
| 330 | 
             
                        )
         | 
| 331 | 
             
                    elif organization == "Flow AI":
         | 
| 332 | 
             
                        return get_flow_judge_response(
         | 
| 333 | 
            +
                            api_model, final_prompt
         | 
| 334 | 
             
                        )
         | 
| 335 | 
            +
                    elif organization == "Salesforce":
         | 
| 336 | 
            +
                        response = get_salesforce_response(
         | 
| 337 | 
            +
                            api_model, final_prompt, system_prompt, max_tokens, temperature
         | 
| 338 | 
            +
                        )           
         | 
| 339 | 
            +
                        return response
         | 
| 340 | 
             
                    else:
         | 
| 341 | 
             
                        # All other organizations use Together API
         | 
| 342 | 
             
                        return get_together_response(
         | 
|  | |
| 359 | 
             
                        data = json.loads(response)
         | 
| 360 | 
             
                        return str(data.get("result", "N/A")), data.get("feedback", "N/A")
         | 
| 361 | 
             
                    except json.JSONDecodeError:
         | 
| 362 | 
            +
                        # If that fails, check if this is a Salesforce response (which uses ATLA format)
         | 
| 363 | 
            +
                        if "**Reasoning:**" in response or "**Result:**" in response:
         | 
| 364 | 
            +
                            # Use ATLA parser for Salesforce responses
         | 
| 365 | 
            +
                            return atla_parse_model_response(response)
         | 
| 366 | 
            +
                            
         | 
| 367 | 
            +
                        # Otherwise try to find JSON within the response
         | 
| 368 | 
             
                        json_match = re.search(r"{.*}", response, re.DOTALL)
         | 
| 369 | 
             
                        if json_match:
         | 
| 370 | 
             
                            data = json.loads(json_match.group(0))
         | 

