import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel from PIL import Image import logging import spaces import numpy # Setup logging logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = "cuda" self.model_id = model_id logging.info("Initializing LLaVA-Phi model...") # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token try: # Use CLIPProcessor directly instead of AutoProcessor self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") logging.info("Successfully loaded CLIP processor") except Exception as e: logging.error(f"Failed to load CLIP processor: {str(e)}") self.processor = None # Increase history length to retain more context self.history = [] self.model = None self.clip = None @spaces.GPU def ensure_models_loaded(self): """Ensure models are loaded in GPU context""" if self.model is None: # Improved quantization config for better quality from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_8bit=True, # Changed from 4-bit to 8-bit for better quality bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=False ) try: self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.model.config.pad_token_id = self.tokenizer.eos_token_id logging.info("Successfully loaded main model") except Exception as e: logging.error(f"Failed to load main model: {str(e)}") raise if self.clip is None: try: # Use CLIPModel directly instead of AutoModel self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) logging.info("Successfully loaded CLIP model") except Exception as e: logging.error(f"Failed to load CLIP model: {str(e)}") self.clip = None @spaces.GPU def process_image(self, image): """Process image through CLIP if available""" try: self.ensure_models_loaded() if self.clip is None or self.processor is None: logging.warning("CLIP model or processor not available") return None # Convert image to correct format if isinstance(image, str): image = Image.open(image) elif isinstance(image, numpy.ndarray): image = Image.fromarray(image) # Ensure image is in RGB mode if image.mode != 'RGB': image = image.convert('RGB') with torch.no_grad(): try: # Process image with error handling image_inputs = self.processor(images=image, return_tensors="pt") image_features = self.clip.get_image_features( pixel_values=image_inputs.pixel_values.to(self.device) ) logging.info("Successfully processed image through CLIP") return image_features except Exception as e: logging.error(f"Error during image processing: {str(e)}") return None except Exception as e: logging.error(f"Error in process_image: {str(e)}") return None @spaces.GPU(duration=120) def generate_response(self, message, image=None): try: self.ensure_models_loaded() if image is not None: image_features = self.process_image(image) has_image = image_features is not None if not has_image: message = "Note: Image processing is not available - continuing with text only.\n" + message prompt = f"human: {'' if has_image else ''}\n{message}\ngpt:" # Include more history for better context (previous 5 turns instead of 3) context = "" for turn in self.history[-5:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt # Increased context window inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024 # Increased from 512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if has_image: inputs["image_features"] = image_features with torch.no_grad(): # More conservative generation settings to reduce hallucinations outputs = self.model.generate( **inputs, max_new_tokens=256, min_length=20, temperature=0.3, # Reduced from 0.7 for more deterministic output do_sample=True, top_p=0.92, top_k=50, repetition_penalty=1.2, # Adjusted for more natural responses no_repeat_ngram_size=3, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) else: prompt = f"human: {message}\ngpt:" # Include more history context = "" for turn in self.history[-5:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt # Increased context window inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024 # Increased from 512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): # More conservative generation settings outputs = self.model.generate( **inputs, max_new_tokens=200, # Slightly increased from 150 min_length=20, temperature=0.3, # Reduced from 0.6 do_sample=True, top_p=0.92, top_k=50, repetition_penalty=1.2, no_repeat_ngram_size=4, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up response if "gpt:" in response: response = response.split("gpt:")[-1].strip() if "human:" in response: response = response.split("human:")[0].strip() if "" in response: response = response.replace("", "").strip() self.history.append((message, response)) return response except Exception as e: logging.error(f"Error generating response: {str(e)}") logging.error(f"Full traceback:", exc_info=True) return f"Error: {str(e)}" def clear_history(self): self.history = [] return None # Add new function to control generation parameters def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2): """Update generation parameters to control hallucination tendency""" self.temperature = temperature self.top_p = top_p self.top_k = top_k self.repetition_penalty = repetition_penalty return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}" def create_demo(): try: model = LLaVAPhiModel() with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown( """ # LLaVA-Phi Demo (Optimized for Accuracy) Chat with a vision-language model that can understand both text and images. """ ) chatbot = gr.Chatbot(height=400) with gr.Row(): with gr.Column(scale=0.7): msg = gr.Textbox( show_label=False, placeholder="Enter text and/or upload an image", container=False ) with gr.Column(scale=0.15, min_width=0): clear = gr.Button("Clear") with gr.Column(scale=0.15, min_width=0): submit = gr.Button("Submit", variant="primary") image = gr.Image(type="pil", label="Upload Image (Optional)") # Add generation parameter controls with gr.Accordion("Advanced Settings", open=False): gr.Markdown("Adjust these parameters to control hallucination tendency") temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)") top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)") top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k") rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") update_params = gr.Button("Update Parameters") def respond(message, chat_history, image): if not message and image is None: return chat_history response = model.generate_response(message, image) chat_history.append((message, response)) return "", chat_history def clear_chat(): model.clear_history() return None, None def update_params_fn(temp, top_p, top_k, rep_penalty): return model.update_generation_params(temp, top_p, top_k, rep_penalty) submit.click( respond, [msg, chatbot, image], [msg, chatbot], ) clear.click( clear_chat, None, [chatbot, image], ) msg.submit( respond, [msg, chatbot, image], [msg, chatbot], ) update_params.click( update_params_fn, [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider], None ) return demo except Exception as e: logging.error(f"Error creating demo: {str(e)}") raise if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )