AdamNovotnyCom's picture
refactor
fc25e76
import gradio as gr
import logging
import os
import torch
import transformers
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO)
if torch.cuda.is_available():
logging.info("Running on GPU")
else:
logging.info("Running on CPU")
# Language model
if "googleflan" == os.environ.get("MODEL"):
### Fast/small model used to debug UI on local machine
model = "google/flan-t5-small"
pipeline = transformers.pipeline("text2text-generation", model=model)
def model_func(input_text, request: gr.Request):
return pipeline(input_text)
elif "summary_bart" == os.environ.get("MODEL"):
model="facebook/bart-large-cnn"
summarizer = transformers.pipeline("summarization", model=model)
def model_func(input_text, request: gr.Request):
return summarizer(input_text, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
elif "llama" == os.environ.get("MODEL"):
### Works on CPU but runtime is > 4 minutes
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(
model,
token=os.environ.get("HF_TOKEN"),
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float32,
device_map="auto",
token=os.environ.get("HF_TOKEN"),
)
def model_func(input_text, request: gr.Request):
sequences = pipeline(
input_text,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
if "name" in list(request.query_params):
output_text = f"{request.query_params['name']}:\n"
else:
output_text = ""
for seq in sequences:
output_text += seq["generated_text"] + "\n"
return output_text
# UI: Gradio
input_label = "How can I help?"
if "summary" in os.environ.get("MODEL"):
input_label = "Enter text to summarize"
demo = gr.Interface(
fn=model_func,
inputs=gr.Textbox(
label=input_label,
lines=3,
value="",
),
outputs=gr.Textbox(
label=f"Model: {model}",
lines=5,
value="",
),
allow_flagging=False,
theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink")
)
demo.launch(server_name="0.0.0.0", server_port=7860)