|
|
import os |
|
|
import gc |
|
|
import cpuinfo |
|
|
import gradio as gr |
|
|
from queue import Queue, Empty |
|
|
from threading import Event, Lock |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from huggingface_hub import snapshot_download |
|
|
import openvino_genai |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
class ChatbotSystem: |
|
|
def __init__(self): |
|
|
self.pipe_lock = Lock() |
|
|
self.mistral_pipe = None |
|
|
self.generation_executor = ThreadPoolExecutor(max_workers=1) |
|
|
|
|
|
|
|
|
self.initialize_model() |
|
|
|
|
|
|
|
|
self.sentiment_analyzer = pipeline("sentiment-analysis") |
|
|
|
|
|
def initialize_model(self): |
|
|
"""Download and initialize Mistral model""" |
|
|
if not os.path.exists("mistral-ov"): |
|
|
snapshot_download( |
|
|
repo_id="OpenVINO/mistral-7b-instruct-v0.1-int8-ov", |
|
|
local_dir="mistral-ov" |
|
|
) |
|
|
|
|
|
|
|
|
cpu_features = cpuinfo.get_cpu_info()["flags"] |
|
|
config_options = {} |
|
|
if "avx512" in cpu_features: |
|
|
config_options["ENFORCE_BF16"] = "YES" |
|
|
elif "avx2" in cpu_features: |
|
|
config_options["INFERENCE_PRECISION_HINT"] = "f32" |
|
|
|
|
|
|
|
|
self.mistral_pipe = openvino_genai.LLMPipeline( |
|
|
"mistral-ov", |
|
|
device="CPU", |
|
|
config={"PERFORMANCE_HINT": "THROUGHPUT", **config_options} |
|
|
) |
|
|
|
|
|
def analyze_sentiment(self, text: str) -> str: |
|
|
"""Detect sentiment (positive, negative, neutral)""" |
|
|
result = self.sentiment_analyzer(text[:512])[0] |
|
|
label, score = result["label"], result["score"] |
|
|
|
|
|
if label == "NEGATIVE" and score > 0.7: |
|
|
return "It sounds like you might be going through something difficult. Remember, you are not alone. 💙" |
|
|
elif label == "POSITIVE": |
|
|
return "I’m glad to hear that! Keep up the positive energy 🌟" |
|
|
else: |
|
|
return "I understand. Please feel free to share more about how you’re feeling." |
|
|
|
|
|
def generate_text_stream(self, prompt: str, max_tokens: int): |
|
|
"""Generate text with streaming""" |
|
|
response_queue = Queue() |
|
|
completion_event = Event() |
|
|
error = [None] |
|
|
|
|
|
config = openvino_genai.GenerationConfig( |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
streaming=True, |
|
|
streaming_interval=5 |
|
|
) |
|
|
|
|
|
def callback(tokens): |
|
|
response_queue.put("".join(tokens)) |
|
|
return openvino_genai.StreamingStatus.RUNNING |
|
|
|
|
|
def generate(): |
|
|
try: |
|
|
with self.pipe_lock: |
|
|
self.mistral_pipe.generate(prompt, config, callback) |
|
|
except Exception as e: |
|
|
error[0] = str(e) |
|
|
finally: |
|
|
completion_event.set() |
|
|
|
|
|
self.generation_executor.submit(generate) |
|
|
|
|
|
accumulated = [] |
|
|
while not completion_event.is_set() or not response_queue.empty(): |
|
|
if error[0]: |
|
|
yield f"❌ Error: {error[0]}" |
|
|
return |
|
|
try: |
|
|
token_batch = response_queue.get(timeout=0.1) |
|
|
accumulated.append(token_batch) |
|
|
yield "".join(accumulated) |
|
|
gc.collect() |
|
|
except Empty: |
|
|
continue |
|
|
|
|
|
yield "".join(accumulated) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot_system = ChatbotSystem() |
|
|
|
|
|
with gr.Blocks(title="Mental Health Chatbot (Mistral-OV)") as demo: |
|
|
gr.Markdown("# 💬 Mental Health Chatbot (Mistral-7B-OV + Sentiment Analysis)") |
|
|
chatbot = gr.Chatbot(label="Conversation", height=500) |
|
|
user_input = gr.Textbox(placeholder="Type your message...", label="Your Message") |
|
|
max_tokens = gr.Slider(minimum=50, maximum=1024, value=256, step=50, label="Max Tokens") |
|
|
send_btn = gr.Button("Send", variant="primary") |
|
|
|
|
|
def respond(message, history, max_tokens): |
|
|
if not message.strip(): |
|
|
return history, "" |
|
|
|
|
|
supportive_message = chatbot_system.analyze_sentiment(message) |
|
|
history = history + [[message, supportive_message]] |
|
|
|
|
|
response = supportive_message |
|
|
for chunk in chatbot_system.generate_text_stream(message, max_tokens): |
|
|
response = chunk |
|
|
history[-1][1] = response |
|
|
yield history, "" |
|
|
|
|
|
send_btn.click( |
|
|
fn=respond, |
|
|
inputs=[user_input, chatbot, max_tokens], |
|
|
outputs=[chatbot, user_input] |
|
|
) |
|
|
|
|
|
user_input.submit( |
|
|
fn=respond, |
|
|
inputs=[user_input, chatbot, max_tokens], |
|
|
outputs=[chatbot, user_input] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|