Spaces:
Running
Running
import os | |
import threading | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from transformers.generation.utils import DynamicCache | |
DynamicCache.get_max_length = DynamicCache.get_max_cache_shape | |
# Check if llama-cpp-python is available | |
def check_llamacpp_available(): | |
try: | |
import llama_cpp | |
return True | |
except ImportError: | |
return False | |
# Global cache for model and tokenizer | |
MODEL_CACHE = {} | |
def load_text_model(model_name, quantize=False): | |
""" | |
Load text model with appropriate configuration for CPU or GPU | |
Args: | |
model_name (str): Hugging Face model ID | |
quantize (bool): Whether to use 4-bit quantization (only works with GPU) | |
Returns: | |
tuple: (model, tokenizer) | |
""" | |
# Check cache first | |
cache_key = f"{model_name}_{quantize}" | |
if cache_key in MODEL_CACHE: | |
return MODEL_CACHE[cache_key] | |
# Check CUDA availability | |
cuda_available = torch.cuda.is_available() | |
# Only try quantization if CUDA is available | |
if quantize and cuda_available: | |
try: | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True | |
) | |
except Exception as e: | |
print(f"Quantization config creation failed: {e}") | |
quantization_config = None | |
quantize = False | |
else: | |
quantization_config = None | |
quantize = False | |
# Try loading the model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Fix for attention mask warning | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Try with quantization first if requested and available | |
if quantize and quantization_config: | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
except Exception as e: | |
print(f"Failed to load with quantization: {e}") | |
quantize = False | |
# If quantization is not used or failed, try standard loading | |
if not quantize: | |
# For CPU, just load without specifing dtype | |
if not cuda_available: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
else: | |
# Try different dtypes for GPU | |
for dtype in (torch.float16, torch.float32): | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=dtype, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
break | |
except Exception as e: | |
if dtype == torch.float32: | |
# Last resort: try without specifying dtype | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
# Cache the loaded model and tokenizer | |
MODEL_CACHE[cache_key] = (model, tokenizer) | |
return model, tokenizer | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model {model_name}: {e}") | |
def format_prompt(tokenizer, query): | |
""" | |
Format prompt according to model's requirements | |
Args: | |
tokenizer: The model tokenizer | |
query (str): User query | |
Returns: | |
str: Formatted prompt | |
""" | |
enhanced_query = f"Please answer this question about pharmaceuticals or medical topics.\n\nQuestion: {query}" | |
# Use chat template if available | |
if hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template")): | |
messages = [{"role": "user", "content": enhanced_query}] | |
try: | |
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
return formatted | |
except: | |
# Fallback if chat template fails | |
pass | |
# Simple formatting fallback | |
return f"User: {enhanced_query}\nAssistant:" | |
def generate_text_with_transformers(model, tokenizer, query, max_tokens=512, temperature=0.7, | |
top_p=0.9, repetition_penalty=1.1, cancel_event=None, | |
progress_callback=None): | |
""" | |
Generate text using the transformers pipeline | |
Args: | |
model: The language model | |
tokenizer: The tokenizer | |
query (str): User query | |
max_tokens (int): Maximum tokens to generate | |
temperature (float): Temperature for sampling | |
top_p (float): Top-p sampling parameter | |
repetition_penalty (float): Penalty for repetition | |
cancel_event (threading.Event): Event to signal cancellation | |
progress_callback (callable): Function to report progress | |
Returns: | |
str: Generated response | |
""" | |
# Format the prompt | |
prompt = format_prompt(tokenizer, query) | |
# Prepare inputs | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Update progress | |
if progress_callback: | |
progress_callback(0.2, "Starting generation...") | |
try: | |
from transformers import TextIteratorStreamer | |
# Set up streamer for token-by-token generation | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Prepare generation parameters | |
generation_kwargs = { | |
"input_ids": inputs.input_ids, | |
"attention_mask": inputs.attention_mask, # Explicitly provide attention mask | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"do_sample": temperature > 0.1, | |
"streamer": streamer | |
} | |
# Start generation in a separate thread | |
generation_thread = threading.Thread( | |
target=model.generate, | |
kwargs=generation_kwargs | |
) | |
generation_thread.start() | |
# Collect tokens as they're generated | |
response_text = "" | |
for i, new_text in enumerate(streamer): | |
if cancel_event and cancel_event.is_set(): | |
break | |
response_text += new_text | |
# Update progress periodically | |
if progress_callback and i % 5 == 0: | |
progress_callback(0.3 + min(0.6, len(response_text) / 500), "Generating response...") | |
return response_text | |
except Exception as e: | |
print(f"Streaming generation failed, falling back to standard generation: {e}") | |
# Fallback to standard generation | |
try: | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=temperature > 0.1, | |
) | |
# Decode and remove prompt | |
prompt_length = inputs.input_ids.shape[1] | |
response = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True) | |
return response | |
except Exception as e2: | |
return f"Error in text generation: {e2}" | |
# Global llamacpp model cache | |
LLAMA_MODEL = None | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
def load_llamacpp_model(model_path=None): | |
"""Load the llama.cpp model, downloading from HF Hub if needed.""" | |
global LLAMA_MODEL | |
# Return cached model if available | |
if LLAMA_MODEL is not None: | |
return LLAMA_MODEL | |
# 1) Look for existing file on disk | |
if model_path is None: | |
possible_paths = [ | |
"models/Phi-3-mini-4k-instruct.Q4_K_M.gguf", | |
os.path.join(os.path.dirname(os.path.dirname(__file__)), "models/Phi-3-mini-4k-instruct.Q4_K_M.gguf"), | |
"/models/Phi-3-mini-4k-instruct.Q4_K_M.gguf", | |
os.path.expanduser("~/.cache/huggingface/hub/models/Phi-3-mini-4k-instruct.Q4_K_M.gguf"), | |
] | |
for p in possible_paths: | |
if os.path.exists(p): | |
model_path = p | |
break | |
# 2) If still not found, download into models/ | |
if model_path is None: | |
print("→ GGUF not found locally, downloading from HF Hub…") | |
model_path = hf_hub_download( | |
repo_id="MohammedSameerSyed/phi3-gguf", # <— YOUR HF repo with the .gguf | |
filename="Phi-3-mini-4k-instruct.Q4_K_M.gguf", | |
cache_dir="models", # will create models/ if needed | |
) | |
# 3) Finally load with llama.cpp | |
try: | |
LLAMA_MODEL = Llama( | |
model_path=model_path, | |
n_ctx=4096, # full 4K context | |
n_batch=512, | |
n_threads=4, | |
n_gpu_layers=0 | |
) | |
return LLAMA_MODEL | |
except Exception as e: | |
raise RuntimeError(f"Failed to load llama.cpp model: {e}") | |
def generate_text_with_llamacpp(query, max_tokens=512, temperature=0.7, top_p=0.9, | |
stop=None, cancel_event=None, progress_callback=None, model_path=None): | |
""" | |
Generate text using llama.cpp | |
Args: | |
query (str): User query | |
max_tokens (int): Maximum tokens to generate | |
temperature (float): Temperature for sampling | |
top_p (float): Top-p sampling parameter | |
stop (list): List of stop sequences | |
cancel_event (threading.Event): Event to signal cancellation | |
progress_callback (callable): Function to report progress | |
model_path (str): Path to GGUF model file (optional) | |
Returns: | |
str: Generated response | |
""" | |
if progress_callback: | |
progress_callback(0.1, "Loading llama.cpp model...") | |
# Load model | |
try: | |
model = load_llamacpp_model(model_path) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load llama.cpp model: {e}") | |
if progress_callback: | |
progress_callback(0.3, "Starting generation...") | |
# Format prompt | |
prompt = f"You are a helpful pharmaceutical assistant. Please answer this question about medications or medical topics.\n\nQuestion: {query}\n\nAnswer:" | |
# Define stop sequences if not provided | |
if stop is None: | |
stop = ["Question:", "\n\n"] | |
try: | |
# Check if create_completion method exists (newer versions) | |
if hasattr(model, "create_completion"): | |
# Stream response | |
response_text = "" | |
# Generate completion with streaming | |
stream = model.create_completion( | |
prompt, | |
max_tokens=1024, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=40, | |
stop=None, | |
stream=True | |
) | |
# Process stream | |
for i, chunk in enumerate(stream): | |
if cancel_event and cancel_event.is_set(): | |
break | |
text_chunk = chunk["choices"][0]["text"] | |
response_text += text_chunk | |
# Update progress periodically | |
if progress_callback and i % 5 == 0: | |
progress_callback(0.4 + min(0.5, len(response_text) / 500), "Generating response...") | |
return response_text.strip() | |
else: | |
# Fallback to older call method | |
result = model( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=40, | |
stop=stop, | |
echo=False | |
) | |
if progress_callback: | |
progress_callback(0.9, "Finalizing...") | |
return result["choices"][0]["text"].strip() | |
except Exception as e: | |
raise RuntimeError(f"Error in llama.cpp generation: {e}") |