Spaces:
Running
Running
# app.py | |
import os | |
from pathlib import Path | |
import torch | |
from threading import Event, Thread | |
from typing import List, Tuple | |
# Importing necessary packages | |
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
from optimum.intel.openvino import OVModelForCausalLM | |
import openvino as ov | |
import openvino.properties as props | |
import openvino.properties.hint as hints | |
import openvino.properties.streams as streams | |
from gradio_helper import make_demo # UI logic import | |
from llm_config import SUPPORTED_LLM_MODELS | |
# Model configuration setup | |
max_new_tokens = 256 | |
model_language_value = "English" | |
model_id_value = 'qwen2.5-0.5b-instruct' | |
prepare_int4_model_value = True | |
enable_awq_value = False | |
device_value = 'CPU' | |
model_to_run_value = 'INT4' | |
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] | |
pt_model_name = model_id_value.split("-")[0] | |
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" | |
int4_weights = int4_model_dir / "openvino_model.bin" | |
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
model_name = model_configuration["model_id"] | |
start_message = model_configuration["start_message"] | |
history_template = model_configuration.get("history_template") | |
has_chat_template = model_configuration.get("has_chat_template", history_template is None) | |
current_message_template = model_configuration.get("current_message_template") | |
stop_tokens = model_configuration.get("stop_tokens") | |
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) | |
# Model loading | |
core = ov.Core() | |
ov_config = { | |
hints.performance_mode(): hints.PerformanceMode.LATENCY, | |
streams.num(): "1", | |
props.cache_dir(): "" | |
} | |
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) | |
ov_model = OVModelForCausalLM.from_pretrained( | |
int4_model_dir, | |
device=device_value, | |
ov_config=ov_config, | |
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), | |
trust_remote_code=True, | |
) | |
# Stopping criteria for token generation | |
class StopOnTokens(StoppingCriteria): | |
def __init__(self, token_ids): | |
self.token_ids = token_ids | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) | |
# Functions for chatbot logic | |
def convert_history_to_token(history: List[Tuple[str, str]]): | |
""" | |
function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template | |
Params: | |
history: dialogue history | |
Returns: | |
history in token format | |
""" | |
if pt_model_name == "baichuan2": | |
system_tokens = tok.encode(start_message) | |
history_tokens = [] | |
for old_query, response in history[:-1]: | |
round_tokens = [] | |
round_tokens.append(195) | |
round_tokens.extend(tok.encode(old_query)) | |
round_tokens.append(196) | |
round_tokens.extend(tok.encode(response)) | |
history_tokens = round_tokens + history_tokens | |
input_tokens = system_tokens + history_tokens | |
input_tokens.append(195) | |
input_tokens.extend(tok.encode(history[-1][0])) | |
input_tokens.append(196) | |
input_token = torch.LongTensor([input_tokens]) | |
elif history_template is None or has_chat_template: | |
messages = [{"role": "system", "content": start_message}] | |
for idx, (user_msg, model_msg) in enumerate(history): | |
if idx == len(history) - 1 and not model_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
break | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if model_msg: | |
messages.append({"role": "assistant", "content": model_msg}) | |
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") | |
else: | |
text = start_message + "".join( | |
["".join([history_template.format(num=round, user=item[0], assistant=item[1])]) for round, item in enumerate(history[:-1])] | |
) | |
text += "".join( | |
[ | |
"".join( | |
[ | |
current_message_template.format( | |
num=len(history) + 1, | |
user=history[-1][0], | |
assistant=history[-1][1], | |
) | |
] | |
) | |
] | |
) | |
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids | |
return input_token | |
# Initialize the search tool | |
search = DuckDuckGoSearchRun() | |
# Function to retrieve and format search results based on user input | |
def fetch_search_results(query: str) -> str: | |
search_results = search.invoke(query) | |
# Displaying search results for debugging | |
print("Search results: ", search_results) | |
return f"Relevant and recent information:\n{search_results}" | |
# Function to decide if a search is needed based on the user query | |
def should_use_search(query: str) -> bool: | |
# Simple heuristic, can be extended with more advanced intent analysis | |
search_keywords = ["latest", "news", "update", "which" "who", "what", "when", "why","how", "recent", "result", "tell", "explain", | |
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", | |
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", | |
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", | |
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", | |
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", | |
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", | |
"product", "performance", "resolution" | |
] | |
return any(keyword in query.lower() for keyword in search_keywords) | |
# Generate prompt for model with optional search context | |
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str: | |
# Simple instruction for the model to prioritize search information if available | |
instructions = ( | |
"If relevant information is provided below, use it to give an accurate and concise answer. If there is no relevant information available, please rely on your general knowledge and indicate that no recent or specific information is available to answer." | |
) | |
# Build the prompt with instructions, search context, and user query | |
prompt = f"{instructions}\n\n" | |
if search_context: | |
prompt += f"{search_context}\n\n" # Include search context prominently at the top | |
# Add the user's query | |
prompt += f"{user_query} ?\n\n" | |
# Optionally add recent history for context, without labels | |
# if history: | |
# prompt += "Recent conversation:\n" | |
# for user_msg, assistant_msg in history[:-1]: # Exclude the last message to prevent duplication | |
# prompt += f"{user_msg}\n{assistant_msg}\n" | |
return prompt | |
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
""" | |
Main callback function for running chatbot on submit button click. | |
""" | |
user_query = history[-1][0] | |
search_context = "" | |
# Decide if search is required based on the user query | |
if should_use_search(user_query): | |
search_context = fetch_search_results(user_query) | |
prompt = construct_model_prompt(user_query, search_context, history) | |
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids | |
else: | |
# If no search context, use the original logic with tokenization | |
prompt = construct_model_prompt(user_query, "", history) | |
input_ids = convert_history_to_token(history) | |
# Ensure input length does not exceed a threshold (e.g., 2000 tokens) | |
if input_ids.shape[1] > 2000: | |
# If input exceeds the limit, only use the most recent conversation | |
history = [history[-1]] | |
# Streamer for model response generation | |
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=256, # Adjust this as needed | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
streamer=streamer, | |
) | |
if stop_tokens is not None: | |
generate_kwargs["stopping_criteria"] = StoppingCriteriaList(stop_tokens) | |
# Event to signal when streaming is complete | |
stream_complete = Event() | |
def generate_and_signal_complete(): | |
ov_model.generate(**generate_kwargs) | |
stream_complete.set() | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text = text_processor(partial_text, new_text) | |
# Update the last entry in the original history with the response | |
history[-1] = (user_query, partial_text) | |
yield history | |
def request_cancel(): | |
ov_model.request.cancel() | |
# Gradio setup and launch | |
demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860) | |