File size: 6,369 Bytes
39bcf9e
 
 
bba5127
325d2bc
 
 
 
 
 
39bcf9e
325d2bc
39bcf9e
 
325d2bc
39bcf9e
 
 
 
 
325d2bc
 
 
 
 
 
39bcf9e
325d2bc
 
 
39bcf9e
325d2bc
 
39bcf9e
325d2bc
39bcf9e
325d2bc
39bcf9e
 
9d6aaa3
325d2bc
39bcf9e
 
 
325d2bc
 
39bcf9e
 
325d2bc
39bcf9e
 
 
325d2bc
 
39bcf9e
325d2bc
39bcf9e
325d2bc
39bcf9e
c2b31a2
39bcf9e
325d2bc
39bcf9e
325d2bc
 
 
39bcf9e
325d2bc
8bf177f
39bcf9e
325d2bc
 
39bcf9e
325d2bc
 
39bcf9e
325d2bc
39bcf9e
 
325d2bc
 
 
 
 
 
 
 
 
8bf177f
 
39bcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325d2bc
 
39bcf9e
 
325d2bc
 
39bcf9e
325d2bc
 
39bcf9e
325d2bc
 
 
 
 
 
 
39bcf9e
 
325d2bc
 
 
39bcf9e
325d2bc
 
bba5127
 
39bcf9e
bba5127
39bcf9e
 
bba5127
 
 
c2b31a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os 
# Disable hf_transfer for safer downloading
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer, util
import torch
import json
import urllib.parse
import soundfile as sf
import time 

# Fetch Hugging Face API Token securely from environment variables
HF_API_TOKEN = os.getenv("HF")# This fetches the token securely

# Updated model URLs for Whisper and LLaMA
WHISPER_API_URL = "https://api-inference.huggingface.co/models/openai/whisper-small"
LLAMA_API_URL = "https://api-inference.huggingface.co/models/abhinand/tamil-llama-instruct-v0.2"

# Load SentenceTransformer model for retrieval
retriever_model = SentenceTransformer("distiluse-base-multilingual-cased-v2")

# Load dataset
with open("qa_dataset.json", "r", encoding="utf-8") as f:
    qa_data = json.load(f)

# Function to transcribe audio using Whisper
def wait_for_model_ready(model_url, headers, timeout=300):
    start_time = time.time()
    while time.time() - start_time < timeout:
        # Send a "dummy" GET request to check status
        response = requests.get(model_url, headers=headers)
        result = response.json()

        if not ("error" in result and "loading" in result["error"].lower()):
            print("✅ Model is ready!")
            return True

        print("⏳ Model is still loading, waiting 10 seconds...")
        time.sleep(10)

    print("❌ Model did not become ready in time.")
    return False  # timeout

def transcribe_audio(audio_file):
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}

    # Wait for Whisper model to be ready
    if not wait_for_model_ready(WHISPER_API_URL, headers):
        return "Error: Whisper model did not load in time. Please try again later."

    # Now send the audio after model is ready
    with open(audio_file, "rb") as f:
        response = requests.post(WHISPER_API_URL, headers=headers, data=f)

    result = response.json()
    print(result)  # log response

    return result.get("text", "Error: No transcription text returned.")

# Function to generate TTS audio URL (Google Translate API for Tamil Voice)
def get_tts_audio_url(text, lang="ta"):
    # URL encode the text to ensure special characters are handled
    safe_text = urllib.parse.quote(text)
    return f"https://translate.google.com/translate_tts?ie=UTF-8&q={safe_text}&tl={lang}&client=tw-ob"

# Function to retrieve a relevant response from the Q&A dataset using SentenceTransformer
def get_bot_response(query):
    print("[INFO] User Query:", query)

    query_embedding = retriever_model.encode(query, convert_to_tensor=True)
    qa_embeddings = retriever_model.encode([qa["question"] for qa in qa_data], convert_to_tensor=True)

    scores = util.pytorch_cos_sim(query_embedding, qa_embeddings)
    best_idx = torch.argmax(scores)

    top_qa = qa_data[best_idx]
    print("[INFO] Best Match Question:", top_qa['question'])
    print("[INFO] Best Match Answer:", top_qa['answer'])

    prompt = f"""நீ ஒரு அறிவாளியான தமிழ் உதவியாளர்.
    தகவல்கள்:
    கேள்வி: {top_qa['question']}
    பதில்: {top_qa['answer']}
    மேலே உள்ள தகவல்களைப் பயன்படுத்தி, தெளிவான மற்றும் சுருக்கமான பதிலை வழங்கவும்.
    உயர்கட்ட கேள்வி: {query}
    பதில்:"""

    print("[DEBUG] Prompt Sent to LLaMA:\n", prompt)

    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "temperature": 0.7,
            "max_new_tokens": 150,
            "return_full_text": False
        },
    }

    response = requests.post(LLAMA_API_URL, headers=headers, json=payload, timeout=300)

    start_time = time.time()
    max_wait_seconds = 180
    while True:
        try:
            result = response.json()
            print("[DEBUG] Raw response from LLaMA:", result)

            if isinstance(result, list) and "generated_text" in result[0]:
                return result[0]["generated_text"]
            elif "error" in result and "loading" in result["error"].lower():
                print("⏳ Model is loading, waiting 10 seconds...")
                time.sleep(10)
            else:
                return "மன்னிக்கவும், நான் இந்த கேள்விக்கு பதில் தர முடியவில்லை."

        except Exception as e:
            print("[ERROR] Exception during LLaMA call:", str(e))
            if time.time() - start_time > max_wait_seconds:
                return f"Error: Timeout while waiting for model prediction after {max_wait_seconds} seconds."
            time.sleep(5)


# Gradio interface function
def chatbot(audio, message, system_message, max_tokens, temperature, top_p):
    if audio is not None:
        sample_rate, audio_data = audio  # ✅ Correct order
        sf.write("temp.wav", audio_data, sample_rate)  # Save audio
        try:
            transcript = transcribe_audio("temp.wav")
            message = transcript  # Use transcribed text
        except Exception as e:
            return f"Audio transcription failed: {str(e)}", None

    try:
        response = get_bot_response(message)
        audio_url = get_tts_audio_url(response)
        return response, audio_url
    except Exception as e:
        return f"Error in generating response: {str(e)}", None


# Define Gradio interface
demo = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Audio(type="numpy", label="Speak to the Bot"),  # Adjusted for microphone input
        gr.Textbox(value="How can I help you?", label="Text Input (optional)"),
        gr.Textbox(value="You are a friendly chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    outputs=[gr.Textbox(label="Response"), gr.Audio(label="Bot's Voice Response (Tamil)")],
    live=True,
)

if __name__ == "__main__":
    demo.launch()