Spaces:
Build error
Build error
| import json | |
| import os | |
| from typing import List | |
| import requests | |
| import easyocr | |
| from revChatGPT.V3 import Chatbot | |
| import time | |
| ocr_model = {} | |
| def covert_language(lng): | |
| if lng == 'zh-CN': | |
| return 'ch_sim' | |
| return lng | |
| def get_ocr_model(languages: List[str]): | |
| languages = list(map(covert_language, languages)) | |
| try: | |
| ocr_model[', '.join(languages)] = easyocr.Reader(languages) | |
| except Exception as ex: | |
| print(ex) | |
| languages = ['ch_sim', 'en'] | |
| ocr_model[', '.join(languages)] = easyocr.Reader(languages) | |
| return ocr_model[', '.join(languages)] | |
| def init_chatgpt(): | |
| chatbot = Chatbot(api_key=os.getenv("CHATGPT_API_KEY")) | |
| chatbot.conversation['ocr'] = [ | |
| { | |
| "role": "system", | |
| "content": "You are a ocr bot, Collating the following ocr result output to turn them into dialog context", | |
| }, | |
| ] | |
| return chatbot | |
| tags = [ | |
| '猫娘', 'girl', '霸道总裁', '魅魔', '小说', | |
| ] | |
| def get_tags(txt): | |
| # Initialize an empty list to store the tags | |
| extract_tags = [] | |
| # Loop through each tag in the list | |
| for tag in tags: | |
| # If the tag appears in the text and is not an empty string, add it to the list of tags | |
| if tag in txt: | |
| extract_tags.append({ | |
| 'tag': tag, | |
| 'confidence': 1, | |
| }) | |
| # Return the list of tags | |
| return extract_tags | |
| def ocr(reader, chatbot, file_path: str): | |
| start_time = time.time() | |
| doc = reader.readtext(read_file(file_path)) | |
| end_time = time.time() | |
| print("ocr cost time: ", end_time - start_time) | |
| doc = str(doc) | |
| start_txt = "Collating the following ocr result output to turn them into conversation. " + \ | |
| "Be careful that long text block may be mutiple box, you should merge them. " + \ | |
| "And the ocr text maybe misspelled, you should fix them" + \ | |
| "the input is the ocr box and text" + \ | |
| "The output format should be a valid json array, each item is a object, include 'question' and 'answer' field" + \ | |
| "if the text is new line, you should add '\n' in text\n" | |
| for i in range(3): | |
| try: | |
| start_time = time.time() | |
| resp = chatbot.ask(prompt=start_txt + doc, convo_id='ocr') | |
| end_time = time.time() | |
| print("chatgpt cost time: ", end_time - start_time) | |
| chatbot.conversation['ocr'] = [ | |
| { | |
| "role": "system", | |
| "content": "You are a ocr bot" | |
| }, | |
| ] | |
| extract_tags = get_tags(resp) | |
| ret = json.loads(resp, strict=False) | |
| except Exception as ex: | |
| print(ex) | |
| if ret: | |
| return ret, extract_tags | |
| return [] | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:47.0) Gecko/20100101 Firefox/47.0" | |
| } | |
| def read_file(file_path: str): | |
| if file_path.startswith("http"): | |
| import urllib.request | |
| tmp_path = f'/tmp/img_{file_path.split("/")[-1].split("?")[0]}' | |
| if os.path.exists(tmp_path): | |
| pass | |
| else: | |
| resp = requests.get(file_path, headers=headers).content | |
| with open(tmp_path, "wb") as f: | |
| f.write(resp) | |
| file_path = tmp_path | |
| return file_path | |