File size: 2,425 Bytes
b124d69 84054b0 74909f9 b124d69 8f4f3a6 62e031b 8f4f3a6 65c9805 f0a60ae 55c8e0a f0a60ae 81c29d0 f0a60ae 81c29d0 65c9805 f0a60ae 55c8e0a f0a60ae 14147ec f0a60ae 102f247 f0a60ae 84054b0 65c9805 66f3bf5 fc25e76 66f3bf5 248a405 f0a60ae 66f3bf5 f0a60ae eb68182 81c29d0 eb68182 a86fdf9 248a405 b124d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import logging
import os
import torch
import transformers
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO)
if torch.cuda.is_available():
logging.info("Running on GPU")
else:
logging.info("Running on CPU")
# Language model
if "googleflan" == os.environ.get("MODEL"):
### Fast/small model used to debug UI on local machine
model = "google/flan-t5-small"
pipeline = transformers.pipeline("text2text-generation", model=model)
def model_func(input_text, request: gr.Request):
return pipeline(input_text)
elif "summary_bart" == os.environ.get("MODEL"):
model="facebook/bart-large-cnn"
summarizer = transformers.pipeline("summarization", model=model)
def model_func(input_text, request: gr.Request):
return summarizer(input_text, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
elif "llama" == os.environ.get("MODEL"):
### Works on CPU but runtime is > 4 minutes
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(
model,
token=os.environ.get("HF_TOKEN"),
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float32,
device_map="auto",
token=os.environ.get("HF_TOKEN"),
)
def model_func(input_text, request: gr.Request):
sequences = pipeline(
input_text,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
if "name" in list(request.query_params):
output_text = f"{request.query_params['name']}:\n"
else:
output_text = ""
for seq in sequences:
output_text += seq["generated_text"] + "\n"
return output_text
# UI: Gradio
input_label = "How can I help?"
if "summary" in os.environ.get("MODEL"):
input_label = "Enter text to summarize"
demo = gr.Interface(
fn=model_func,
inputs=gr.Textbox(
label=input_label,
lines=3,
value="",
),
outputs=gr.Textbox(
label=f"Model: {model}",
lines=5,
value="",
),
allow_flagging=False,
theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink")
)
demo.launch(server_name="0.0.0.0", server_port=7860) |