|
|
|
|
|
import os |
|
import gc |
|
import logging |
|
import traceback |
|
import time |
|
import transformers |
|
import torch |
|
import gradio as gr |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
GenerationConfig |
|
) |
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger("DamageScan-App") |
|
|
|
|
|
|
|
|
|
MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged" |
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. |
|
|
|
If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" |
|
|
|
|
|
|
|
|
|
def optimize_memory(): |
|
"""Optimize memory usage by clearing caches and forcing garbage collection""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
logger.info("Memory optimized: caches cleared and garbage collected") |
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(): |
|
"""Load the model with comprehensive error handling and logging""" |
|
logger.info(f"Loading model: {MODEL_ID}") |
|
logger.info(f"Transformers version: {transformers.__version__}") |
|
logger.info(f"PyTorch version: {torch.__version__}") |
|
|
|
|
|
device_info = { |
|
"cuda_available": torch.cuda.is_available(), |
|
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, |
|
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
|
} |
|
logger.info(f"Device information: {device_info}") |
|
|
|
|
|
try: |
|
logger.info("Loading base Llama tokenizer for pipeline...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
trust_remote_code=True |
|
) |
|
logger.info(f"Base tokenizer loaded: {type(tokenizer).__name__}") |
|
except Exception as e: |
|
logger.warning(f"Could not load base tokenizer: {str(e)}") |
|
logger.warning("Will try to initialize pipeline without explicit tokenizer") |
|
tokenizer = None |
|
|
|
|
|
try: |
|
logger.info("Loading model...") |
|
model_start = time.time() |
|
|
|
|
|
if device_info["cuda_available"]: |
|
device_map = "auto" |
|
torch_dtype = torch.float16 |
|
logger.info("Using 'auto' device map for CUDA with float16 precision") |
|
elif device_info["mps_available"]: |
|
device_map = {"": "mps"} |
|
torch_dtype = torch.float16 |
|
logger.info("Using MPS device with float16 precision") |
|
else: |
|
device_map = {"": "cpu"} |
|
torch_dtype = torch.float32 |
|
logger.info("Using CPU with float32 precision") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch_dtype, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
) |
|
model.eval() |
|
model_load_time = time.time() - model_start |
|
logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds") |
|
|
|
|
|
try: |
|
model_info = { |
|
"model_type": model.config.model_type, |
|
"hidden_size": model.config.hidden_size, |
|
"vocab_size": model.config.vocab_size, |
|
"num_hidden_layers": model.config.num_hidden_layers |
|
} |
|
logger.info(f"Model properties: {model_info}") |
|
except Exception as e: |
|
logger.warning(f"Could not log all model properties: {str(e)}") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load model: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
raise RuntimeError(f"Failed to load model: {str(e)}") |
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
def format_prompt(messages, system_prompt=DEFAULT_SYSTEM_PROMPT): |
|
""" |
|
Format messages into a simplified prompt for the model. |
|
This is an ultra-simplified version that just uses plain text. |
|
""" |
|
logger.info(f"Formatting prompt with {len(messages)} messages") |
|
|
|
|
|
prompt = f"SYSTEM: {system_prompt}\n\n" |
|
|
|
|
|
for msg in messages: |
|
role = msg["role"] if isinstance(msg, dict) else msg[0] |
|
content = msg["content"] if isinstance(msg, dict) else msg[1] |
|
|
|
if role.lower() == "system": |
|
|
|
continue |
|
elif role.lower() == "user" or role.lower() == "human": |
|
prompt += f"USER: {content}\n\n" |
|
elif role.lower() == "assistant" or role.lower() == "ai": |
|
prompt += f"ASSISTANT: {content}\n\n" |
|
|
|
|
|
prompt += "ASSISTANT: " |
|
|
|
logger.info(f"Formatted prompt (length: {len(prompt)})") |
|
return prompt |
|
|
|
def generate_text(model, tokenizer, prompt, temperature=0.7, top_p=0.9, max_new_tokens=256): |
|
""" |
|
Generate text using the pipeline with explicit tokenizer. |
|
""" |
|
logger.info(f"Generating text with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}") |
|
|
|
try: |
|
|
|
logger.info(f"Input prompt length: {len(prompt)}") |
|
|
|
|
|
gen_config = { |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"do_sample": True, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": 1.1, |
|
} |
|
logger.info(f"Generation config: {gen_config}") |
|
|
|
|
|
if tokenizer: |
|
logger.info("Creating pipeline with explicit tokenizer") |
|
pipe = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device_map=model.device_map if hasattr(model, "device_map") else "auto" |
|
) |
|
else: |
|
|
|
logger.info("No tokenizer available, using direct model.generate") |
|
|
|
|
|
generation_start = time.time() |
|
|
|
|
|
inputs = model.tokenize_using_default(prompt) |
|
inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
**gen_config |
|
) |
|
|
|
|
|
generated_text = model.decode_using_default(outputs[0]) |
|
|
|
generation_time = time.time() - generation_start |
|
logger.info(f"Direct generation completed in {generation_time:.2f} seconds") |
|
|
|
|
|
response = generated_text[len(prompt):].strip() |
|
logger.info(f"Generated response length: {len(response)}") |
|
|
|
return response |
|
|
|
|
|
generation_start = time.time() |
|
outputs = pipe( |
|
prompt, |
|
return_full_text=True, |
|
**gen_config |
|
) |
|
generation_time = time.time() - generation_start |
|
logger.info(f"Pipeline generation completed in {generation_time:.2f} seconds") |
|
|
|
|
|
generated_text = outputs[0]["generated_text"] |
|
|
|
|
|
response = generated_text[len(prompt):].strip() |
|
logger.info(f"Generated response length: {len(response)}") |
|
|
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"Error in generate_text: {e}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
try: |
|
logger.info("Trying fallback manual text generation approach") |
|
|
|
|
|
return "I'm having trouble generating a response right now. Please try again with different parameters or a different question." |
|
|
|
except Exception as e2: |
|
logger.error(f"Fallback approach also failed: {e2}") |
|
return "I encountered an error while generating a response. Please try again." |
|
|
|
|
|
|
|
|
|
def build_gradio_interface(model, tokenizer): |
|
"""Build and launch the Gradio interface""" |
|
logger.info("Building Gradio interface") |
|
|
|
def user_submit(message_history, user_text, temp, top_p, max_tokens, system_message): |
|
"""Handle user message submission""" |
|
logger.info(f"Received user message: '{user_text[:50]}...' (length: {len(user_text)})") |
|
|
|
if not user_text.strip(): |
|
logger.warning("Empty user message, skipping processing") |
|
return message_history, "" |
|
|
|
try: |
|
|
|
if message_history is None: |
|
message_history = [] |
|
|
|
|
|
formatted_history = [] |
|
for msg in message_history: |
|
if isinstance(msg, tuple): |
|
role = "user" if msg[0] == "user" or msg[0] == "human" else "assistant" |
|
formatted_history.append({"role": role, "content": msg[1]}) |
|
elif isinstance(msg, dict): |
|
formatted_history.append(msg) |
|
|
|
|
|
if not formatted_history or formatted_history[0]["role"] != "system": |
|
formatted_history.insert(0, {"role": "system", "content": system_message}) |
|
|
|
|
|
formatted_history.append({"role": "user", "content": user_text}) |
|
|
|
|
|
prompt = format_prompt(formatted_history) |
|
|
|
|
|
assistant_response = generate_text( |
|
model, |
|
tokenizer, |
|
prompt, |
|
temperature=temp, |
|
top_p=top_p, |
|
max_new_tokens=max_tokens |
|
) |
|
|
|
|
|
formatted_history.append({"role": "assistant", "content": assistant_response}) |
|
|
|
|
|
|
|
display_history = [] |
|
for msg in formatted_history: |
|
if msg["role"] == "system": |
|
continue |
|
display_history.append({"role": msg["role"], "content": msg["content"]}) |
|
|
|
logger.info(f"Added assistant response (length: {len(assistant_response)})") |
|
|
|
|
|
optimize_memory() |
|
|
|
return display_history, "" |
|
|
|
except Exception as e: |
|
logger.error(f"Error in user_submit: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
error_msg = "I encountered an error processing your request. Please try again." |
|
|
|
|
|
if message_history is None: |
|
return [ |
|
{"role": "user", "content": user_text}, |
|
{"role": "assistant", "content": error_msg} |
|
], "" |
|
else: |
|
|
|
try: |
|
|
|
if message_history and isinstance(message_history[0], dict): |
|
message_history.append({"role": "user", "content": user_text}) |
|
message_history.append({"role": "assistant", "content": error_msg}) |
|
|
|
else: |
|
new_history = [] |
|
for msg in message_history: |
|
if isinstance(msg, tuple): |
|
role = "user" if msg[0] == "user" else "assistant" |
|
new_history.append({"role": role, "content": msg[1]}) |
|
else: |
|
new_history.append(msg) |
|
new_history.append({"role": "user", "content": user_text}) |
|
new_history.append({"role": "assistant", "content": error_msg}) |
|
message_history = new_history |
|
|
|
return message_history, "" |
|
except: |
|
|
|
return [ |
|
{"role": "user", "content": user_text}, |
|
{"role": "assistant", "content": error_msg} |
|
], "" |
|
|
|
def clear_chat(): |
|
"""Clear the chat history""" |
|
logger.info("Clearing chat history") |
|
optimize_memory() |
|
return [], "" |
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>") |
|
gr.Markdown("<p align='center'>Powered by FrameRateTech/DamageScan-llama-8b-instruct-merged</p>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
label="Chat History", |
|
height=600, |
|
type="messages", |
|
avatar_images=(None, "https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/avatar.png"), |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
user_input = gr.Textbox( |
|
lines=3, |
|
label="Your Message", |
|
placeholder="Type your message here...", |
|
show_copy_button=True |
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### System Prompt") |
|
system_prompt_input = gr.Textbox( |
|
lines=5, |
|
label="System Instructions", |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
show_copy_button=True |
|
) |
|
|
|
gr.Markdown("### Generation Settings") |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", |
|
info="Higher values make output more random, lower values more deterministic" |
|
) |
|
top_p_slider = gr.Slider( |
|
minimum=0.5, maximum=1.0, value=0.9, step=0.05, label="Top-p", |
|
info="Controls diversity via nucleus sampling" |
|
) |
|
max_tokens_slider = gr.Slider( |
|
minimum=64, maximum=1024, value=256, step=64, label="Max New Tokens", |
|
info="Maximum length of generated response" |
|
) |
|
|
|
gr.Markdown("### Tips") |
|
gr.Markdown(""" |
|
* Lower temperature (0.1-0.3) for factual responses |
|
* Higher temperature (0.7-1.0) for creative tasks |
|
* Reduce max tokens if responses are too long |
|
* Clear chat if the model gets confused |
|
""") |
|
|
|
|
|
submit_btn.click( |
|
user_submit, |
|
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input], |
|
outputs=[chatbot, user_input], |
|
) |
|
user_input.submit( |
|
user_submit, |
|
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input], |
|
outputs=[chatbot, user_input], |
|
) |
|
clear_btn.click( |
|
clear_chat, |
|
outputs=[chatbot, user_input] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["Can you explain how the Large Hadron Collider works?"], |
|
["Write a short story about a robot who learns to paint"], |
|
["What are three ways to improve productivity when working from home?"], |
|
["Explain quantum computing to me like I'm 10 years old"], |
|
], |
|
inputs=user_input, |
|
label="Example Prompts" |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
def main(): |
|
"""Main application entry point""" |
|
try: |
|
logger.info("Starting DamageScan 8B Instruct application") |
|
logger.info(f"Environment: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
if not hasattr(model, "tokenize_using_default"): |
|
logger.info("Adding default tokenization methods to model") |
|
|
|
def tokenize_using_default(text): |
|
"""Very basic tokenization that just returns a dummy""" |
|
logger.info("Using minimal default tokenization") |
|
|
|
return {"input_ids": torch.tensor([[1]]).to(model.device)} |
|
|
|
def decode_using_default(token_ids): |
|
"""Very basic decoding that just returns a message""" |
|
logger.info("Using minimal default decoding") |
|
return "I'm having trouble generating a proper response." |
|
|
|
|
|
model.tokenize_using_default = tokenize_using_default |
|
model.decode_using_default = decode_using_default |
|
|
|
|
|
demo = build_gradio_interface(model, tokenizer) |
|
|
|
|
|
logger.info("Launching Gradio interface") |
|
demo.queue().launch( |
|
share=False, |
|
debug=False, |
|
show_error=True, |
|
favicon_path="https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/favicon.ico" |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Application startup failed: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
with gr.Blocks() as fallback_demo: |
|
gr.Markdown("# ⚠️ DamageScan 8B Application Error") |
|
gr.Markdown(f"The application encountered an error during startup:\n\n```\n{str(e)}\n```") |
|
gr.Markdown("Please check the logs for more details or try again later.") |
|
|
|
fallback_demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |