Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from peft import PeftModel | |
import gradio as gr | |
from typing import Iterator, List, Tuple | |
# Constants | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DEFAULT_MAX_NEW_TOKENS = 930 | |
# Model Configuration for Generating Mode | |
model_id = "meta-llama/Llama-2-7b-hf" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=False, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config) | |
model_generate = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Editing mode uses the same tokenizer but might use a simpler or different model setup | |
model_edit = model_generate # For simplicity, using the same model setup for editing in this example | |
# Helper Functions | |
def generate_text(input_text: str, chat_history: List[Tuple[str, str]], max_tokens: int = DEFAULT_MAX_NEW_TOKENS) -> Iterator[str]: | |
# Append the new message to the chat history for context | |
chat_history.append(("user", input_text)) | |
# Prepare the input with the conversation context | |
context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history]) | |
input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_generate.device) | |
outputs = model_generate.generate(input_ids, max_length=input_ids.shape[1] + max_tokens, do_sample=True) | |
for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split(): | |
yield output | |
chat_history.append(("assistant", tokenizer.decode(outputs[0], skip_special_tokens=True))) | |
def edit_text(input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]: | |
context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history]) | |
input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_edit.device) | |
outputs = model_edit.generate(input_ids, max_length=input_ids.shape[1] + DEFAULT_MAX_NEW_TOKENS, do_sample=True) | |
for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split(): | |
yield output | |
# Gradio Interface | |
def switch_mode(is_editing: bool, input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]: | |
if is_editing and chat_history: | |
return edit_text(input_text, chat_history) | |
elif not is_editing: | |
return generate_text(input_text, chat_history) | |
else: | |
yield "Chat history is empty, cannot edit." | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
input_text = gr.Textbox(label="Input Text") | |
is_editing = gr.Checkbox(label="Editing Mode", value=False) | |
output_text = gr.Textbox(label="Output", interactive=True) | |
chat_history = gr.State([]) # Using State to maintain chat history | |
generate_button = gr.Button("Generate/Edit") | |
generate_button.click(switch_mode, inputs=[is_editing, input_text, chat_history], outputs=output_text) | |
# Main Execution | |
if __name__ == "__main__": | |
demo.launch() | |