import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel import gradio as gr from typing import Iterator, List, Tuple # Constants MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DEFAULT_MAX_NEW_TOKENS = 930 # Model Configuration for Generating Mode model_id = "meta-llama/Llama-2-7b-hf" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config) model_generate = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Editing mode uses the same tokenizer but might use a simpler or different model setup model_edit = model_generate # For simplicity, using the same model setup for editing in this example # Helper Functions def generate_text(input_text: str, chat_history: List[Tuple[str, str]], max_tokens: int = DEFAULT_MAX_NEW_TOKENS) -> Iterator[str]: # Append the new message to the chat history for context chat_history.append(("user", input_text)) # Prepare the input with the conversation context context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history]) input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_generate.device) outputs = model_generate.generate(input_ids, max_length=input_ids.shape[1] + max_tokens, do_sample=True) for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split(): yield output chat_history.append(("assistant", tokenizer.decode(outputs[0], skip_special_tokens=True))) def edit_text(input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]: context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history]) input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_edit.device) outputs = model_edit.generate(input_ids, max_length=input_ids.shape[1] + DEFAULT_MAX_NEW_TOKENS, do_sample=True) for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split(): yield output # Gradio Interface def switch_mode(is_editing: bool, input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]: if is_editing and chat_history: return edit_text(input_text, chat_history) elif not is_editing: return generate_text(input_text, chat_history) else: yield "Chat history is empty, cannot edit." with gr.Blocks() as demo: with gr.Row(): input_text = gr.Textbox(label="Input Text") is_editing = gr.Checkbox(label="Editing Mode", value=False) output_text = gr.Textbox(label="Output", interactive=True) chat_history = gr.State([]) # Using State to maintain chat history generate_button = gr.Button("Generate/Edit") generate_button.click(switch_mode, inputs=[is_editing, input_text, chat_history], outputs=output_text) # Main Execution if __name__ == "__main__": demo.launch()