Final-Llama / app.py
neuralleap's picture
Update app.py
0f741e1
import os
from threading import Thread
from typing import Iterator
import time
import textwrap
import nltk
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import tiktoken
import numpy as np
import multiprocessing
nltk.download('punkt')
sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
num_processes_tone = multiprocessing.cpu_count()
if torch.cuda.is_available():
model_id = "daryl149/llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
def num_tokens_from_string(string: str, encoding_name='gpt-3.5-turbo'):
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens + 7
def merge_sentences(sentences):
merged_list = []
current_sentence = ""
for sentence in sentences:
if num_tokens_from_string(current_sentence + sentence) <= 500:
if current_sentence != "":
current_sentence += " "
current_sentence += sentence
else:
merged_list.append(current_sentence)
current_sentence = sentence
if current_sentence:
merged_list.append(current_sentence)
return merged_list
def split_into_sentences(text):
sentences = sentence_tokenizer.tokenize(text)
return sentences
def corrected_tone(message):
output_prompts = []
split_sentences = split_into_sentences(message)
token_safe_sentences = merge_sentences(split_sentences)
for message in token_safe_sentences:
prompt = f"""You are going to act as a storyteller. Read the text below and change the text tone to a new formal tone. Important: Do not change its meaning or content. return only the output."""
prompt += "\n"
prompt += message.strip()
output_prompts.append(prompt)
with multiprocessing.Pool(processes=num_processes_tone) as pool:
results = pool.map(langchain_function, prompt_list)
out_put_text = ' '.join(results)
return out_put_text
@spaces.GPU
def generate(prompt):
max_new_tokens = 1024
temperature = 0.6
top_p = 0.9
top_k = 50
repetition_penalty = 1.2
conversation = []
conversation.append({"role": "user", "content": prompt})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
return "".join(outputs)
with gr.Blocks(css="style.css") as demo:
input_prompt = gr.Textbox()
output = gr.Textbox()
btn = gr.Button("Generate")
btn.click(generate, inputs=input_prompt, outputs=output)
if __name__ == "__main__":
demo.launch(share=True)