musadac commited on
Commit
2c72c7b
1 Parent(s): 36debf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -76,19 +76,25 @@ tokenizer = MBartTokenizer.from_pretrained(
76
  'facebook/mbart-large-50'
77
  )
78
  processortext2 = CustomOCRProcessor(image_processor,tokenizer)
79
- import os
80
 
 
81
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
82
- model2 = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu", use_auth_token=huggingface_token)
 
 
 
 
 
83
  st.title("Image OCR with musadac/vilanocr")
 
84
  uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
85
-
86
  if uploaded_file is not None:
 
87
  img = Image.open(uploaded_file).convert("RGB")
88
  pixel_values = processortext2(img.convert("RGB"), return_tensors="pt").pixel_values
89
 
90
  with torch.no_grad():
91
- generated_ids = model2.generate(pixel_values)
92
 
93
  result = processortext2.batch_decode(generated_ids, skip_special_tokens=True)[0]
94
  st.write("OCR Result:")
 
76
  'facebook/mbart-large-50'
77
  )
78
  processortext2 = CustomOCRProcessor(image_processor,tokenizer)
 
79
 
80
+ import os
81
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
82
+ model = {}
83
+ model['single-urdu'] = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu", use_auth_token=huggingface_token)
84
+ model['multi-urdu'] = VisionEncoderDecoderModel.from_pretrained("musadac/ViLanOCR", use_auth_token=huggingface_token)
85
+ model['medical'] = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-multi-medical", use_auth_token=huggingface_token)
86
+ model['chinese'] = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-chinese", use_auth_token=huggingface_token)
87
+
88
  st.title("Image OCR with musadac/vilanocr")
89
+ model_name = st.selectbox("Choose an OCR model", ["single-urdu", "multi-urdu", "medical","chinese" ])
90
  uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
 
91
  if uploaded_file is not None:
92
+
93
  img = Image.open(uploaded_file).convert("RGB")
94
  pixel_values = processortext2(img.convert("RGB"), return_tensors="pt").pixel_values
95
 
96
  with torch.no_grad():
97
+ generated_ids = model[model_name].generate(pixel_values)
98
 
99
  result = processortext2.batch_decode(generated_ids, skip_special_tokens=True)[0]
100
  st.write("OCR Result:")