AdamNovotnyCom's picture
refactor
62e031b
raw
history blame
No virus
1.93 kB
import gradio as gr
import logging
import os
import torch
import transformers
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO)
if torch.cuda.is_available():
logging.info("Running on GPU")
else:
logging.info("Running on CPU")
if "googleflan" == os.environ.get("MODEL"):
model = "google/flan-t5-small"
pipeline = transformers.pipeline("text2text-generation", model=model)
def model_func(input_text, request: gr.Request):
return pipeline(input_text)
elif "llama" == os.environ.get("MODEL"):
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(
model,
token=os.environ.get("HF_TOKEN"),
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
# torch_dtype="auto",
low_cpu_mem_usage=True,
device_map="auto",
token=os.environ.get("HF_TOKEN"),
)
def model_func(input_text, request: gr.Request):
sequences = pipeline(
input_text,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
if "name" in list(request.query_params):
output_text = f"{request.query_params['name']}:\n"
else:
output_text = ""
for seq in sequences:
output_text += seq["generated_text"] + "\n"
return output_text
demo = gr.Interface(
fn=model_func,
inputs=gr.Textbox(
label="How can I help?",
lines=3,
value="",
),
outputs=gr.Textbox(
label=f"Model: {model}",
lines=5,
value="",
),
allow_flagging=False,
theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink")
)
demo.launch(server_name="0.0.0.0", server_port=7860)