# app.py import os from pathlib import Path import torch from threading import Event, Thread from typing import List, Tuple # Importing necessary packages from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from langchain_community.tools import DuckDuckGoSearchRun from optimum.intel.openvino import OVModelForCausalLM import openvino as ov import openvino.properties as props import openvino.properties.hint as hints import openvino.properties.streams as streams from gradio_helper import make_demo # UI logic import from llm_config import SUPPORTED_LLM_MODELS # Model configuration setup max_new_tokens = 256 model_language_value = "English" model_id_value = 'qwen2.5-0.5b-instruct' prepare_int4_model_value = True enable_awq_value = False device_value = 'CPU' model_to_run_value = 'INT4' pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] pt_model_name = model_id_value.split("-")[0] int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" int4_weights = int4_model_dir / "openvino_model.bin" model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] model_name = model_configuration["model_id"] start_message = model_configuration["start_message"] history_template = model_configuration.get("history_template") has_chat_template = model_configuration.get("has_chat_template", history_template is None) current_message_template = model_configuration.get("current_message_template") stop_tokens = model_configuration.get("stop_tokens") tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) # Model loading core = ov.Core() ov_config = { hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): "" } tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) ov_model = OVModelForCausalLM.from_pretrained( int4_model_dir, device=device_value, ov_config=ov_config, config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), trust_remote_code=True, ) # Define stopping criteria for specific token sequences class StopOnTokens(StoppingCriteria): def __init__(self, token_ids): self.token_ids = token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) if stop_tokens is not None: if isinstance(stop_tokens[0], str): stop_tokens = tok.convert_tokens_to_ids(stop_tokens) stop_tokens = [StopOnTokens(stop_tokens)] # Helper function for partial text update def default_partial_text_processor(partial_text: str, new_text: str) -> str: return partial_text + new_text text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor) # Convert conversation history to tokens based on model template def convert_history_to_token(history: List[Tuple[str, str]]): if pt_model_name == "baichuan2": system_tokens = tok.encode(start_message) history_tokens = [] for old_query, response in history[:-1]: round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response) history_tokens = round_tokens + history_tokens input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196] input_token = torch.LongTensor([input_tokens]) elif history_template is None or has_chat_template: messages = [{"role": "system", "content": start_message}] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") else: text = start_message + "".join( [history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])] ) text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1]) input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids return input_token # Initialize search tool search = DuckDuckGoSearchRun() # Determine if a search is needed based on the query def should_use_search(query: str) -> bool: search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current", "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", "product", "performance", "resolution" ] return any(keyword in query.lower() for keyword in search_keywords) # Construct the prompt with optional search context def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str: instructions = ( "Based on the information provided below, deliver an accurate, concise, and easily understandable answer. If relevant information is missing, draw on your general knowledge and mention the absence of specific details." ) prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n" return prompt # Fetch search results for a query def fetch_search_results(query: str) -> str: search_results = search.invoke(query) print("Search results:", search_results) # Optional: Debugging output return f"Relevant and recent information:\n{search_results}" # Main chatbot function def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): user_query = history[-1][0] search_context = fetch_search_results(user_query) if should_use_search(user_query) else "" prompt = construct_model_prompt(user_query, search_context, history) input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history) # Limit input length to avoid exceeding token limit if input_ids.shape[1] > 2000: history = [history[-1]] # Configure response streaming streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": input_ids, "max_new_tokens": max_new_tokens, "temperature": temperature, "do_sample": temperature > 0.0, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "streamer": streamer, "stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None, } # Signal completion stream_complete = Event() def generate_and_signal_complete(): try: ov_model.generate(**generate_kwargs) except RuntimeError as e: # Check if the error message indicates the request was canceled if "Infer Request was canceled" in str(e): print("Generation request was canceled.") else: # If it's a different RuntimeError, re-raise it raise e finally: # Signal completion of the stream stream_complete.set() t1 = Thread(target=generate_and_signal_complete) t1.start() partial_text = "" for new_text in streamer: partial_text = text_processor(partial_text, new_text) history[-1] = (user_query, partial_text) yield history def request_cancel(): ov_model.request.cancel() # Gradio setup and launch demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value) if __name__ == "__main__": demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)