| import os |
| import cv2 |
| import tempfile |
| import numpy as np |
| from PIL import Image, UnidentifiedImageError |
| import torch |
| from torchvision import models, transforms |
| from ultralytics import YOLO |
| import gradio as gr |
| import torch.nn as nn |
| import pandas as pd |
| from io import BytesIO |
|
|
| |
| |
| |
| |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| try: |
| detection_model = YOLO('best.pt') |
| classifier_network = models.resnet50(weights=None) |
| classifier_network.fc = nn.Linear(classifier_network.fc.in_features, 3) |
| classifier_network.load_state_dict( |
| torch.load('rice_resnet_model.pth', map_location=device) |
| ) |
| classifier_network = classifier_network.to(device) |
| classifier_network.eval() |
| models_loaded = True |
| except Exception as e: |
| print(f"Model initialization failed: {e}") |
| detection_model = None |
| classifier_network = None |
| models_loaded = False |
|
|
| |
| VARIETY_MAP = { |
| 0: "C9 Premium", |
| 1: "Kant Special", |
| 2: "Superfine Grade" |
| } |
|
|
| VARIETY_COLORS = { |
| "C9 Premium": (255, 100, 100), |
| "Kant Special": (100, 100, 255), |
| "Superfine Grade": (100, 255, 100) |
| } |
|
|
| |
| image_preprocessor = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| |
| |
|
|
| def classify_grain(grain_image): |
| """ |
| Classify a single grain using the neural network. |
| Returns the grain variety label. |
| """ |
| if not models_loaded: |
| return "System Error" |
|
|
| tensor_input = image_preprocessor(grain_image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output = classifier_network(tensor_input) |
| class_idx = torch.argmax(output, dim=1).item() |
| return VARIETY_MAP[class_idx] |
|
|
| def generate_distribution_report(variety_counts): |
| """ |
| Generate a text-based summary of grain variety distribution |
| with total counts, percentages, and dominant variety. |
| """ |
| total = sum(variety_counts.values()) |
| if total == 0: |
| return "No grains detected for analysis." |
|
|
| report = ["## Grain Distribution Report\n"] |
| report.append(f"Total Grains Detected: **{total}**\n\n") |
| report.append("### Breakdown by Variety:\n") |
|
|
| for variety, count in sorted(variety_counts.items(), key=lambda x: x[1], reverse=True): |
| percentage = (count / total) * 100 |
| bar_length = int(percentage / 5) |
| bar = "█" * bar_length + "░" * (20 - bar_length) |
| report.append(f"- {variety}: {count} ({percentage:.1f}%) {bar}\n") |
|
|
| dominant_variety = max(variety_counts.items(), key=lambda x: x[1])[0] |
| report.append(f"\nDominant Variety: **{dominant_variety}**\n") |
| return "".join(report) |
|
|
| def generate_csv_export(grain_details): |
| """ |
| Convert grain detection results into a temporary CSV file for download. |
| Returns the file path. |
| """ |
| if not grain_details: |
| return None |
|
|
| df = pd.DataFrame(grain_details) |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w') |
| df.to_csv(tmp.name, index=False) |
| tmp.close() |
| return tmp.name |
|
|
| def load_image_safe(input_image): |
| """ |
| Safely load and validate an image from various input types. |
| Accepts PIL Image, numpy array, or file path string. |
| Returns a valid RGB PIL Image or raises gr.Error. |
| """ |
| try: |
| if input_image is None: |
| raise gr.Error("Please upload an image to start analysis.") |
|
|
| |
| if isinstance(input_image, str): |
| if not os.path.exists(input_image): |
| raise gr.Error(f"Image file not found: {input_image}") |
| img = Image.open(input_image).convert("RGB") |
|
|
| |
| elif isinstance(input_image, Image.Image): |
| img = input_image.convert("RGB") |
|
|
| |
| elif isinstance(input_image, np.ndarray): |
| img = Image.fromarray(input_image).convert("RGB") |
|
|
| else: |
| raise gr.Error(f"Unsupported image type: {type(input_image)}") |
|
|
| return img |
|
|
| except UnidentifiedImageError: |
| raise gr.Error("Could not read the image file. It may be corrupted or in an unsupported format.") |
| except gr.Error: |
| raise |
| except Exception as e: |
| raise gr.Error(f"Image loading failed: {str(e)}") |
|
|
| def analyze_rice_image(input_image): |
| """ |
| Full analysis pipeline: |
| 1. Validate and load image |
| 2. Detect grains |
| 3. Classify each grain |
| 4. Annotate image |
| 5. Generate distribution report |
| 6. Generate CSV export |
| """ |
| if not models_loaded: |
| raise gr.Error("Analysis engine not available. Check model files.") |
|
|
| |
| pil_image = load_image_safe(input_image) |
|
|
| |
| img_array = np.array(pil_image) |
| img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
| |
| results = detection_model(img_bgr, verbose=False)[0] |
| boxes = results.boxes.xyxy.cpu().numpy() |
|
|
| if len(boxes) == 0: |
| return ( |
| pil_image, |
| "No grains detected. Try a clearer image.", |
| None |
| ) |
|
|
| |
| variety_counts = {v: 0 for v in VARIETY_MAP.values()} |
| grain_details = [] |
|
|
| for idx, box in enumerate(boxes): |
| x1, y1, x2, y2 = map(int, box[:4]) |
| crop = img_bgr[y1:y2, x1:x2] |
|
|
| if crop.shape[0] > 0 and crop.shape[1] > 0: |
| pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) |
| variety_label = classify_grain(pil_crop) |
| variety_counts[variety_label] += 1 |
|
|
| |
| grain_details.append({ |
| "Grain_ID": f"G{idx+1:04d}", |
| "Variety": variety_label, |
| "X_center": (x1 + x2) // 2, |
| "Y_center": (y1 + y2) // 2 |
| }) |
|
|
| |
| color = VARIETY_COLORS[variety_label] |
| cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 3) |
| label = variety_label |
| (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) |
| cv2.rectangle(img_bgr, (x1, y1 - h - 10), (x1 + w, y1), color, -1) |
| cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
|
|
| |
| report_text = generate_distribution_report(variety_counts) |
| csv_path = generate_csv_export(grain_details) |
|
|
| return ( |
| Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)), |
| report_text, |
| csv_path |
| ) |
|
|
| |
| |
| |
|
|
| custom_css = """ |
| .gradio-container { |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| } |
| .header-box { |
| background: linear-gradient(135deg, #1e5631 0%, #4c9a2a 100%); |
| padding: 25px; |
| border-radius: 12px; |
| color: white; |
| text-align: center; |
| margin-bottom: 20px; |
| } |
| """ |
|
|
| |
| _all_samples = [ |
| "samples/rice3.jpg", |
| "samples/rice2.jpg", |
| "samples/rice4.jpg", |
| "samples/rice5.jpg" |
| ] |
| sample_images = [s for s in _all_samples if os.path.exists(s)] |
|
|
| with gr.Blocks(css=custom_css, title="Rice Classifier") as app: |
|
|
| gr.HTML(""" |
| <div class="header-box"> |
| <h1>Rice Analyzer Pro</h1> |
| <p>Advanced Grain Classification | Rice Grain Locator</p> |
| </div> |
| """) |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("Analysis"): |
| gr.Markdown(""" |
| ### How to Use |
| 1. Upload a clear image of rice grains. |
| 2. Click **Start Analysis**. |
| 3. Review annotated results, distribution report, and download CSV. |
| |
| **Color Coding:** Red = C9 Premium Blue = Kant Special Green = Superfine Grade |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload Sample Image", height=600, width=600) |
| start_btn = gr.Button("Start Analysis", variant="primary", size="lg") |
|
|
| with gr.Column(scale=1): |
| |
| annotated_output = gr.Image(label="Annotated Results", height=600, width=600) |
|
|
| with gr.Row(): |
| report_output = gr.Markdown(label="Distribution Report") |
|
|
| with gr.Row(): |
| |
| csv_output = gr.File(label="Download CSV Export") |
|
|
| start_btn.click( |
| fn=analyze_rice_image, |
| inputs=image_input, |
| outputs=[annotated_output, report_output, csv_output] |
| ) |
|
|
| |
| with gr.Tab("Documentation"): |
| gr.Markdown(""" |
| ## System Overview |
| |
| Rice Classifier uses a deep learning pipeline: |
| |
| 1. **Grain Detection:** YOLO-based model identifies rice grains. |
| 2. **Grain Classification:** ResNet50 model classifies grains into three varieties. |
| 3. **CSV Export:** Detailed grain data available for download. |
| |
| ### Supported Varieties |
| | Variety | Description | |
| |---------|-------------| |
| | C9 Premium | High-quality long grain | |
| | Kant Special | Medium grain specialty | |
| | Superfine Grade | Ultra-refined grain | |
| |
| ### Best Practices |
| - Use well-lit images without shadows |
| - Keep grains separated |
| - Use plain backgrounds |
| - Resolution: 1024x1024 or higher for best results |
| |
| ### Technical Details |
| - Detection: YOLOv8 |
| - Classification: ResNet50 fine-tuned |
| - GPU recommended for faster processing |
| """) |
|
|
| gr.Markdown("---") |
|
|
| if sample_images: |
| gr.Markdown("### Sample Gallery") |
| gr.Examples( |
| examples=sample_images, |
| inputs=image_input, |
| outputs=[annotated_output, report_output, csv_output], |
| fn=analyze_rice_image, |
| cache_examples=False, |
| label="Click any sample to run analysis" |
| ) |
| else: |
| gr.Markdown("*No sample images found. Add images to the `samples/` folder.*") |
|
|
| if __name__ == "__main__": |
| app.queue() |
| app.launch() |