oceansweep commited on
Commit
17c7477
·
verified ·
1 Parent(s): 623c851

Upload ms_g_eval.py

Browse files
Files changed (1) hide show
  1. App_Function_Libraries/ms_g_eval.py +19 -41
App_Function_Libraries/ms_g_eval.py CHANGED
@@ -7,13 +7,13 @@
7
  # Scripts taken from https://github.com/microsoft/promptflow/tree/main/examples/flows/evaluation/eval-summarization and modified.
8
  #
9
  import configparser
10
- import json
11
-
12
- import gradio as gr
13
  import inspect
 
14
  import logging
15
  import re
16
  from typing import Dict, Callable, List, Any
 
 
17
  from tenacity import (
18
  RetryError,
19
  Retrying,
@@ -23,12 +23,7 @@ from tenacity import (
23
  wait_random_exponential,
24
  )
25
 
26
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_llama, summarize_with_kobold, \
27
- summarize_with_oobabooga, summarize_with_tabbyapi, summarize_with_vllm, summarize_with_local_llm, \
28
- summarize_with_ollama
29
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_openai, summarize_with_anthropic, \
30
- summarize_with_cohere, summarize_with_groq, summarize_with_openrouter, summarize_with_deepseek, \
31
- summarize_with_huggingface, summarize_with_mistral
32
 
33
  #
34
  #######################################################################################################################
@@ -290,11 +285,10 @@ def parse_output(output: str, max: float) -> float:
290
  def geval_summarization(
291
  prompt_with_src_and_gen: str,
292
  max_score: float,
293
- api_name: str,
294
  api_key: str,
295
  ) -> float:
296
- summarize_function = get_summarize_function(api_name)
297
- model = get_model_from_config(api_name)
298
 
299
  try:
300
  for attempt in Retrying(
@@ -305,9 +299,16 @@ def geval_summarization(
305
  stop=stop_after_attempt(10),
306
  ):
307
  with attempt:
308
- response = summarize_function(api_key, prompt_with_src_and_gen, "", temp=0.7, system_prompt="You are a helpful AI assistant")
 
 
 
 
 
 
 
309
  except RetryError:
310
- logger.exception(f"geval {api_name} call failed\nInput prompt was: {prompt_with_src_and_gen}")
311
  raise
312
 
313
  try:
@@ -319,31 +320,6 @@ def geval_summarization(
319
  return score
320
 
321
 
322
- def get_summarize_function(api_name: str):
323
- summarize_functions = {
324
- "openai": summarize_with_openai,
325
- "anthropic": summarize_with_anthropic,
326
- "cohere": summarize_with_cohere,
327
- "groq": summarize_with_groq,
328
- "openrouter": summarize_with_openrouter,
329
- "deepseek": summarize_with_deepseek,
330
- "huggingface": summarize_with_huggingface,
331
- "mistral": summarize_with_mistral,
332
- "llama.cpp": summarize_with_llama,
333
- "kobold": summarize_with_kobold,
334
- "ooba": summarize_with_oobabooga,
335
- "tabbyapi": summarize_with_tabbyapi,
336
- "vllm": summarize_with_vllm,
337
- "local-llm": summarize_with_local_llm,
338
- "ollama": summarize_with_ollama
339
- }
340
- api_name_lower = api_name.lower()
341
- if api_name_lower not in summarize_functions:
342
- raise ValueError(f"Unsupported API: {api_name}")
343
- return summarize_functions[api_name_lower]
344
-
345
-
346
-
347
  def get_model_from_config(api_name: str) -> str:
348
  model = config.get('models', api_name)
349
  if isinstance(model, dict):
@@ -397,8 +373,6 @@ def validate_inputs(document: str, summary: str, api_name: str, api_key: str) ->
397
  if api_name.lower() not in ["openai", "anthropic", "cohere", "groq", "openrouter", "deepseek", "huggingface",
398
  "mistral", "llama.cpp", "kobold", "ooba", "tabbyapi", "vllm", "local-llm", "ollama"]:
399
  raise ValueError(f"Unsupported API: {api_name}")
400
- if not api_key.strip() and api_name.lower() not in ["local-llm", "ollama"]:
401
- raise ValueError("API key is required for non-local APIs")
402
 
403
 
404
  def detailed_api_error(api_name: str, error: Exception) -> str:
@@ -430,6 +404,10 @@ def save_eval_results(results: Dict[str, Any], filename: str = "geval_results.js
430
  print(f"Results saved to {filename}")
431
 
432
 
 
 
 
 
433
  #######################################################################################################################
434
  #
435
  # Taken from: https://github.com/microsoft/promptflow/blob/b5a68f45e4c3818a29e2f79a76f2e73b8ea6be44/src/promptflow-core/promptflow/_core/metric_logger.py
 
7
  # Scripts taken from https://github.com/microsoft/promptflow/tree/main/examples/flows/evaluation/eval-summarization and modified.
8
  #
9
  import configparser
 
 
 
10
  import inspect
11
+ import json
12
  import logging
13
  import re
14
  from typing import Dict, Callable, List, Any
15
+
16
+ import gradio as gr
17
  from tenacity import (
18
  RetryError,
19
  Retrying,
 
23
  wait_random_exponential,
24
  )
25
 
26
+ from App_Function_Libraries.Chat import chat_api_call
 
 
 
 
 
27
 
28
  #
29
  #######################################################################################################################
 
285
  def geval_summarization(
286
  prompt_with_src_and_gen: str,
287
  max_score: float,
288
+ api_endpoint: str,
289
  api_key: str,
290
  ) -> float:
291
+ model = get_model_from_config(api_endpoint)
 
292
 
293
  try:
294
  for attempt in Retrying(
 
299
  stop=stop_after_attempt(10),
300
  ):
301
  with attempt:
302
+ system_message="You are a helpful AI assistant"
303
+ # TEMP setting for Confabulation check
304
+ temp = 0.7
305
+ logging.info(f"Debug - geval_summarization Function - API Endpoint: {api_endpoint}")
306
+ try:
307
+ response = chat_api_call(api_endpoint, api_key, prompt_with_src_and_gen, "", temp, system_message)
308
+ except Exception as e:
309
+ raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
310
  except RetryError:
311
+ logger.exception(f"geval {api_endpoint} call failed\nInput prompt was: {prompt_with_src_and_gen}")
312
  raise
313
 
314
  try:
 
320
  return score
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def get_model_from_config(api_name: str) -> str:
324
  model = config.get('models', api_name)
325
  if isinstance(model, dict):
 
373
  if api_name.lower() not in ["openai", "anthropic", "cohere", "groq", "openrouter", "deepseek", "huggingface",
374
  "mistral", "llama.cpp", "kobold", "ooba", "tabbyapi", "vllm", "local-llm", "ollama"]:
375
  raise ValueError(f"Unsupported API: {api_name}")
 
 
376
 
377
 
378
  def detailed_api_error(api_name: str, error: Exception) -> str:
 
404
  print(f"Results saved to {filename}")
405
 
406
 
407
+
408
+
409
+ #
410
+ #
411
  #######################################################################################################################
412
  #
413
  # Taken from: https://github.com/microsoft/promptflow/blob/b5a68f45e4c3818a29e2f79a76f2e73b8ea6be44/src/promptflow-core/promptflow/_core/metric_logger.py