File size: 1,665 Bytes
dde79d4
 
 
 
 
0693cef
5c7cd99
 
 
0693cef
dde79d4
238c064
dde79d4
5c7cd99
 
 
dde79d4
 
 
0693cef
1778e34
0693cef
 
 
 
 
 
7b9bad8
1778e34
 
 
 
 
 
 
4fcf9ac
238c064
5c7cd99
 
 
 
0cdadfe
 
 
 
 
 
 
 
238c064
449029d
86ffcfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
from PIL import Image
import torch
from transformers import AutoImageProcessor
import pandas as pd
from transformers import ViTForImageClassification
from transformers import VitsModel, AutoTokenizer
import torch
from IPython.display import Audio

# Streamlit application title
st.title("Phonically describe traffic signs")

#Traffic Sign Classification
model= ViTForImageClassification.from_pretrained('Rae1230/Traffic_Signs_Classification')
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

uploaded_file = st.file_uploader("Choose a PNG image...", type="png", accept_multiple_files=False)
if uploaded_file is not None:
    img = Image.open(uploaded_file)
    st.image(img, caption='Uploaded Image.', use_column_width=True)
    inputs = processor(img.convert('RGB'), return_tensors="pt")

    outputs = model(**inputs)

    logits = outputs.logits
    img_class_idx=logits.argmax(-1).item()
    with open("labels.csv", "r") as file:
        df = pd.read_csv(file)
    
    
    num_col = df['ClassId']
    text_col = df['Name']
    
    text_value = text_col.loc[num_col == img_class_idx].values[0]
    st.header("Classified traffic sign:")
    st.write(text_value)


#speech the Traffic Sign

    model2 = VitsModel.from_pretrained("facebook/mms-tts-eng")
    tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
    
    text = text_value
    inputs = tokenizer(text, return_tensors="pt")
    
    with torch.no_grad():
        output = model2(**inputs).waveform
        
    st.header("Phonically describe this traffic sign:")
    st.audio(output.numpy(),sample_rate=model2.config.sampling_rate)