Spaces:
Sleeping
Sleeping
| import atexit | |
| import functools | |
| import base64 | |
| import io | |
| import re | |
| import os | |
| import tempfile | |
| from queue import Queue | |
| from threading import Event, Thread | |
| import numpy as np | |
| from paddleocr import PaddleOCR, draw_ocr | |
| from PIL import Image | |
| import gradio as gr | |
| import fasttext | |
| # 加载fasttext语言检测模型 | |
| # 首次运行时会自动下载模型 | |
| try: | |
| # 检查模型文件是否存在 | |
| model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lid.176.bin") | |
| if not os.path.exists(model_path): | |
| # 如果模型不存在,则下载 | |
| import urllib.request | |
| print("下载fasttext语言检测模型...") | |
| urllib.request.urlretrieve( | |
| "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", | |
| model_path | |
| ) | |
| # 加载模型 | |
| lang_model = fasttext.load_model(model_path) | |
| print("fasttext语言检测模型加载成功") | |
| except Exception as e: | |
| print(f"警告: 无法加载fasttext模型: {e}") | |
| lang_model = None | |
| LANG_CONFIG = { | |
| "ch": {"num_workers": 2}, | |
| "en": {"num_workers": 2}, | |
| "fr": {"num_workers": 1}, | |
| "german": {"num_workers": 1}, | |
| "korean": {"num_workers": 1}, | |
| "japan": {"num_workers": 1}, | |
| } | |
| # 语言映射表 | |
| LANG_MAP = { | |
| "ch": "中文", | |
| "en": "英文", | |
| "fr": "法语", | |
| "german": "德语", | |
| "korean": "韩语", | |
| "japan": "日语", | |
| } | |
| # fasttext语言代码到PaddleOCR语言代码的映射 | |
| FASTTEXT_TO_PADDLE = { | |
| "zh": "ch", # 中文 | |
| "en": "en", # 英文 | |
| "fr": "fr", # 法语 | |
| "de": "german", # 德语 | |
| "ko": "korean", # 韩语 | |
| "ja": "japan", # 日语 | |
| } | |
| # 语言特征检测 - 用于备用语言检测 | |
| LANG_FEATURES = { | |
| "ch": set("的一是不了人我在有他这为之大来以个中上们"), | |
| "en": set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"), | |
| "fr": set("àâäæçéèêëîïôœùûüÿÀÂÄÆÇÉÈÊËÎÏÔŒÙÛÜŸ"), | |
| "german": set("äöüßÄÖÜ"), | |
| "japan": set("あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもやゆよらりるれろわをんアイウエオカキクケコサシスセソタチツテト") | |
| } | |
| CONCURRENCY_LIMIT = 8 | |
| class PaddleOCRModelManager(object): | |
| def __init__(self, | |
| num_workers, | |
| model_factory): | |
| super().__init__() | |
| self._model_factory = model_factory | |
| self._queue = Queue() | |
| self._workers = [] | |
| self._model_initialized_event = Event() | |
| for _ in range(num_workers): | |
| worker = Thread(target=self._worker, daemon=False) | |
| worker.start() | |
| self._model_initialized_event.wait() | |
| self._model_initialized_event.clear() | |
| self._workers.append(worker) | |
| def infer(self, *args, **kwargs): | |
| # XXX: Should I use a more lightweight data structure, say, a future? | |
| result_queue = Queue(maxsize=1) | |
| self._queue.put((args, kwargs, result_queue)) | |
| success, payload = result_queue.get() | |
| if success: | |
| return payload | |
| else: | |
| raise payload | |
| def close(self): | |
| for _ in self._workers: | |
| self._queue.put(None) | |
| for worker in self._workers: | |
| worker.join() | |
| def _worker(self): | |
| model = self._model_factory() | |
| self._model_initialized_event.set() | |
| while True: | |
| item = self._queue.get() | |
| if item is None: | |
| break | |
| args, kwargs, result_queue = item | |
| try: | |
| result = model.ocr(*args, **kwargs) | |
| result_queue.put((True, result)) | |
| except Exception as e: | |
| result_queue.put((False, e)) | |
| finally: | |
| self._queue.task_done() | |
| def create_model(lang): | |
| # 为中文模型添加特殊参数,提高中文识别准确性 | |
| if lang == "ch": | |
| # 不指定自定义字典路径,使用PaddleOCR内置的默认字典 | |
| return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False) | |
| else: | |
| return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False) | |
| # 预先加载所有语言的模型 | |
| print("正在初始化多语言OCR模型...") | |
| model_managers = {} | |
| for lang, config in LANG_CONFIG.items(): | |
| print(f"加载 {LANG_MAP.get(lang, lang)} 模型...") | |
| model_manager = PaddleOCRModelManager(config["num_workers"], functools.partial(create_model, lang=lang)) | |
| model_managers[lang] = model_manager | |
| print("所有OCR模型加载完成") | |
| def close_model_managers(): | |
| for manager in model_managers.values(): | |
| manager.close() | |
| # XXX: Not sure if gradio allows adding custom teardown logic | |
| atexit.register(close_model_managers) | |
| def detect_language_by_features(text): | |
| """基于特征字符集检测语言""" | |
| if not text: | |
| return "en" | |
| # 计算每种语言的特征字符出现比例 | |
| lang_scores = {} | |
| for lang, char_set in LANG_FEATURES.items(): | |
| if not char_set: | |
| continue | |
| # 计算文本中该语言特征字符的数量 | |
| count = sum(1 for char in text if char in char_set) | |
| if count > 0: | |
| lang_scores[lang] = count / len(text) | |
| # 特殊处理韩语(通过Unicode范围检测) | |
| korean_count = sum(1 for char in text if '\uac00' <= char <= '\ud7a3') | |
| if korean_count > 0: | |
| lang_scores["korean"] = korean_count / len(text) | |
| # 如果没有检测到任何语言特征,默认为英语 | |
| if not lang_scores: | |
| return "en" | |
| # 返回特征比例最高的语言 | |
| return max(lang_scores.items(), key=lambda x: x[1])[0] | |
| def detect_language_with_fasttext(text): | |
| """使用fasttext检测语言""" | |
| if not text or not text.strip(): | |
| return "en" | |
| if lang_model is None: | |
| # 如果fasttext模型加载失败,使用基于特征的检测 | |
| return detect_language_by_features(text) | |
| try: | |
| # 预处理文本,保留一定长度 | |
| text = text[:1000] # 限制文本长度,提高效率 | |
| # 使用fasttext预测语言 | |
| predictions = lang_model.predict(text.replace('\n', ' ')) | |
| lang_code = predictions[0][0].replace('__label__', '') | |
| # 映射到PaddleOCR支持的语言 | |
| paddle_lang = FASTTEXT_TO_PADDLE.get(lang_code, None) | |
| # 如果无法映射,使用基于特征的检测作为备用 | |
| if paddle_lang is None: | |
| return detect_language_by_features(text) | |
| return paddle_lang | |
| except Exception as e: | |
| print(f"语言检测错误: {e}") | |
| # 出错时使用基于特征的检测作为备用 | |
| return detect_language_by_features(text) | |
| def try_all_languages(image_path): | |
| """尝试所有语言的OCR,返回最佳结果""" | |
| best_result = None | |
| best_lang = "en" | |
| max_text_length = 0 | |
| # 尝试所有语言 | |
| for lang in LANG_CONFIG.keys(): | |
| try: | |
| ocr = model_managers[lang] | |
| result = ocr.infer(image_path, cls=True)[0] | |
| if result: | |
| # 提取所有文本 | |
| all_text = " ".join([line[1][0] for line in result]) | |
| text_length = len(all_text.strip()) | |
| # 如果这个语言提取的文本更多,认为它更可能是正确的语言 | |
| if text_length > max_text_length: | |
| max_text_length = text_length | |
| best_result = result | |
| best_lang = lang | |
| # 如果是中文且提取了足够多的文本,直接返回 | |
| if lang == "ch" and text_length > 10: | |
| return result, lang | |
| except Exception as e: | |
| print(f"OCR处理错误 ({lang}): {e}") | |
| continue | |
| return best_result, best_lang | |
| def auto_detect_language(image_path): | |
| """使用多模型投票的方式检测语言""" | |
| # 先尝试中文和英文模型 | |
| languages_to_try = ["ch", "en"] | |
| results = {} | |
| detected_texts = {} | |
| for lang in languages_to_try: | |
| try: | |
| ocr = model_managers[lang] | |
| result = ocr.infer(image_path, cls=True)[0] | |
| if result: | |
| # 提取所有文本 | |
| all_text = " ".join([line[1][0] for line in result]) | |
| detected_texts[lang] = all_text | |
| if all_text.strip(): | |
| # 使用fasttext检测语言 | |
| detected = detect_language_with_fasttext(all_text) | |
| results[detected] = results.get(detected, 0) + 1 | |
| except Exception as e: | |
| print(f"OCR处理错误 ({lang}): {e}") | |
| continue | |
| # 检查中文结果是否包含足够的中文字符 | |
| if "ch" in detected_texts: | |
| chinese_chars = sum(1 for char in detected_texts["ch"] if '\u4e00' <= char <= '\u9fff') | |
| if chinese_chars > 5: # 如果有超过5个中文字符 | |
| return "ch" | |
| # 如果没有检测结果或者结果不可靠,尝试所有语言 | |
| if not results: | |
| print("无法可靠检测语言,尝试所有语言...") | |
| _, best_lang = try_all_languages(image_path) | |
| return best_lang | |
| # 返回得票最多的语言 | |
| return max(results.items(), key=lambda x: x[1])[0] | |
| def save_base64_to_temp_file(base64_string): | |
| """将Base64图像保存为临时文件""" | |
| try: | |
| # 移除可能的前缀 | |
| if "base64," in base64_string: | |
| base64_string = base64_string.split("base64,")[1] | |
| # 解码Base64 | |
| image_data = base64.b64decode(base64_string) | |
| # 创建临时文件 | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| temp_file.write(image_data) | |
| temp_file.close() | |
| return temp_file.name | |
| except Exception as e: | |
| raise ValueError(f"处理Base64图像时出错: {str(e)}") | |
| def inference(img, return_text_only=True): | |
| """OCR推理函数,自动检测语言""" | |
| temp_file = None | |
| try: | |
| # 处理输入图像 | |
| if isinstance(img, str): | |
| if img.startswith("data:") or re.match(r'^[A-Za-z0-9+/=]+$', img): | |
| # 处理Base64输入 | |
| temp_file = save_base64_to_temp_file(img) | |
| img_path = temp_file | |
| else: | |
| # 处理文件路径输入 | |
| img_path = img | |
| else: | |
| # 处理其他类型输入 | |
| img_path = img | |
| # 自动检测语言 | |
| lang = auto_detect_language(img_path) | |
| print(f"检测到的语言: {LANG_MAP.get(lang, lang)}") | |
| # 使用检测到的语言进行OCR | |
| ocr = model_managers[lang] | |
| result = ocr.infer(img_path, cls=True)[0] | |
| # 如果结果为空或很少,尝试所有语言 | |
| all_text = " ".join([line[1][0] for line in result]) | |
| if len(all_text.strip()) < 5: | |
| print("识别结果太少,尝试所有语言...") | |
| result, lang = try_all_languages(img_path) | |
| print(f"最佳语言: {LANG_MAP.get(lang, lang)}") | |
| # 提取文本和位置信息 | |
| boxes = [line[0] for line in result] | |
| txts = [line[1][0] for line in result] | |
| scores = [line[1][1] for line in result] | |
| # 读取图像用于绘制 | |
| pil_img = Image.open(img_path).convert("RGB") | |
| # 尝试查找simfang.ttf字体文件 | |
| font_path = "./simfang.ttf" | |
| if not os.path.exists(font_path): | |
| # 尝试在其他可能的位置查找 | |
| possible_paths = [ | |
| "./doc/fonts/simfang.ttf", | |
| "/usr/local/lib/python3.10/site-packages/paddleocr/doc/fonts/simfang.ttf", | |
| "/usr/local/lib/python3.10/site-packages/paddleocr/ppocr/utils/fonts/simfang.ttf" | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| font_path = path | |
| break | |
| if return_text_only: | |
| # 仅返回文本 | |
| return "\n".join(txts), LANG_MAP.get(lang, lang) | |
| else: | |
| # 返回带标注的图像 | |
| try: | |
| im_show = draw_ocr(pil_img, boxes, txts, scores, font_path=font_path) | |
| return im_show, "\n".join(txts), LANG_MAP.get(lang, lang) | |
| except Exception as e: | |
| print(f"绘制OCR结果时出错: {e}") | |
| # 如果绘制失败,返回原图和文本 | |
| return pil_img, "\n".join(txts), LANG_MAP.get(lang, lang) | |
| finally: | |
| # 清理临时文件 | |
| if temp_file and os.path.exists(temp_file): | |
| try: | |
| os.unlink(temp_file) | |
| except: | |
| pass | |
| def inference_with_image(img): | |
| """返回带标注的图像和文本""" | |
| im_show, text, lang = inference(img, return_text_only=False) | |
| return im_show, text, lang | |
| def inference_text_only(img): | |
| """仅返回文本""" | |
| text, lang = inference(img, return_text_only=True) | |
| return text, lang | |
| def inference_base64(base64_string): | |
| """处理Base64图像并返回OCR结果""" | |
| if not base64_string or base64_string.strip() == "": | |
| return "请提供有效的Base64图像字符串", "" | |
| try: | |
| text, lang = inference(base64_string, return_text_only=True) | |
| return text, lang | |
| except Exception as e: | |
| return f"处理Base64图像时出错: {str(e)}", "" | |
| title = '🔍 PaddleOCR 智能文字识别' | |
| description = ''' | |
| ### 功能特点 | |
| - 支持中文、英文、法语、德语、韩语和日语的智能文字识别 | |
| - 自动检测图像中的语言,无需手动选择 | |
| - 支持Base64编码图像识别 | |
| - 同时提供文本结果和标注图像 | |
| ### 使用方法 | |
| - 上传图像或提供Base64编码的图像数据 | |
| - 系统会自动检测语言并进行OCR识别 | |
| - 查看识别结果和标注图像 | |
| ''' | |
| examples = [ | |
| ['en_example.jpg'], | |
| ['cn_example.jpg'], | |
| ['jp_example.jpg'], | |
| ] | |
| # 自定义CSS样式,优化界面 | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Roboto', 'Microsoft YaHei', sans-serif; | |
| } | |
| .output_image, .input_image { | |
| height: 30rem !important; | |
| width: 100% !important; | |
| object-fit: contain; | |
| border-radius: 8px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .tabs { | |
| margin-top: 0.5rem; | |
| } | |
| .output-text { | |
| font-family: 'Courier New', monospace; | |
| line-height: 1.5; | |
| padding: 1rem; | |
| border-radius: 8px; | |
| background-color: #f8f9fa; | |
| border: 1px solid #e9ecef; | |
| } | |
| .detected-lang { | |
| font-weight: bold; | |
| color: #4285f4; | |
| margin-bottom: 0.5rem; | |
| } | |
| """ | |
| # 使用Gradio Blocks创建更丰富的界面 | |
| with gr.Blocks(title=title, css=css) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Tabs() as tabs: | |
| # 图像上传标签页 | |
| with gr.TabItem("图像上传识别"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="上传图像", type="filepath") | |
| image_submit = gr.Button("开始识别", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| image_output = gr.Image(label="标注结果", type="pil") | |
| with gr.Row(): | |
| detected_lang = gr.Textbox(label="检测到的语言", lines=1) | |
| with gr.Row(): | |
| text_output = gr.Textbox(label="识别文本", lines=10, elem_classes=["output-text"]) | |
| # Base64标签页 | |
| with gr.TabItem("Base64图像识别"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| base64_input = gr.Textbox( | |
| label="输入Base64编码的图像数据", | |
| lines=8, | |
| placeholder="在此粘贴Base64编码的图像数据..." | |
| ) | |
| base64_submit = gr.Button("开始识别", variant="primary") | |
| with gr.Column(scale=2): | |
| base64_lang = gr.Textbox(label="检测到的语言", lines=1) | |
| base64_output = gr.Textbox( | |
| label="识别文本", | |
| lines=15, | |
| elem_classes=["output-text"] | |
| ) | |
| # API使用说明 | |
| with gr.Accordion("API使用说明", open=False): | |
| gr.Markdown(""" | |
| ## API使用方法 | |
| ### 1. 图像上传API | |
| ```bash | |
| curl -X POST "http://localhost:7860/api/predict" \\ | |
| -F "fn_index=0" \\ | |
| -F "data=@/path/to/your/image.jpg" | |
| ``` | |
| ### 2. Base64图像API | |
| ```bash | |
| curl -X POST "http://localhost:7860/api/predict" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{ | |
| "fn_index": 1, | |
| "data": ["YOUR_BASE64_STRING_HERE"] | |
| }' | |
| ``` | |
| """) | |
| # 设置事件处理 | |
| image_submit.click( | |
| fn=inference_with_image, | |
| inputs=[image_input], | |
| outputs=[image_output, text_output, detected_lang] | |
| ) | |
| base64_submit.click( | |
| fn=inference_base64, | |
| inputs=[base64_input], | |
| outputs=[base64_output, base64_lang] | |
| ) | |
| # 启动Gradio应用 | |
| demo.launch(debug=False, share=False) |