import os from threading import Thread from typing import Iterator import time import textwrap import nltk import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import tiktoken import numpy as np import multiprocessing nltk.download('punkt') sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) num_processes_tone = multiprocessing.cpu_count() if torch.cuda.is_available(): model_id = "daryl149/llama-2-7b-chat-hf" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.use_default_system_prompt = False def num_tokens_from_string(string: str, encoding_name='gpt-3.5-turbo'): encoding = tiktoken.encoding_for_model(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens + 7 def merge_sentences(sentences): merged_list = [] current_sentence = "" for sentence in sentences: if num_tokens_from_string(current_sentence + sentence) <= 500: if current_sentence != "": current_sentence += " " current_sentence += sentence else: merged_list.append(current_sentence) current_sentence = sentence if current_sentence: merged_list.append(current_sentence) return merged_list def split_into_sentences(text): sentences = sentence_tokenizer.tokenize(text) return sentences def corrected_tone(message): output_prompts = [] split_sentences = split_into_sentences(message) token_safe_sentences = merge_sentences(split_sentences) for message in token_safe_sentences: prompt = f"""You are going to act as a storyteller. Read the text below and change the text tone to a new formal tone. Important: Do not change its meaning or content. return only the output.""" prompt += "\n" prompt += message.strip() output_prompts.append(prompt) with multiprocessing.Pool(processes=num_processes_tone) as pool: results = pool.map(langchain_function, prompt_list) out_put_text = ' '.join(results) return out_put_text @spaces.GPU def generate(prompt): max_new_tokens = 1024 temperature = 0.6 top_p = 0.9 top_k = 50 repetition_penalty = 1.2 conversation = [] conversation.append({"role": "user", "content": prompt}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) return "".join(outputs) with gr.Blocks(css="style.css") as demo: input_prompt = gr.Textbox() output = gr.Textbox() btn = gr.Button("Generate") btn.click(generate, inputs=input_prompt, outputs=output) if __name__ == "__main__": demo.launch(share=True)