Spaces:
Sleeping
Sleeping
import gradio as gr | |
import jiwer | |
import pandas as pd | |
import logging | |
from typing import List, Optional, Tuple, Dict | |
from llama_cpp import Llama | |
import os | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
force=True, | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize LLM | |
MODEL_PATH = "./DeepSeek-R1-Distill-Qwen-1.5B-Q3_K_M.gguf" | |
try: | |
llm = Llama( | |
model_path=MODEL_PATH, | |
n_ctx=2048, # Context window | |
n_threads=4, # CPU threads | |
n_batch=512, # Batch size | |
verbose=False # Disable verbose output | |
) | |
logger.info("LLM initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize LLM: {str(e)}") | |
llm = None | |
def calculate_wer_metrics( | |
hypothesis: str, | |
reference: str, | |
normalize: bool = True, | |
words_to_filter: Optional[List[str]] = None | |
) -> Dict: | |
""" | |
Calculate WER metrics between hypothesis and reference texts. | |
Args: | |
hypothesis (str): The hypothesis text | |
reference (str): The reference text | |
normalize (bool): Whether to normalize texts before comparison | |
words_to_filter (List[str], optional): Words to filter out before comparison | |
Returns: | |
dict: Dictionary containing WER metrics | |
Raises: | |
ValueError: If inputs are invalid or result in empty text after processing | |
""" | |
logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}") | |
# Validate inputs | |
if not hypothesis.strip() or not reference.strip(): | |
raise ValueError("Both hypothesis and reference texts must contain non-empty strings") | |
if normalize: | |
# Define basic transformations | |
basic_transform = jiwer.Compose([ | |
jiwer.ExpandCommonEnglishContractions(), | |
jiwer.ToLowerCase(), | |
jiwer.RemoveMultipleSpaces(), | |
jiwer.RemovePunctuation(), | |
jiwer.Strip(), | |
jiwer.ReduceToListOfListOfWords() | |
]) | |
if words_to_filter and any(words_to_filter): | |
def filter_words_transform(words: List[str]) -> List[str]: | |
filtered = [word for word in words | |
if word.lower() not in [w.lower() for w in words_to_filter]] | |
if not filtered: | |
raise ValueError("Text is empty after filtering words") | |
return filtered | |
transformation = jiwer.Compose([ | |
basic_transform, | |
filter_words_transform | |
]) | |
else: | |
transformation = basic_transform | |
# Pre-check the transformed text | |
try: | |
transformed_ref = transformation(reference) | |
transformed_hyp = transformation(hypothesis) | |
if not transformed_ref or not transformed_hyp: | |
raise ValueError("Text is empty after normalization") | |
logger.debug(f"Transformed reference: {transformed_ref}") | |
logger.debug(f"Transformed hypothesis: {transformed_hyp}") | |
except Exception as e: | |
logger.error(f"Transformation error: {str(e)}") | |
raise ValueError(f"Error during text transformation: {str(e)}") | |
measures = jiwer.compute_measures( | |
truth=reference, | |
hypothesis=hypothesis, | |
truth_transform=transformation, | |
hypothesis_transform=transformation | |
) | |
else: | |
measures = jiwer.compute_measures( | |
truth=reference, | |
hypothesis=hypothesis | |
) | |
return measures | |
def extract_medical_terms(text: str) -> List[str]: | |
"""Extract medical terms from text using Qwen model.""" | |
if llm is None: | |
logger.error("LLM not initialized") | |
return [] | |
prompt = f"""Extract all medical terms from the following text. | |
Return only the medical terms as a comma-separated list. | |
Text: {text}""" | |
try: | |
response = llm( | |
prompt, | |
max_tokens=256, | |
temperature=0.1, | |
stop=["Text:", "\n\n"], | |
echo=False | |
) | |
response_text = response['choices'][0]['text'].strip() | |
# Remove thinking process if present | |
if '<think>' in response_text and '</think>' in response_text: | |
medical_terms_text = response_text.split('</think>')[-1].strip() | |
else: | |
medical_terms_text = response_text | |
medical_terms = [term.strip() for term in medical_terms_text.split(',')] | |
return [term for term in medical_terms if term and not term.startswith('<') and not term.endswith('>')] | |
except Exception as e: | |
logger.error(f"Error in medical term extraction: {str(e)}") | |
return [] | |
def calculate_medical_recall( | |
hypothesis_terms: List[str], | |
reference_terms: List[str] | |
) -> float: | |
""" | |
Calculate medical term recall rate. | |
Args: | |
hypothesis_terms (List[str]): Medical terms from hypothesis | |
reference_terms (List[str]): Medical terms from reference | |
Returns: | |
float: Recall rate | |
""" | |
if not reference_terms: | |
return 1.0 if not hypothesis_terms else 0.0 | |
correct_terms = set(hypothesis_terms) & set(reference_terms) | |
return len(correct_terms) / len(set(reference_terms)) | |
def process_inputs( | |
reference: str, | |
hypothesis: str, | |
normalize: bool, | |
words_to_filter: str | |
) -> Tuple[str, str, str, str]: | |
""" | |
Process inputs and calculate both WER and medical term recall metrics. | |
Args: | |
reference (str): Reference text | |
hypothesis (str): Hypothesis text | |
normalize (bool): Whether to normalize text | |
words_to_filter (str): Comma-separated words to filter | |
Returns: | |
Tuple[str, str, str, str]: HTML formatted main metrics, error analysis, | |
and explanations | |
""" | |
if not reference or not hypothesis: | |
return "Please provide both reference and hypothesis texts.", "", "", "" | |
try: | |
# Extract medical terms | |
logger.info("Extracting medical terms from reference text...") | |
reference_terms = extract_medical_terms(reference) | |
logger.info(f"Reference terms extracted: {reference_terms}") | |
logger.info("Extracting medical terms from hypothesis text...") | |
hypothesis_terms = extract_medical_terms(hypothesis) | |
logger.info(f"Hypothesis terms extracted: {hypothesis_terms}") | |
# Calculate medical recall | |
med_recall = calculate_medical_recall(hypothesis_terms, reference_terms) | |
# Calculate WER metrics | |
filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None | |
measures = calculate_wer_metrics( | |
hypothesis=hypothesis, | |
reference=reference, | |
normalize=normalize, | |
words_to_filter=filter_words | |
) | |
# Format metrics | |
metrics_df = pd.DataFrame({ | |
'Metric': ['WER', 'MER', 'WIL', 'WIP', 'Medical Term Recall'], | |
'Value': [ | |
f"{measures['wer']:.3f}", | |
f"{measures['mer']:.3f}", | |
f"{measures['wil']:.3f}", | |
f"{measures['wip']:.3f}", | |
f"{med_recall:.3f}" | |
] | |
}) | |
# Format error analysis | |
error_df = pd.DataFrame({ | |
'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'], | |
'Count': [ | |
measures['substitutions'], | |
measures['deletions'], | |
measures['insertions'], | |
measures['hits'] | |
] | |
}) | |
# Format medical terms comparison | |
med_terms_df = pd.DataFrame({ | |
'Source': ['Reference', 'Hypothesis'], | |
'Medical Terms': [ | |
', '.join(reference_terms), | |
', '.join(hypothesis_terms) | |
] | |
}) | |
metrics_html = metrics_df.to_html(index=False) | |
error_html = error_df.to_html(index=False) | |
med_terms_html = med_terms_df.to_html(index=False) | |
explanation = f""" | |
<h3>Metrics Explanation:</h3> | |
<ul> | |
<li><b>WER (Word Error Rate)</b>: The percentage of words that were incorrectly predicted</li> | |
<li><b>MER (Match Error Rate)</b>: The percentage of words that were incorrectly matched</li> | |
<li><b>WIL (Word Information Lost)</b>: The percentage of word information that was lost</li> | |
<li><b>WIP (Word Information Preserved)</b>: The percentage of word information that was preserved</li> | |
<li><b>Medical Term Recall</b>: The proportion of reference medical terms that were correctly identified in the hypothesis</li> | |
</ul> | |
<h3>Extracted Medical Terms:</h3> | |
{med_terms_html} | |
""" | |
return metrics_html, error_html, explanation, "" | |
except Exception as e: | |
error_msg = f"Error in processing: {str(e)}" | |
logger.error(error_msg) | |
return "", "", "", error_msg | |
def load_example() -> Tuple[str, str]: | |
"""Load example texts for demonstration.""" | |
return ( | |
"The patient shows signs of heart attack and hypertension.", | |
"The patient shows signs of heart attack and high blood pressure." | |
) | |
def create_interface() -> gr.Blocks: | |
"""Create the Gradio interface.""" | |
with gr.Blocks(title="WER Evaluation Tool") as interface: | |
gr.Markdown("# Word Error Rate (WER) Evaluation Tool") | |
gr.Markdown( | |
"This tool helps you evaluate the Word Error Rate (WER) between a reference " | |
"text and a hypothesis text. WER is commonly used in speech recognition and " | |
"machine translation evaluation." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
reference = gr.Textbox( | |
label="Reference Text", | |
placeholder="Enter the reference text here...", | |
lines=5 | |
) | |
with gr.Column(): | |
hypothesis = gr.Textbox( | |
label="Hypothesis Text", | |
placeholder="Enter the hypothesis text here...", | |
lines=5 | |
) | |
with gr.Row(): | |
normalize = gr.Checkbox( | |
label="Normalize text (lowercase, remove punctuation)", | |
value=True | |
) | |
words_to_filter = gr.Textbox( | |
label="Words to filter (comma-separated)", | |
placeholder="e.g., um, uh, ah" | |
) | |
with gr.Row(): | |
example_btn = gr.Button("Load Example") | |
calculate_btn = gr.Button("Calculate WER", variant="primary") | |
with gr.Row(): | |
metrics_output = gr.HTML(label="Main Metrics") | |
error_output = gr.HTML(label="Error Analysis") | |
explanation_output = gr.HTML() | |
error_msg_output = gr.HTML() | |
# Event handlers | |
example_btn.click( | |
load_example, | |
outputs=[reference, hypothesis] | |
) | |
calculate_btn.click( | |
process_inputs, | |
inputs=[reference, hypothesis, normalize, words_to_filter], | |
outputs=[metrics_output, error_output, explanation_output, error_msg_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
logger.info("Application started") | |
try: | |
app = create_interface() | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True | |
) | |
except Exception as e: | |
logger.error(f"Failed to launch application: {str(e)}") | |
raise | |