IncreasingLoss's picture
Update app.py
444e7de verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as FU
from torchvision import transforms
import numpy as np
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
import os
import torch
from pypinyin import pinyin
from PIL import Image, ImageDraw, ImageEnhance
import json
import os
import torchvision.transforms.functional as F
from statistics import mean
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize LLM as None - will be loaded lazily
llm = None
def load_llm():
"""Lazy loading of LLM to avoid startup delays"""
global llm
if llm is None:
try:
logger.info("Loading LLM model...")
from llama_cpp import Llama
# Check if model exists locally first
model_filename = "Yi-1.5-9B-Chat-Q6_K.gguf"
local_model_path = os.path.join("./models", model_filename)
if not os.path.exists(local_model_path):
logger.info("Downloading LLM model from HuggingFace...")
model_path = hf_hub_download(
repo_id="IncreasingLoss/FineTunedTranslation_Yi-1.5-9B-Chat-Q6_K",
filename=model_filename,
local_dir="./models",
local_dir_use_symlinks=False
)
else:
model_path = local_model_path
logger.info(f"Using existing model at: {model_path}")
# Initialize with conservative settings for HF Spaces
llm = Llama(
model_path=model_path,
n_ctx=2048, # Reduced context size
n_gpu_layers=0, # CPU only for HF Spaces compatibility
verbose=False,
n_threads=2, # Limit threads for shared environment
use_mmap=True, # Memory mapping for efficiency
use_mlock=False # Don't lock memory
)
logger.info("LLM model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load LLM: {e}")
llm = None
return llm
"""yolo model"""
logger.info("Loading YOLO model...")
user_device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {user_device}")
try:
detection_model = YOLO("model/yolo_chinese_m.pt").to(user_device).eval()
logger.info("YOLO model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load YOLO model: {e}")
raise
"""LW-Vit Classifier"""
class HSwish(nn.Module):
"""Hard Swish activation function"""
def forward(self, x):
out = x * FU.relu6(x + 3, inplace=True) / 6
return out
class MV2_Block(nn.Module):
"""MobileNetV2 Inverted Residual Block with h-swish activation"""
def __init__(self, in_channels, expand_channels, out_channels, stride):
super().__init__()
self.stride = stride
self.use_res_connect = self.stride == 1 and in_channels == out_channels
self.expand = nn.Sequential(
nn.Conv2d(in_channels, expand_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(expand_channels),
HSwish()
)
self.depthwise = nn.Sequential(
nn.Conv2d(expand_channels, expand_channels, kernel_size=3, stride=stride,
padding=1, groups=expand_channels, bias=False),
nn.BatchNorm2d(expand_channels),
HSwish()
)
self.project = nn.Sequential(
nn.Conv2d(expand_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
if self.use_res_connect:
return x + self.project(self.depthwise(self.expand(x)))
else:
return self.project(self.depthwise(self.expand(x)))
class LW_ViT_Transformer_Block(nn.Module):
"""Lightweight Vision Transformer Block"""
def __init__(self, in_channels, out_channels, patch_size=2, heads=4, dim=128):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.dim = dim
# Downsampling and channel adjustment
self.downsample = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)
self.channel_adjust = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
# Calculate patch dimension
patch_dim = out_channels * patch_size * patch_size
# Add projection layer if needed
self.projection = nn.Identity() if patch_dim == dim else nn.Linear(patch_dim, dim)
# Transformer components
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# Final processing
self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm_out = nn.BatchNorm2d(out_channels)
self.act = HSwish()
def forward(self, x):
# Initial downsampling and channel adjustment
x = self.downsample(x)
x = self.channel_adjust(x)
B, C, H, W = x.shape
# Convert feature map to patches and embed
# Reshape to [B, C, H/P, P, W/P, P]
x_reshaped = x.reshape(B, C, H // self.patch_size, self.patch_size,
W // self.patch_size, self.patch_size)
# Permute to [B, H/P, W/P, C, P, P]
x_permuted = x_reshaped.permute(0, 2, 4, 1, 3, 5)
# Reshape to [B, H/P * W/P, C * P * P]
patches = x_permuted.reshape(B, (H // self.patch_size) * (W // self.patch_size), -1)
# Apply projection if needed
patches = self.projection(patches)
# Apply transformer operations
normed_patches = self.norm1(patches)
attn_out, _ = self.attn(normed_patches, normed_patches, normed_patches)
patches = patches + attn_out
normed_patches = self.norm2(patches)
ffn_out = self.ffn(normed_patches)
patches = patches + ffn_out
# Reshape back to feature map
# Reshape to [B, H/P, W/P, C, P, P]
x_back = patches.reshape(B, H // self.patch_size, W // self.patch_size,
C, self.patch_size, self.patch_size)
# Permute to [B, C, H/P, P, W/P, P]
x_back = x_back.permute(0, 3, 1, 4, 2, 5)
# Reshape to [B, C, H, W]
x_out = x_back.reshape(B, C, H, W)
# Final processing
out = self.final_conv(x_out)
out = self.norm_out(out)
out = self.act(out)
return out
class LW_ViT(nn.Module):
"""Lightweight Vision Transformer for Chinese Character Recognition"""
def __init__(self, base_channels=16, num_classes=3892):
super().__init__()
# Stem layer
self.stem = nn.Sequential(
nn.Conv2d(3, base_channels, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels),
HSwish()
)
# Feature extraction layers
self.features = nn.Sequential(
# MV2 blocks
MV2_Block(base_channels, 4*base_channels, base_channels, stride=1),
MV2_Block(base_channels, 4*base_channels, base_channels, stride=2),
MV2_Block(base_channels, 4*base_channels, base_channels, stride=1),
MV2_Block(base_channels, 4*base_channels, base_channels, stride=2),
# Transformer block
LW_ViT_Transformer_Block(
in_channels=base_channels,
out_channels=2*base_channels,
patch_size=2,
heads=4,
dim=128
),
# Final convolutional layer
nn.Conv2d(2*base_channels, 4*base_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(4*base_channels),
HSwish()
)
# Global pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(0.1) # Added dropout for regularization
self.classifier = nn.Sequential(
nn.Linear(4*base_channels, 8*base_channels),
nn.BatchNorm1d(8*base_channels),
HSwish(),
nn.Dropout(0.1), # Added dropout for regularization
nn.Linear(8*base_channels, num_classes)
)
def forward(self, x):
x = self.stem(x)
x = self.features(x)
x = self.global_pool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.classifier(x)
return x
logger.info("Loading classifier model...")
try:
classifier = LW_ViT()
classifier.load_state_dict(torch.load("model/chinese_Character_Classification_handwritten_LW_ViT.pth", weights_only=True, map_location=user_device))
classifier = classifier.to(user_device)
classifier.eval()
logger.info("Classifier model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load classifier: {e}")
raise
"""load classes dict"""
logger.info("Loading classes dictionary...")
try:
with open("model/chinese_classes.json", "r", encoding="utf-8") as f:
classes_dict = json.load(f)
logger.info(f"Loaded {len(classes_dict)} classes")
except Exception as e:
logger.error(f"Failed to load classes dictionary: {e}")
raise
"""transforms"""
class ToSquare(object):
"""
Transform to make images square by padding the shorter dimension
"""
def __init__(self, fill=0):
self.fill = fill # Fill value for padding (0 = black)
def __call__(self, img):
w, h = img.size
# If already square, return as is
if w == h:
return img
# Calculate target size (max dimension)
max_dim = max(w, h)
# Calculate padding
pad_w = (max_dim - w) // 2
pad_h = (max_dim - h) // 2
# Handle odd dimensions (extra pixel on one side)
pad_w_extra = (max_dim - w) % 2
pad_h_extra = (max_dim - h) % 2
# Create padding list (left, top, right, bottom)
padding = [pad_w, pad_h, pad_w + pad_w_extra, pad_h + pad_h_extra]
# Create new padded image
padded_img = F.pad(img, padding, self.fill)
return padded_img
class ConvertToRGB(object):
"""Convert image to RGB mode"""
def __call__(self, img):
# Convert any image mode to RGB (including RGBA)
return img.convert('RGB')
class Contrast(object):
"""Randomly adjust contrast"""
def __init__(self, p=0.3, factor_range=(0.5, 1.5)):
self.p = p
self.factor_range = factor_range
def __call__(self, img):
enhancer = ImageEnhance.Contrast(img)
return enhancer.enhance(1.5)
test_transforms = transforms.Compose([
ToSquare(fill=255),
transforms.Resize(128),
ConvertToRGB(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
"""app functions"""
def convert_to_pil(inputimg):
# convert input to pil image
if inputimg is None:
return None
if isinstance(inputimg, np.ndarray):
pil_img = Image.fromarray(inputimg)
elif hasattr(inputimg, 'convert'): # Check if it's already a PIL Image
pil_img = inputimg
else:
pil_img = Image.fromarray(np.array(inputimg))
pil_img = pil_img.convert("RGB")
return pil_img
def draw_bboxes(img, results):
# Create a copy of the input image to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Draw bounding boxes for detected objects
for box in results[0].boxes.xyxy:
x1, y1, x2, y2 = map(int, box) # Convert to integers
draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=2)
return draw_img
def crop_bboxes_by_reading_order(img, bboxes, written_vertical=False):
"""
Crop and return image regions in proper reading order.
For vertical Chinese:
- Group into columns by x-coordinate (clusters whose centers are within avg_width/2)
- Order columns right-to-left
- Within each column, order top-to-bottom by y
For horizontal:
- Group into rows by y-coordinate (clusters within avg_height/2)
- Order rows top-to-bottom
- Within each row, order left-to-right by x
"""
# 1) If bboxes is a Tensor, convert to Python list
if hasattr(bboxes, "tolist"):
bboxes = bboxes.tolist()
# 2) Single box -> trivial crop
if len(bboxes) <= 1:
return [img.crop(tuple(map(int, bbox))) for bbox in bboxes]
# 3) Compute centers and average glyph size
centers = [((x1+x2)/2, (y1+y2)/2) for x1,y1,x2,y2 in bboxes]
widths = [x2 - x1 for x1,y1,x2,y2 in bboxes]
heights = [y2 - y1 for x1,y1,x2,y2 in bboxes]
avg_w = mean(widths)
avg_h = mean(heights)
# 4) Choose clustering axis and threshold
if written_vertical:
coords = [c[0] for c in centers] # x-coords
threshold = avg_w / 2
else:
coords = [c[1] for c in centers] # y-coords
threshold = avg_h / 2
# 5) Sort glyphs by the clustering axis
order = sorted(range(len(coords)), key=lambda i: coords[i])
# 6) Build clusters
clusters = []
current = [order[0]]
current_mean = coords[order[0]]
for idx in order[1:]:
c = coords[idx]
# if within threshold of this cluster's mean, add; else start new
if abs(c - current_mean) <= threshold:
current.append(idx)
# update cluster mean incrementally
current_mean = mean([coords[i] for i in current])
else:
clusters.append(current)
current = [idx]
current_mean = c
clusters.append(current)
# 7) Order clusters
if written_vertical:
# Chinese vertical: rightmost column first
clusters.sort(
key=lambda grp: mean(coords[i] for i in grp),
reverse=True
)
else:
# Horizontal: top row first
clusters.sort(
key=lambda grp: mean(coords[i] for i in grp)
)
# 8) Within each cluster, sort by the orthogonal axis
final_indices = []
for grp in clusters:
if written_vertical:
# sort top-to-bottom by y
grp.sort(key=lambda i: centers[i][1])
else:
# sort left-to-right by x
grp.sort(key=lambda i: centers[i][0])
final_indices.extend(grp)
# 9) Crop and return
crops = []
for i in final_indices:
x1, y1, x2, y2 = map(int, bboxes[i])
crops.append(img.crop((x1, y1, x2, y2)))
return crops
def move_slider(threshhold_slider, input_img):
pil_img = convert_to_pil(input_img)
if pil_img is None:
return None
with torch.inference_mode():
results = detection_model(source=pil_img, conf=threshhold_slider)
draw_img = draw_bboxes(pil_img, results)
return draw_img
def select_image(evt: gr.SelectData):
selected_index = evt.index
return example_images[selected_index]
def translate_text(threshhold_slider, is_vertical, input_img):
# convert input to pil image
pil_img = convert_to_pil(input_img)
if pil_img is None:
return "", "", "Please upload an image first."
try:
with torch.inference_mode():
results = detection_model(source=pil_img, conf=threshhold_slider)
sorted_cropped_images = crop_bboxes_by_reading_order(img=pil_img, bboxes=results[0].boxes.xyxy, written_vertical=is_vertical)
chinese_text = ""
for crop in sorted_cropped_images:
crop = test_transforms(crop)
crop = crop.to(user_device) # Move tensor to device
crop_batch = crop.unsqueeze(dim=0) # Add batch dimension
classifier.to(user_device)
logits = classifier(crop_batch)
class_idx = logits.argmax(dim=1)
class_idx_cpu = class_idx.cpu().item() # Convert to CPU integer
chinese_charcter = classes_dict[class_idx_cpu] # Use integer index for list
chinese_text += chinese_charcter
# Generate pinyin
pinyin_text = pinyin(chinese_text)
pinyin_sentence = " "
for pin in pinyin_text:
pinyin_sentence += f"{pin[0]} "
# Load LLM lazily only when needed for translation
current_llm = load_llm()
if current_llm is None:
return chinese_text, pinyin_sentence, "Translation service unavailable - LLM failed to load."
# Generate translation
prompt = f""" You are a professional Chinese to English translator.
1. Translate the following Chinese text to natural, fluent English:
"{chinese_text}"
2. Respond only with the translated English text.
English translation: """
try:
output = current_llm(
prompt,
max_tokens=min(len(chinese_text)*3, 256), # Conservative token limit
stop=["</s>", "\n\n", "Chinese:", "Chinese text:"],
echo=False,
temperature=0.3,
frequency_penalty=0.5,
presence_penalty=0.5,
top_p=0.9,
stream=False
)
# Extract the translation
translation = output["choices"][0]["text"].strip()
# Clean up translation if it contains quotes
try:
if '"' in translation:
start = translation.index('"') + 1
end = translation.rindex('"')
translation = translation[start:end]
except:
pass
return chinese_text, pinyin_sentence, translation
except Exception as e:
logger.error(f"Translation failed: {e}")
return chinese_text, pinyin_sentence, f"Translation failed: {str(e)}"
except Exception as e:
logger.error(f"OCR processing failed: {e}")
return "", "", f"OCR processing failed: {str(e)}"
css = """
.centered-examples {
margin: 0 auto !important;
justify-content: center !important;
gap: 8px !important;
min-height: 150px !important;
}
.centered-examples .thumb {
height: 100px !important;
width: 100px !important;
object-fit: cover !important;
margin: 5px !important;
}
#my_media_gallery {
min-height: 0 !important;
max-height: none !important;
height: auto !important;
}
#my_media_gallery * {
min-height: 0 !important;
}
"""
"""gradio app"""
logger.info("Setting up Gradio interface...")
# Check if examples directory exists
example_dir = "examples"
example_images = []
if os.path.exists(example_dir):
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
logger.info(f"Found {len(example_images)} example images")
else:
logger.warning(f"Examples directory '{example_dir}' not found")
with gr.Blocks(css=css, title="DeepTranslate: Chinese OCR") as program:
gr.Markdown("## DeepTranslate: Chinese OCR with translation to English")
gr.Markdown("Upload or select an image and move the slider to detect characters.")
gr.Markdown("Make sure that the input image is high resolution and not rotated in any way!")
gr.Markdown("Spaces is very slow since its running on a 2 core cpu, expect translation times of 2-4 minutes. (12 seconds on a 4080)")
#inputs
with gr.Column(scale=1):
if example_images:
gallery = gr.Gallery(value=example_images,
label="Example Images (Click to Select)",
columns=6,
height = "auto",
allow_preview=False,
elem_id="my_media_gallery",
elem_classes=["centered-examples"])
with gr.Row(scale=2):
input_img = gr.Image(label="Input Image ")
detection_img = gr.Image(label="Detection Image", interactive=False)
# slider and button
with gr.Column(scale=1):
with gr.Row(scale=3):
threshhold_slider = gr.Slider(value=0.25, minimum=0, maximum=0.75, label="Detection Threshold", step=0.01)
translate_button = gr.Button("Translate To English", variant="primary")
is_vertical = gr.Checkbox(value=False, label="Vertical Chinese Text?", interactive=True)
# outputs
with gr.Column(scale=1):
with gr.Row(scale=3):
chinese_text = gr.TextArea(label="Chinese Text", max_lines=1000, interactive=False)
pinyin_text = gr.TextArea(label="Chinese Pinyin", max_lines=1000, interactive=False)
english_text = gr.TextArea(label="English Text", max_lines=1000, interactive=False)
# function calling
threshhold_slider.change(fn=move_slider, inputs=[threshhold_slider, input_img], outputs=[detection_img])
translate_button.click(fn=translate_text, inputs=[threshhold_slider, is_vertical, input_img], outputs=[chinese_text, pinyin_text, english_text])
if example_images:
gallery.select(fn=select_image, inputs=None, outputs=input_img)
logger.info("Gradio interface ready!")
if __name__ == "__main__":
program.launch(share=False, server_name="0.0.0.0", show_error=True)