import gradio as gr import torch import torch.nn as nn import torch.nn.functional as FU from torchvision import transforms import numpy as np from ultralytics import YOLO from huggingface_hub import hf_hub_download import os import torch from pypinyin import pinyin from PIL import Image, ImageDraw, ImageEnhance import json import os import torchvision.transforms.functional as F from statistics import mean import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize LLM as None - will be loaded lazily llm = None def load_llm(): """Lazy loading of LLM to avoid startup delays""" global llm if llm is None: try: logger.info("Loading LLM model...") from llama_cpp import Llama # Check if model exists locally first model_filename = "Yi-1.5-9B-Chat-Q6_K.gguf" local_model_path = os.path.join("./models", model_filename) if not os.path.exists(local_model_path): logger.info("Downloading LLM model from HuggingFace...") model_path = hf_hub_download( repo_id="IncreasingLoss/FineTunedTranslation_Yi-1.5-9B-Chat-Q6_K", filename=model_filename, local_dir="./models", local_dir_use_symlinks=False ) else: model_path = local_model_path logger.info(f"Using existing model at: {model_path}") # Initialize with conservative settings for HF Spaces llm = Llama( model_path=model_path, n_ctx=2048, # Reduced context size n_gpu_layers=0, # CPU only for HF Spaces compatibility verbose=False, n_threads=2, # Limit threads for shared environment use_mmap=True, # Memory mapping for efficiency use_mlock=False # Don't lock memory ) logger.info("LLM model loaded successfully!") except Exception as e: logger.error(f"Failed to load LLM: {e}") llm = None return llm """yolo model""" logger.info("Loading YOLO model...") user_device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {user_device}") try: detection_model = YOLO("model/yolo_chinese_m.pt").to(user_device).eval() logger.info("YOLO model loaded successfully!") except Exception as e: logger.error(f"Failed to load YOLO model: {e}") raise """LW-Vit Classifier""" class HSwish(nn.Module): """Hard Swish activation function""" def forward(self, x): out = x * FU.relu6(x + 3, inplace=True) / 6 return out class MV2_Block(nn.Module): """MobileNetV2 Inverted Residual Block with h-swish activation""" def __init__(self, in_channels, expand_channels, out_channels, stride): super().__init__() self.stride = stride self.use_res_connect = self.stride == 1 and in_channels == out_channels self.expand = nn.Sequential( nn.Conv2d(in_channels, expand_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(expand_channels), HSwish() ) self.depthwise = nn.Sequential( nn.Conv2d(expand_channels, expand_channels, kernel_size=3, stride=stride, padding=1, groups=expand_channels, bias=False), nn.BatchNorm2d(expand_channels), HSwish() ) self.project = nn.Sequential( nn.Conv2d(expand_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): if self.use_res_connect: return x + self.project(self.depthwise(self.expand(x))) else: return self.project(self.depthwise(self.expand(x))) class LW_ViT_Transformer_Block(nn.Module): """Lightweight Vision Transformer Block""" def __init__(self, in_channels, out_channels, patch_size=2, heads=4, dim=128): super().__init__() self.patch_size = patch_size self.in_channels = in_channels self.out_channels = out_channels self.dim = dim # Downsampling and channel adjustment self.downsample = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1) self.channel_adjust = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) # Calculate patch dimension patch_dim = out_channels * patch_size * patch_size # Add projection layer if needed self.projection = nn.Identity() if patch_dim == dim else nn.Linear(patch_dim, dim) # Transformer components self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True) self.norm2 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) # Final processing self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.norm_out = nn.BatchNorm2d(out_channels) self.act = HSwish() def forward(self, x): # Initial downsampling and channel adjustment x = self.downsample(x) x = self.channel_adjust(x) B, C, H, W = x.shape # Convert feature map to patches and embed # Reshape to [B, C, H/P, P, W/P, P] x_reshaped = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size) # Permute to [B, H/P, W/P, C, P, P] x_permuted = x_reshaped.permute(0, 2, 4, 1, 3, 5) # Reshape to [B, H/P * W/P, C * P * P] patches = x_permuted.reshape(B, (H // self.patch_size) * (W // self.patch_size), -1) # Apply projection if needed patches = self.projection(patches) # Apply transformer operations normed_patches = self.norm1(patches) attn_out, _ = self.attn(normed_patches, normed_patches, normed_patches) patches = patches + attn_out normed_patches = self.norm2(patches) ffn_out = self.ffn(normed_patches) patches = patches + ffn_out # Reshape back to feature map # Reshape to [B, H/P, W/P, C, P, P] x_back = patches.reshape(B, H // self.patch_size, W // self.patch_size, C, self.patch_size, self.patch_size) # Permute to [B, C, H/P, P, W/P, P] x_back = x_back.permute(0, 3, 1, 4, 2, 5) # Reshape to [B, C, H, W] x_out = x_back.reshape(B, C, H, W) # Final processing out = self.final_conv(x_out) out = self.norm_out(out) out = self.act(out) return out class LW_ViT(nn.Module): """Lightweight Vision Transformer for Chinese Character Recognition""" def __init__(self, base_channels=16, num_classes=3892): super().__init__() # Stem layer self.stem = nn.Sequential( nn.Conv2d(3, base_channels, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(base_channels), HSwish() ) # Feature extraction layers self.features = nn.Sequential( # MV2 blocks MV2_Block(base_channels, 4*base_channels, base_channels, stride=1), MV2_Block(base_channels, 4*base_channels, base_channels, stride=2), MV2_Block(base_channels, 4*base_channels, base_channels, stride=1), MV2_Block(base_channels, 4*base_channels, base_channels, stride=2), # Transformer block LW_ViT_Transformer_Block( in_channels=base_channels, out_channels=2*base_channels, patch_size=2, heads=4, dim=128 ), # Final convolutional layer nn.Conv2d(2*base_channels, 4*base_channels, kernel_size=1, stride=1), nn.BatchNorm2d(4*base_channels), HSwish() ) # Global pooling and classifier self.global_pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(0.1) # Added dropout for regularization self.classifier = nn.Sequential( nn.Linear(4*base_channels, 8*base_channels), nn.BatchNorm1d(8*base_channels), HSwish(), nn.Dropout(0.1), # Added dropout for regularization nn.Linear(8*base_channels, num_classes) ) def forward(self, x): x = self.stem(x) x = self.features(x) x = self.global_pool(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.classifier(x) return x logger.info("Loading classifier model...") try: classifier = LW_ViT() classifier.load_state_dict(torch.load("model/chinese_Character_Classification_handwritten_LW_ViT.pth", weights_only=True, map_location=user_device)) classifier = classifier.to(user_device) classifier.eval() logger.info("Classifier model loaded successfully!") except Exception as e: logger.error(f"Failed to load classifier: {e}") raise """load classes dict""" logger.info("Loading classes dictionary...") try: with open("model/chinese_classes.json", "r", encoding="utf-8") as f: classes_dict = json.load(f) logger.info(f"Loaded {len(classes_dict)} classes") except Exception as e: logger.error(f"Failed to load classes dictionary: {e}") raise """transforms""" class ToSquare(object): """ Transform to make images square by padding the shorter dimension """ def __init__(self, fill=0): self.fill = fill # Fill value for padding (0 = black) def __call__(self, img): w, h = img.size # If already square, return as is if w == h: return img # Calculate target size (max dimension) max_dim = max(w, h) # Calculate padding pad_w = (max_dim - w) // 2 pad_h = (max_dim - h) // 2 # Handle odd dimensions (extra pixel on one side) pad_w_extra = (max_dim - w) % 2 pad_h_extra = (max_dim - h) % 2 # Create padding list (left, top, right, bottom) padding = [pad_w, pad_h, pad_w + pad_w_extra, pad_h + pad_h_extra] # Create new padded image padded_img = F.pad(img, padding, self.fill) return padded_img class ConvertToRGB(object): """Convert image to RGB mode""" def __call__(self, img): # Convert any image mode to RGB (including RGBA) return img.convert('RGB') class Contrast(object): """Randomly adjust contrast""" def __init__(self, p=0.3, factor_range=(0.5, 1.5)): self.p = p self.factor_range = factor_range def __call__(self, img): enhancer = ImageEnhance.Contrast(img) return enhancer.enhance(1.5) test_transforms = transforms.Compose([ ToSquare(fill=255), transforms.Resize(128), ConvertToRGB(), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) """app functions""" def convert_to_pil(inputimg): # convert input to pil image if inputimg is None: return None if isinstance(inputimg, np.ndarray): pil_img = Image.fromarray(inputimg) elif hasattr(inputimg, 'convert'): # Check if it's already a PIL Image pil_img = inputimg else: pil_img = Image.fromarray(np.array(inputimg)) pil_img = pil_img.convert("RGB") return pil_img def draw_bboxes(img, results): # Create a copy of the input image to draw on draw_img = img.copy() draw = ImageDraw.Draw(draw_img) # Draw bounding boxes for detected objects for box in results[0].boxes.xyxy: x1, y1, x2, y2 = map(int, box) # Convert to integers draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=2) return draw_img def crop_bboxes_by_reading_order(img, bboxes, written_vertical=False): """ Crop and return image regions in proper reading order. For vertical Chinese: - Group into columns by x-coordinate (clusters whose centers are within avg_width/2) - Order columns right-to-left - Within each column, order top-to-bottom by y For horizontal: - Group into rows by y-coordinate (clusters within avg_height/2) - Order rows top-to-bottom - Within each row, order left-to-right by x """ # 1) If bboxes is a Tensor, convert to Python list if hasattr(bboxes, "tolist"): bboxes = bboxes.tolist() # 2) Single box -> trivial crop if len(bboxes) <= 1: return [img.crop(tuple(map(int, bbox))) for bbox in bboxes] # 3) Compute centers and average glyph size centers = [((x1+x2)/2, (y1+y2)/2) for x1,y1,x2,y2 in bboxes] widths = [x2 - x1 for x1,y1,x2,y2 in bboxes] heights = [y2 - y1 for x1,y1,x2,y2 in bboxes] avg_w = mean(widths) avg_h = mean(heights) # 4) Choose clustering axis and threshold if written_vertical: coords = [c[0] for c in centers] # x-coords threshold = avg_w / 2 else: coords = [c[1] for c in centers] # y-coords threshold = avg_h / 2 # 5) Sort glyphs by the clustering axis order = sorted(range(len(coords)), key=lambda i: coords[i]) # 6) Build clusters clusters = [] current = [order[0]] current_mean = coords[order[0]] for idx in order[1:]: c = coords[idx] # if within threshold of this cluster's mean, add; else start new if abs(c - current_mean) <= threshold: current.append(idx) # update cluster mean incrementally current_mean = mean([coords[i] for i in current]) else: clusters.append(current) current = [idx] current_mean = c clusters.append(current) # 7) Order clusters if written_vertical: # Chinese vertical: rightmost column first clusters.sort( key=lambda grp: mean(coords[i] for i in grp), reverse=True ) else: # Horizontal: top row first clusters.sort( key=lambda grp: mean(coords[i] for i in grp) ) # 8) Within each cluster, sort by the orthogonal axis final_indices = [] for grp in clusters: if written_vertical: # sort top-to-bottom by y grp.sort(key=lambda i: centers[i][1]) else: # sort left-to-right by x grp.sort(key=lambda i: centers[i][0]) final_indices.extend(grp) # 9) Crop and return crops = [] for i in final_indices: x1, y1, x2, y2 = map(int, bboxes[i]) crops.append(img.crop((x1, y1, x2, y2))) return crops def move_slider(threshhold_slider, input_img): pil_img = convert_to_pil(input_img) if pil_img is None: return None with torch.inference_mode(): results = detection_model(source=pil_img, conf=threshhold_slider) draw_img = draw_bboxes(pil_img, results) return draw_img def select_image(evt: gr.SelectData): selected_index = evt.index return example_images[selected_index] def translate_text(threshhold_slider, is_vertical, input_img): # convert input to pil image pil_img = convert_to_pil(input_img) if pil_img is None: return "", "", "Please upload an image first." try: with torch.inference_mode(): results = detection_model(source=pil_img, conf=threshhold_slider) sorted_cropped_images = crop_bboxes_by_reading_order(img=pil_img, bboxes=results[0].boxes.xyxy, written_vertical=is_vertical) chinese_text = "" for crop in sorted_cropped_images: crop = test_transforms(crop) crop = crop.to(user_device) # Move tensor to device crop_batch = crop.unsqueeze(dim=0) # Add batch dimension classifier.to(user_device) logits = classifier(crop_batch) class_idx = logits.argmax(dim=1) class_idx_cpu = class_idx.cpu().item() # Convert to CPU integer chinese_charcter = classes_dict[class_idx_cpu] # Use integer index for list chinese_text += chinese_charcter # Generate pinyin pinyin_text = pinyin(chinese_text) pinyin_sentence = " " for pin in pinyin_text: pinyin_sentence += f"{pin[0]} " # Load LLM lazily only when needed for translation current_llm = load_llm() if current_llm is None: return chinese_text, pinyin_sentence, "Translation service unavailable - LLM failed to load." # Generate translation prompt = f""" You are a professional Chinese to English translator. 1. Translate the following Chinese text to natural, fluent English: "{chinese_text}" 2. Respond only with the translated English text. English translation: """ try: output = current_llm( prompt, max_tokens=min(len(chinese_text)*3, 256), # Conservative token limit stop=["", "\n\n", "Chinese:", "Chinese text:"], echo=False, temperature=0.3, frequency_penalty=0.5, presence_penalty=0.5, top_p=0.9, stream=False ) # Extract the translation translation = output["choices"][0]["text"].strip() # Clean up translation if it contains quotes try: if '"' in translation: start = translation.index('"') + 1 end = translation.rindex('"') translation = translation[start:end] except: pass return chinese_text, pinyin_sentence, translation except Exception as e: logger.error(f"Translation failed: {e}") return chinese_text, pinyin_sentence, f"Translation failed: {str(e)}" except Exception as e: logger.error(f"OCR processing failed: {e}") return "", "", f"OCR processing failed: {str(e)}" css = """ .centered-examples { margin: 0 auto !important; justify-content: center !important; gap: 8px !important; min-height: 150px !important; } .centered-examples .thumb { height: 100px !important; width: 100px !important; object-fit: cover !important; margin: 5px !important; } #my_media_gallery { min-height: 0 !important; max-height: none !important; height: auto !important; } #my_media_gallery * { min-height: 0 !important; } """ """gradio app""" logger.info("Setting up Gradio interface...") # Check if examples directory exists example_dir = "examples" example_images = [] if os.path.exists(example_dir): example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] logger.info(f"Found {len(example_images)} example images") else: logger.warning(f"Examples directory '{example_dir}' not found") with gr.Blocks(css=css, title="DeepTranslate: Chinese OCR") as program: gr.Markdown("## DeepTranslate: Chinese OCR with translation to English") gr.Markdown("Upload or select an image and move the slider to detect characters.") gr.Markdown("Make sure that the input image is high resolution and not rotated in any way!") gr.Markdown("Spaces is very slow since its running on a 2 core cpu, expect translation times of 2-4 minutes. (12 seconds on a 4080)") #inputs with gr.Column(scale=1): if example_images: gallery = gr.Gallery(value=example_images, label="Example Images (Click to Select)", columns=6, height = "auto", allow_preview=False, elem_id="my_media_gallery", elem_classes=["centered-examples"]) with gr.Row(scale=2): input_img = gr.Image(label="Input Image ") detection_img = gr.Image(label="Detection Image", interactive=False) # slider and button with gr.Column(scale=1): with gr.Row(scale=3): threshhold_slider = gr.Slider(value=0.25, minimum=0, maximum=0.75, label="Detection Threshold", step=0.01) translate_button = gr.Button("Translate To English", variant="primary") is_vertical = gr.Checkbox(value=False, label="Vertical Chinese Text?", interactive=True) # outputs with gr.Column(scale=1): with gr.Row(scale=3): chinese_text = gr.TextArea(label="Chinese Text", max_lines=1000, interactive=False) pinyin_text = gr.TextArea(label="Chinese Pinyin", max_lines=1000, interactive=False) english_text = gr.TextArea(label="English Text", max_lines=1000, interactive=False) # function calling threshhold_slider.change(fn=move_slider, inputs=[threshhold_slider, input_img], outputs=[detection_img]) translate_button.click(fn=translate_text, inputs=[threshhold_slider, is_vertical, input_img], outputs=[chinese_text, pinyin_text, english_text]) if example_images: gallery.select(fn=select_image, inputs=None, outputs=input_img) logger.info("Gradio interface ready!") if __name__ == "__main__": program.launch(share=False, server_name="0.0.0.0", show_error=True)