compare-llms / app.py
playgrdstar's picture
Add check truncation and remove server port
b3577aa
import os, requests
import gradio as gr
HF_READ_API_KEY = os.environ["HF_READ_API_KEY"]
### This code loads the models and undertakes inference locally ###
# from transformers import GPTNeoForCausalLM, GPT2Tokenizer
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
# model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
# tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
# tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
model_list = ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl',
'gpt2-medium', 'gpt2-large', 'gpt2-xl',
'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-neo-6b', 'EleutherAI/gpt-neox-20b',
'bigscience/bloom-1b7', 'bigscience/bloom-3b', 'bigscience/bloom-7b1'
]
def load_model(model_name):
if model_name == 'EleutherAI/gpt-neo-2.7B' or model_name == 'gpt2-medium' or model_name == 'gpt2-large':
model = AutoModelForCausalLM.from_pretrained(model_name)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
return model, tokenizer
def maybe_is_truncated(s):
punct = [".", "!", "?", '"']
if s[-1] in punct:
return False
return True
def load_and_generate(model_name, prompt):
model, tokenizer = load_model(model_name)
temperature=0.25
tokens = tokenizer(prompt, return_tensors="pt")
max_length = len(tokens.input_ids[0])+5
input_ids = tokens.input_ids
attention_mask = tokens.attention_mask
# see huggingface.co/docs/transformers/main_classes/text_generation
gen_tokens = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
# max_length=max_length,
max_new_tokens=max_length,
# use_cache=False,
# penalty_alpha=0.1,
# top_k=100,
# early_stopping=False
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
max_times = 20
while maybe_is_truncated(gen_text) and max_times > 0:
tokens = tokenizer(gen_text, return_tensors="pt")
max_length = len(tokens.input_ids[0])+5
input_ids = tokens.input_ids
attention_mask = tokens.attention_mask
gen_tokens = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
max_length=max_length,
# max_new_tokens=100,
# use_cache=True,
# penalty_alpha=0.1,
# top_k=100,
# early_stopping=False
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
max_times -= 1
return gen_text.replace("<pad>", "").replace("</s>", "")
### This code for the inference api ###
def generate_from_api(query, model_name, temperature, max_tokens):
headers = {f"Authorization": f"Bearer {HF_READ_API_KEY}",
"wait_for_model": "true",
"temperature": str(temperature),
"max_tokens": str(max_tokens),
"max_time": str(120)}
model_api_url = f"https://api-inference.huggingface.co/models/{model_name}"
payload = {"inputs": query}
response = requests.post(model_api_url, headers=headers, json=payload)
while response.status_code != 200:
response = requests.post(model_api_url, headers=headers, json=payload)
return response.json()[0]['generated_text']
def generate_from_api_check(query, model_name, temperature, max_tokens):
headers = {f"Authorization": f"Bearer {HF_READ_API_KEY}",
"wait_for_model": "true",
"temperature": str(temperature),
"max_tokens": str(max_tokens),
"max_time": str(120)}
model_api_url = f"https://api-inference.huggingface.co/models/{model_name}"
payload = {"inputs": query}
response = requests.post(model_api_url, headers=headers, json=payload)
while response.status_code != 200:
response = requests.post(model_api_url, headers=headers, json=payload)
max_times = 20
gen_text = response.json()[0]['generated_text']
while maybe_is_truncated(gen_text) and max_times > 0:
headers = {f"Authorization": f"Bearer {HF_READ_API_KEY}",
"wait_for_model": "true",
"temperature": str(temperature),
"max_tokens": str(max_tokens + len(gen_text)),
"max_time": str(120)}
payload = {"inputs": query + ' ' + gen_text}
response = requests.post(model_api_url, headers=headers, json=payload)
while response.status_code != 200:
response = requests.post(model_api_url, headers=headers, json=payload)
gen_text = response.json()[0]['generated_text']
max_times -= 1
return gen_text
with gr.Blocks(css='style.css') as demo:
gr.HTML("""
<div style="text-align: center; max-width: 1240px; margin: 0 auto;">
<h1 style="font-weight: 200; font-size: 20px; margin-bottom:8px; margin-top:0px;">
Different Strokes (Prompts) for Different Folks (LLMs)
</h1>
<hr style="margin-bottom:5px; margin-top:5px;">
<h4 style="font-weight: 50; font-size: 14px; margin-bottom:0px; margin-top:0px;">
After reading <a href="https://github.com/dair-ai/Prompt-Engineering-Guide">Prompt Engineering Guide</a>, which is a good guide when starting to learn about prompts for large language models (LLMs), specifically OpenAI's LLMs, I was interested in seeing the results with for other LLMs. Hence, did up a simple demonstration of different prompts for different popular LLMs of different sizes. The prompt examples are taken from the Prompt Engineering Guide, and the LLMs that you can select below are all available on Hugging Face. If you are interested in comparing them with the prompts from OpenAI's model, you can refer to the writeup in the <a href="https://github.com/dair-ai/Prompt-Engineering-Guide">Prompt Engineering Guide</a> itself.
</h4>
<hr style="margin-bottom:5px; margin-top:5px;">
<h5 style="font-weight: 50; font-size: 12px; margin-bottom:0px; margin-top:0px;">
Note: Larger models will take a while, especially on the first run.
</h5>
</div>
""")
with gr.Column(elem_id="col-container"):
with gr.Row(variant="compact"):
model_name = gr.Dropdown(
model_list,
label="Select model",
value=model_list[0],
).style(
container=False,
)
temperature = gr.Slider(
0.1, 100.0, value=1.0, label="Temperature",
).style(
container=False,
)
max_tokens = gr.Slider(
10, 250, step=1, value=100, label="Max. tokens (in output)",
).style(
container=False,
)
check_truncated = gr.Checkbox(
label="Check for truncated output",
value=False,
).style(
container=False,
)
with gr.Row(variant="compact"):
prompt = gr.Textbox(
label="Enter your prompt",
show_label=False,
# max_lines=2,
placeholder="Select your prompt from the examples below",
).style(
container=False,
)
process = gr.Button("Generate").style(full_width=False)
with gr.Row():
output=gr.Textbox(
label="LLM output",
show_label=True)
gr.HTML("""
<div>
<h4 style="font-weight: 50; font-size: 14px; margin-bottom:0px; margin-top:0px;">
Prompt examples. Select the prompt you would like to test, and it will appear (properly formatted) in the input box above.
</h4>
</div>
""")
with gr.Tab("Introduction"):
example_set_1 = gr.Examples(label = 'Simple Prompt vs. Instruct then Prompt.',
examples=["The sky is ", "Complete the following sentence: The sky is ",],
inputs=[prompt])
example_set_2 = gr.Examples(label = 'Few Shot Prompt.',
examples=["This is awesome! // Positive\nThis is bad! // Negative\nWow that movie was rad! // Positive\nWhat a horrible show! //",],
inputs=[prompt])
example_set_3 = gr.Examples(label = 'Explicitly Specify the Instruction',
examples=["### Instruction ###\nTranslate the text below to Spanish:\nText: 'hello!'",],
inputs=[prompt])
example_set_4 = gr.Examples(label = 'Be Very Specific',
examples=["Extract the name of places in the following text.\nDesired format:\nPlace: <comma_separated_list_of_company_names>\nInput: 'Although these developments are encouraging to researchers, much is still a mystery. “We often have a black box between the brain and the effect we see in the periphery,” says Henrique Veiga-Fernandes, a neuroimmunologist at the Champalimaud Centre for the Unknown in Lisbon. “If we want to use it in the therapeutic context, we actually need to understand the mechanism.'",],
inputs=[prompt])
example_set_5 = gr.Examples(label = 'Precision',
examples=["Explain the concept of deep learning. Keep the explanation short, only a few sentences, and don't be too descriptive.", "Use 2-3 sentences to explain the concept of deep learning to a high school student."],
inputs=[prompt])
example_set_6 = gr.Examples(label = 'Focus on What LLM Should Do',
examples=["The following is an agent that recommends movies to a customer. The agent is responsible to recommend a movie from the top global trending movies. It should refrain from asking users for their preferences and avoid asking for personal information. If the agent doesn't have a movie to recommend, it should respond 'Sorry, couldn't find a movie to recommend today.'.\nCustomer: Please recommend a movie based on my interests.\nAgent:"],
inputs=[prompt])
with gr.Tab("Basic Tasks"):
example_set_7 = gr.Examples(label = 'Explain vs. Summarize',
examples=["Explain antibiotics.\nA:", "Antibiotics are a type of medication used to treat bacterial infections. They work by either killing the bacteria or preventing them from reproducing, allowing the body’s immune system to fight off the infection. Antibiotics are usually taken orally in the form of pills, capsules, or liquid solutions, or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance.\nExplain the above in one sentence:",],
inputs=[prompt])
example_set_8 = gr.Examples(label = 'Information Extraction',
examples=["Author-contribution statements and acknowledgements in research papers should state clearly and specifically whether, and to what extent, the authors used AI technologies such as ChatGPT in the preparation of their manuscript and analysis. They should also indicate which LLMs were used. This will alert editors and reviewers to scrutinize manuscripts more carefully for potential biases, inaccuracies and improper source crediting. Likewise, scientific journals should be transparent about their use of LLMs, for example when selecting submitted manuscripts.\nMention the large language model based product mentioned in the paragraph above:",],
inputs=[prompt])
example_set_9 = gr.Examples(label = 'Question and Answer',
examples=["Answer the question based on the context below. Keep the answer short and concise. Respond 'Unsure about answer' if not sure about the answer.\nContext: Teplizumab traces its roots to a New Jersey drug company called Ortho Pharmaceutical. There, scientists generated an early version of the antibody, dubbed OKT3. Originally sourced from mice, the molecule was able to bind to the surface of T cells and limit their cell-killing potential. In 1986, it was approved to help prevent organ rejection after kidney transplants, making it the first therapeutic antibody allowed for human use.\nQuestion: What was OKT3 originally sourced from?\nAnswer:",],
inputs=[prompt])
example_set_10 = gr.Examples(label = 'Text Classification',
examples=["Classify the text into neutral, negative or positive.\nText: I think the food was okay.\nSentiment:","Classify the text into neutral, negative or positive.\nText: I think the vacation is okay.\nSentiment: neutral\nText: I think the food was okay.\nSentiment:"],
inputs=[prompt])
example_set_11 = gr.Examples(label = 'Conversation',
examples=["The following is a conversation with an AI research assistant. The assistant tone is technical and scientific.\nHuman: Hello, who are you?\nAI: Greeting! I am an AI research assistant. How can I help you today?\nHuman: Can you tell me about the creation of blackholes?\nAI:", "The following is a conversation with an AI research assistant. The assistant answers should be easy to understand even by primary school students.\nHuman: Hello, who are you?\nAI: Greeting! I am an AI research assistant. How can I help you today?\nHuman: Can you tell me about the creation of black holes?\nAI: "],
inputs=[prompt])
example_set_12 = gr.Examples(label = 'Reasoning',
examples=["The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.\nA: ", "The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.\nSolve by breaking the problem into steps. First, identify the odd numbers, add them, and indicate whether the result is odd or even."],
inputs=[prompt])
with gr.Tab("Interesting Techniques"):
example_set_13 = gr.Examples(label = 'Zero Shot, i.e., no examples at all',
examples=["Classify the text into neutral, negative or positive.\nText: I think the vacation is okay.\nSentiment:",],
inputs=[prompt])
example_set_14 = gr.Examples(label = 'Few Shot, i.e., only a few examples',
examples=["The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.\nA: The answer is False.\n\nThe odd numbers in this group add up to an even number: 17, 10, 19, 4, 8, 12, 24.\nA: The answer is True.\n\nThe odd numbers in this group add up to an even number: 16, 11, 14, 4, 8, 13, 24.\nA: The answer is True.\n\nThe odd numbers in this group add up to an even number: 17, 9, 10, 12, 13, 4, 2.\nA: The answer is False.\n\nThe odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.\nA: ",],
inputs=[prompt])
example_set_15 = gr.Examples(label = 'Chain of Thought, i.e., go through a series of rational steps',
examples=["The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.\nA: Adding all the odd numbers (9, 15, 1) gives 25. The answer is False.\n\nThe odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.\nA:",],
inputs=[prompt])
example_set_16 = gr.Examples(label = 'Zero Shot Chain of Thought, i.e., think step by step, but no examples provided',
examples=["I went to the market and bought 10 apples. I gave 2 apples to the neighbor and 2 to the repairman. I then went and bought 5 more apples and ate 1. How many apples did I remain with?\nLet's think step by step.",],
inputs=[prompt])
example_set_17 = gr.Examples(label = 'Self Consistency, i.e., give examples to encourage the model to be consistent',
examples=["Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done,there will be 21 trees. How many trees did the grove workers plant today?\nA: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted.\nSo, they must have planted 21 - 15 = 6 trees. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n\nQ: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA: She bought 5 bagels for $3 each. This means she spent 5\n\nQ: When I was 6 my sister was half my age. Now I’m 70 how old is my sister?\nA:",],
inputs=[prompt])
example_set_18 = gr.Examples(label = 'Generating Knowledge, i.e., use examples to generate knowledge',
examples=["Input: Greece is larger than mexico.\nKnowledge: Greece is approximately 131,957 sq km, while Mexico is approximately 1,964,375 sq km, making Mexico 1,389% larger than Greece.\n\nInput: Glasses always fog up.\nKnowledge: Condensation occurs on eyeglass lenses when water vapor from your sweat, breath, and ambient humidity lands on a cold surface, cools, and then changes into tiny drops of liquid, forming a film that you see as fog. Your lenses will be relatively cool compared to your breath, especially when the outside air is cold.\n\nInput: A fish is capable of thinking.\nKnowledge: Fish are more intelligent than they appear. In many areas, such as memory, their cognitive powers match or exceed those of ’higher’ vertebrates including non-human primates. Fish’s long-term memories help them keep track of complex social relationships.\n\nInput: A common effect of smoking lots of cigarettes in one’s lifetime is a higher than normal chance of getting lung cancer.\nKnowledge: Those who consistently averaged less than one cigarette per day over their lifetime had nine times the risk of dying from lung cancer than never smokers. Among people who smoked between one and 10 cigarettes per day, the risk of dying from lung cancer was nearly 12 times higher than that of never smokers.\n\nInput: Part of golf is trying to get a higher point total than others.\nKnowledge:",],
inputs=[prompt])
# process.click(load_and_generate, inputs=[model_name, prompt], outputs=[output])
if check_truncated:
process.click(generate_from_api_check, inputs=[prompt, model_name, temperature, max_tokens], outputs=[output])
else:
process.click(generate_from_api, inputs=[prompt, model_name, temperature, max_tokens], outputs=[output])
# demo.launch(server_port=8080)
demo.launch()