webpluging / app.py
ranamhamoud's picture
Update app.py
f317c15 verified
raw
history blame
No virus
3.19 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import gradio as gr
from typing import Iterator, List, Tuple
# Constants
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DEFAULT_MAX_NEW_TOKENS = 930
# Model Configuration for Generating Mode
model_id = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
model_generate = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Editing mode uses the same tokenizer but might use a simpler or different model setup
model_edit = model_generate # For simplicity, using the same model setup for editing in this example
# Helper Functions
def generate_text(input_text: str, chat_history: List[Tuple[str, str]], max_tokens: int = DEFAULT_MAX_NEW_TOKENS) -> Iterator[str]:
# Append the new message to the chat history for context
chat_history.append(("user", input_text))
# Prepare the input with the conversation context
context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history])
input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_generate.device)
outputs = model_generate.generate(input_ids, max_length=input_ids.shape[1] + max_tokens, do_sample=True)
for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split():
yield output
chat_history.append(("assistant", tokenizer.decode(outputs[0], skip_special_tokens=True)))
def edit_text(input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]:
context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history])
input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_edit.device)
outputs = model_edit.generate(input_ids, max_length=input_ids.shape[1] + DEFAULT_MAX_NEW_TOKENS, do_sample=True)
for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split():
yield output
# Gradio Interface
def switch_mode(is_editing: bool, input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]:
if is_editing and chat_history:
return edit_text(input_text, chat_history)
elif not is_editing:
return generate_text(input_text, chat_history)
else:
yield "Chat history is empty, cannot edit."
with gr.Blocks() as demo:
with gr.Row():
input_text = gr.Textbox(label="Input Text")
is_editing = gr.Checkbox(label="Editing Mode", value=False)
output_text = gr.Textbox(label="Output", interactive=True)
chat_history = gr.State([]) # Using State to maintain chat history
generate_button = gr.Button("Generate/Edit")
generate_button.click(switch_mode, inputs=[is_editing, input_text, chat_history], outputs=output_text)
# Main Execution
if __name__ == "__main__":
demo.launch()