AdamNovotnyCom's picture
refactor
8f4f3a6
raw
history blame
No virus
1.49 kB
import gradio as gr
import logging
import os
import torch
import transformers
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO)
print("APP startup")
pipe_flan = transformers.pipeline("text2text-generation", model="google/flan-t5-small")
def google_flan(input_text, request: gr.Request):
print("New response 2")
print(request.query_params)
print(os.environ.get("HF_TOKEN")[:5])
logging.info(os.environ.get("HF_TOKEN")[:5])
return pipe_flan(input_text)
# model = "meta-llama/Llama-2-7b-chat-hf"
# tokenizer = AutoTokenizer.from_pretrained(
# model,
# token=os.environ["HF_TOKEN"],
# )
# pipeline = transformers.pipeline(
# "text-generation",
# model=model,
# torch_dtype=torch.float16,
# device_map="auto",
# token=os.environ["HF_TOKEN"],
# low_cpu_mem_usage=True,
# )
# def llama2(input_text):
# sequences = pipeline(
# input_text,
# do_sample=True,
# top_k=10,
# num_return_sequences=1,
# eos_token_id=tokenizer.eos_token_id,
# max_length=200,
# )
# output_text = ""
# for seq in sequences:
# output_text += seq["generated_text"] + "\n"
# return output_text
demo = gr.Interface(
fn=google_flan,
inputs="text",
outputs="text",
allow_flagging=False,
title="How can I help?",
theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink")
)
demo.launch(server_name="0.0.0.0", server_port=7860)