File size: 6,673 Bytes
0a86b5f
 
a26b24a
0a86b5f
 
ce0d2f1
 
0a86b5f
a26b24a
 
 
 
ce0d2f1
0a86b5f
 
a26b24a
0a86b5f
 
a26b24a
ce0d2f1
 
a26b24a
ce0d2f1
 
 
8cb1722
ce0d2f1
8cb1722
 
a26b24a
 
ce0d2f1
 
a26b24a
 
 
0a86b5f
a26b24a
ce0d2f1
a26b24a
 
 
 
ce0d2f1
 
 
 
 
 
 
a26b24a
ce0d2f1
 
 
a26b24a
0a86b5f
 
ce0d2f1
a26b24a
 
 
ce0d2f1
a26b24a
ce0d2f1
a26b24a
 
0a86b5f
a26b24a
 
0a86b5f
ce0d2f1
 
 
 
a26b24a
ce0d2f1
 
a26b24a
 
 
 
 
 
 
 
 
 
ce0d2f1
a26b24a
0a86b5f
ce0d2f1
 
 
a26b24a
 
 
 
 
 
 
ce0d2f1
 
a26b24a
ce0d2f1
0a86b5f
a26b24a
 
 
0a86b5f
9406f1c
795e7cf
a26b24a
0a86b5f
 
 
f24a63b
 
 
0a86b5f
 
f24a63b
 
0a86b5f
 
 
795e7cf
0a86b5f
 
ce5a771
f24a63b
0a86b5f
 
795e7cf
a26b24a
 
 
 
 
 
 
795e7cf
a26b24a
 
 
 
 
795e7cf
a26b24a
 
6cc43c1
 
a26b24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a86b5f
 
f24a63b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import streamlit as st
import torch
from transformers import pipeline
import torchaudio
import os
import re
import numpy as np

# -----------------------------
# 1) Model loading and utility functions
# -----------------------------

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Whisper model for Cantonese ASR
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
asr_pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME,
    chunk_length_s=30,  # Adjust chunk size for memory handling
    device=device,
    generate_kwargs={
        "no_repeat_ngram_size": 3,
        "repetition_penalty": 1.15,
        "temperature": 0.7,
        "top_p": 0.97,
        "top_k": 40,
        "max_new_tokens": 400,
        "do_sample": True
    }
)
asr_pipe.model.config.forced_decoder_ids = asr_pipe.tokenizer.get_decoder_prompt_ids(
    language=language, task="transcribe"
)

# Remove repeated sentences that are highly similar
def remove_repeated_phrases(text):
    def is_similar(a, b):
        from difflib import SequenceMatcher
        return SequenceMatcher(None, a, b).ratio() > 0.9

    sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
    cleaned_sentences = []
    for sentence in sentences:
        if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
            cleaned_sentences.append(sentence.strip())
    return " ".join(cleaned_sentences)

# Remove punctuation from text
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

# Transcribe the audio using Whisper
def transcribe_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)

    # Convert multi-channel audio to mono if necessary
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    waveform = waveform.squeeze(0).numpy()
    duration = waveform.shape[0] / sample_rate

    # For audio longer than 60 seconds, process in overlapping chunks
    if duration > 60:
        chunk_size = sample_rate * 55
        step_size = sample_rate * 50
        results = []
        for start in range(0, waveform.shape[0], step_size):
            chunk = waveform[start:start + chunk_size]
            if chunk.shape[0] == 0:
                break
            transcript = asr_pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
            results.append(remove_punctuation(transcript))
        return remove_punctuation(remove_repeated_phrases(" ".join(results)))
    else:
        transcript = asr_pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]
        return remove_punctuation(remove_repeated_phrases(transcript))

# Load sentiment analysis model
sentiment_pipe = pipeline(
    "text-classification",
    model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced",
    device=device
)

# Perform sentiment analysis in chunks (max 512 tokens each)
def rate_quality(text):
    chunks = [text[i:i+512] for i in range(0, len(text), 512)]
    results = sentiment_pipe(chunks, batch_size=4)

    label_map = {
        "Very Negative": "Very Poor",
        "Negative": "Poor",
        "Neutral": "Neutral",
        "Positive": "Good",
        "Very Positive": "Very Good"
    }
    processed_results = [label_map.get(res["label"], "Unknown") for res in results]

    # Use majority voting to determine the final sentiment
    return max(set(processed_results), key=processed_results.count)

# -----------------------------
# 2) Main Streamlit application
# -----------------------------
def main():
    st.set_page_config(page_title="Customer Service Analyzer", page_icon="πŸŽ™οΈ")

    # Custom CSS styling
    st.markdown("""
    <style>
    .header {
        background: linear-gradient(90deg, #4B79A1, #283E51);
        border-radius: 10px;
        padding: 1.5rem;
        text-align: center;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
        margin-bottom: 1.5rem;
        color: white;
    }
    </style>
    """, unsafe_allow_html=True)

    st.markdown("""
    <div class="header">
        <h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
        <p>Evaluate the service quality with simple uploading!</p>
    </div>
    """, unsafe_allow_html=True)

    # Initialize session state to store results
    if "transcript" not in st.session_state:
        st.session_state["transcript"] = ""
    if "quality_rating" not in st.session_state:
        st.session_state["quality_rating"] = ""
    if "uploaded_filename" not in st.session_state:
        st.session_state["uploaded_filename"] = ""

    # File uploader
    uploaded_file = st.file_uploader(
        "πŸ“€ Please upload your Cantonese customer service audio file",
        type=["wav", "mp3", "flac"]
    )

    if uploaded_file is not None:
        # Display audio player
        st.audio(uploaded_file, format="audio/wav")

        # Only run the model again if a new file is uploaded
        if st.session_state["uploaded_filename"] != uploaded_file.name:
            st.session_state["uploaded_filename"] = uploaded_file.name

            # Save uploaded file to a temporary path
            temp_audio_path = "uploaded_audio.wav"
            with open(temp_audio_path, "wb") as f:
                f.write(uploaded_file.getbuffer())

            # Process the audio
            with st.spinner('πŸ”„ Processing your audio, please wait...'):
                transcript = transcribe_audio(temp_audio_path)
                quality_rating = rate_quality(transcript)

            # Store results in session state
            st.session_state["transcript"] = transcript
            st.session_state["quality_rating"] = quality_rating

            # Remove the temporary file
            if os.path.exists(temp_audio_path):
                os.remove(temp_audio_path)

    # Display results if available
    if st.session_state["transcript"]:
        st.write("**Transcript:**", st.session_state["transcript"])
        st.write("**Sentiment Analysis Result:**", st.session_state["quality_rating"])

        # Prepare download content
        result_text = (
            f"Transcript:\n{st.session_state['transcript']}\n\n"
            f"Sentiment Analysis Result: {st.session_state['quality_rating']}"
        )
        # Download button for the analysis report
        st.download_button(
            label="πŸ“₯ Download Analysis Report",
            data=result_text,
            file_name="analysis_report.txt"
        )

    st.markdown(
        "❓If you encounter any issues, please contact customer support: "
        "πŸ“§ **example@hellotoby.com**"
    )

if __name__ == "__main__":
    main()