detectweed / app.py
Esmaeilkiani's picture
Update app.py
d281fdb verified
from PIL import Image
import torchvision.transforms as transforms
# لود مدل YOLOv5
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
# پیش‌پردازش تصویر
def preprocess_image(image_path):
image = Image.open(image_path)
transform = transforms.Compose([
transforms.Resize((640, 640)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# آموزش مدل
def train_model(data_dir, epochs=10):
# آماده‌سازی داده‌ها
dataset = ... # خواندن داده‌ها از data_dir
dataloader = ... # ایجاد DataLoader
# تنظیم پارامترهای آموزش
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')
# تشخیص مناطق دارای گپ
def detect_gaps(image_path):
image = preprocess_image(image_path)
results = model(image)
return results
# مثال استفاده
image_path = '/content/Sugarcane-Cultivation-in-Tamil-Nadu-1.jpg'
results = detect_gaps(image_path)
print(results)