Amine-0047 commited on
Commit
733b6a1
1 Parent(s): 55cc802

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +68 -0
  2. requirements.txt +5 -0
main.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, VitsModel, AutoTokenizer
4
+ import torch
5
+ import yolov5
6
+
7
+ # Load YOLOv5 model
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_model():
10
+ return yolov5.load('keremberke/yolov5m-license-plate')
11
+
12
+ # Load TR-OCR model
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_ocr_model():
15
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
16
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
17
+ return processor, model
18
+
19
+ # Load TTS model
20
+ @st.cache(allow_output_mutation=True)
21
+ def load_tts_model():
22
+ model = VitsModel.from_pretrained("facebook/mms-tts-eng")
23
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
24
+ return model, tokenizer
25
+
26
+ # Main function for Streamlit app
27
+ def main():
28
+ st.title("License Plate Recognition App")
29
+
30
+ # Upload file
31
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
32
+
33
+ # Load models on startup
34
+ model = load_model()
35
+ processor, ocr_model = load_ocr_model()
36
+ tts_model, tokenizer = load_tts_model()
37
+
38
+ if uploaded_file is not None:
39
+ img = Image.open(uploaded_file)
40
+ st.image(img, caption='Uploaded Image', use_column_width=True)
41
+
42
+ if st.button("Run Inference"):
43
+ results = model(img, size=640)
44
+ # results.show()
45
+ predictions = results.pred[0]
46
+ boxes = predictions[:, :4] # x1, y1, x2, y2
47
+ scores = predictions[:, 4]
48
+ categories = predictions[:, 5]
49
+
50
+ # Crop the image of the license plate
51
+ cropped_image = img.crop(tuple(results.xyxy[0][0, :4].squeeze().tolist()[:4]))
52
+ st.image(cropped_image, caption='Plate detected')
53
+
54
+ # Extract text from the image
55
+ pixel_values = processor(cropped_image, return_tensors="pt").pixel_values
56
+ generated_ids = ocr_model.generate(pixel_values)
57
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
+
59
+ st.write("Detected License Plate Text:", generated_text)
60
+
61
+ # Convert the text to audio
62
+ inputs = tokenizer(generated_text, return_tensors="pt")
63
+ with torch.no_grad():
64
+ output = tts_model(**inputs).waveform
65
+ st.audio(output.numpy(), format="audio/wav", sample_rate=tts_model.config.sampling_rate)
66
+
67
+ if __name__ == "__main__":
68
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Pillow==10.2.0
2
+ streamlit==1.32.2
3
+ torch==2.2.1
4
+ transformers==4.40.0.dev0
5
+ yolov5==7.0.13