oceansweep
commited on
Upload ms_g_eval.py
Browse files
App_Function_Libraries/ms_g_eval.py
CHANGED
@@ -7,13 +7,13 @@
|
|
7 |
# Scripts taken from https://github.com/microsoft/promptflow/tree/main/examples/flows/evaluation/eval-summarization and modified.
|
8 |
#
|
9 |
import configparser
|
10 |
-
import json
|
11 |
-
|
12 |
-
import gradio as gr
|
13 |
import inspect
|
|
|
14 |
import logging
|
15 |
import re
|
16 |
from typing import Dict, Callable, List, Any
|
|
|
|
|
17 |
from tenacity import (
|
18 |
RetryError,
|
19 |
Retrying,
|
@@ -23,12 +23,7 @@ from tenacity import (
|
|
23 |
wait_random_exponential,
|
24 |
)
|
25 |
|
26 |
-
from App_Function_Libraries.
|
27 |
-
summarize_with_oobabooga, summarize_with_tabbyapi, summarize_with_vllm, summarize_with_local_llm, \
|
28 |
-
summarize_with_ollama
|
29 |
-
from App_Function_Libraries.Summarization_General_Lib import summarize_with_openai, summarize_with_anthropic, \
|
30 |
-
summarize_with_cohere, summarize_with_groq, summarize_with_openrouter, summarize_with_deepseek, \
|
31 |
-
summarize_with_huggingface, summarize_with_mistral
|
32 |
|
33 |
#
|
34 |
#######################################################################################################################
|
@@ -290,11 +285,10 @@ def parse_output(output: str, max: float) -> float:
|
|
290 |
def geval_summarization(
|
291 |
prompt_with_src_and_gen: str,
|
292 |
max_score: float,
|
293 |
-
|
294 |
api_key: str,
|
295 |
) -> float:
|
296 |
-
|
297 |
-
model = get_model_from_config(api_name)
|
298 |
|
299 |
try:
|
300 |
for attempt in Retrying(
|
@@ -305,9 +299,16 @@ def geval_summarization(
|
|
305 |
stop=stop_after_attempt(10),
|
306 |
):
|
307 |
with attempt:
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
except RetryError:
|
310 |
-
logger.exception(f"geval {
|
311 |
raise
|
312 |
|
313 |
try:
|
@@ -319,31 +320,6 @@ def geval_summarization(
|
|
319 |
return score
|
320 |
|
321 |
|
322 |
-
def get_summarize_function(api_name: str):
|
323 |
-
summarize_functions = {
|
324 |
-
"openai": summarize_with_openai,
|
325 |
-
"anthropic": summarize_with_anthropic,
|
326 |
-
"cohere": summarize_with_cohere,
|
327 |
-
"groq": summarize_with_groq,
|
328 |
-
"openrouter": summarize_with_openrouter,
|
329 |
-
"deepseek": summarize_with_deepseek,
|
330 |
-
"huggingface": summarize_with_huggingface,
|
331 |
-
"mistral": summarize_with_mistral,
|
332 |
-
"llama.cpp": summarize_with_llama,
|
333 |
-
"kobold": summarize_with_kobold,
|
334 |
-
"ooba": summarize_with_oobabooga,
|
335 |
-
"tabbyapi": summarize_with_tabbyapi,
|
336 |
-
"vllm": summarize_with_vllm,
|
337 |
-
"local-llm": summarize_with_local_llm,
|
338 |
-
"ollama": summarize_with_ollama
|
339 |
-
}
|
340 |
-
api_name_lower = api_name.lower()
|
341 |
-
if api_name_lower not in summarize_functions:
|
342 |
-
raise ValueError(f"Unsupported API: {api_name}")
|
343 |
-
return summarize_functions[api_name_lower]
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
def get_model_from_config(api_name: str) -> str:
|
348 |
model = config.get('models', api_name)
|
349 |
if isinstance(model, dict):
|
@@ -397,8 +373,6 @@ def validate_inputs(document: str, summary: str, api_name: str, api_key: str) ->
|
|
397 |
if api_name.lower() not in ["openai", "anthropic", "cohere", "groq", "openrouter", "deepseek", "huggingface",
|
398 |
"mistral", "llama.cpp", "kobold", "ooba", "tabbyapi", "vllm", "local-llm", "ollama"]:
|
399 |
raise ValueError(f"Unsupported API: {api_name}")
|
400 |
-
if not api_key.strip() and api_name.lower() not in ["local-llm", "ollama"]:
|
401 |
-
raise ValueError("API key is required for non-local APIs")
|
402 |
|
403 |
|
404 |
def detailed_api_error(api_name: str, error: Exception) -> str:
|
@@ -430,6 +404,10 @@ def save_eval_results(results: Dict[str, Any], filename: str = "geval_results.js
|
|
430 |
print(f"Results saved to {filename}")
|
431 |
|
432 |
|
|
|
|
|
|
|
|
|
433 |
#######################################################################################################################
|
434 |
#
|
435 |
# Taken from: https://github.com/microsoft/promptflow/blob/b5a68f45e4c3818a29e2f79a76f2e73b8ea6be44/src/promptflow-core/promptflow/_core/metric_logger.py
|
|
|
7 |
# Scripts taken from https://github.com/microsoft/promptflow/tree/main/examples/flows/evaluation/eval-summarization and modified.
|
8 |
#
|
9 |
import configparser
|
|
|
|
|
|
|
10 |
import inspect
|
11 |
+
import json
|
12 |
import logging
|
13 |
import re
|
14 |
from typing import Dict, Callable, List, Any
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
from tenacity import (
|
18 |
RetryError,
|
19 |
Retrying,
|
|
|
23 |
wait_random_exponential,
|
24 |
)
|
25 |
|
26 |
+
from App_Function_Libraries.Chat import chat_api_call
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
#
|
29 |
#######################################################################################################################
|
|
|
285 |
def geval_summarization(
|
286 |
prompt_with_src_and_gen: str,
|
287 |
max_score: float,
|
288 |
+
api_endpoint: str,
|
289 |
api_key: str,
|
290 |
) -> float:
|
291 |
+
model = get_model_from_config(api_endpoint)
|
|
|
292 |
|
293 |
try:
|
294 |
for attempt in Retrying(
|
|
|
299 |
stop=stop_after_attempt(10),
|
300 |
):
|
301 |
with attempt:
|
302 |
+
system_message="You are a helpful AI assistant"
|
303 |
+
# TEMP setting for Confabulation check
|
304 |
+
temp = 0.7
|
305 |
+
logging.info(f"Debug - geval_summarization Function - API Endpoint: {api_endpoint}")
|
306 |
+
try:
|
307 |
+
response = chat_api_call(api_endpoint, api_key, prompt_with_src_and_gen, "", temp, system_message)
|
308 |
+
except Exception as e:
|
309 |
+
raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
|
310 |
except RetryError:
|
311 |
+
logger.exception(f"geval {api_endpoint} call failed\nInput prompt was: {prompt_with_src_and_gen}")
|
312 |
raise
|
313 |
|
314 |
try:
|
|
|
320 |
return score
|
321 |
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
def get_model_from_config(api_name: str) -> str:
|
324 |
model = config.get('models', api_name)
|
325 |
if isinstance(model, dict):
|
|
|
373 |
if api_name.lower() not in ["openai", "anthropic", "cohere", "groq", "openrouter", "deepseek", "huggingface",
|
374 |
"mistral", "llama.cpp", "kobold", "ooba", "tabbyapi", "vllm", "local-llm", "ollama"]:
|
375 |
raise ValueError(f"Unsupported API: {api_name}")
|
|
|
|
|
376 |
|
377 |
|
378 |
def detailed_api_error(api_name: str, error: Exception) -> str:
|
|
|
404 |
print(f"Results saved to {filename}")
|
405 |
|
406 |
|
407 |
+
|
408 |
+
|
409 |
+
#
|
410 |
+
#
|
411 |
#######################################################################################################################
|
412 |
#
|
413 |
# Taken from: https://github.com/microsoft/promptflow/blob/b5a68f45e4c3818a29e2f79a76f2e73b8ea6be44/src/promptflow-core/promptflow/_core/metric_logger.py
|