import gradio as gr import logging import os import torch import transformers from transformers import AutoTokenizer logging.basicConfig(level=logging.INFO) if "googleflan" == os.environ.get("MODEL"): model = "google/flan-t5-small" logging.info(f"APP startup. Model {model}") pipe_flan = transformers.pipeline("text2text-generation", model=model) def model_func(input_text, request: gr.Request): print(f"Input request: {input_text}") print(request.query_params) print(os.environ.get("HF_TOKEN")[:5]) logging.info(os.environ.get("HF_TOKEN")[:5]) return pipe_flan(input_text) elif "llama" == os.environ.get("MODEL"): model = "meta-llama/Llama-2-7b-chat-hf" logging.info(f"APP startup. Model {model}") tokenizer = AutoTokenizer.from_pretrained( model, token=os.environ.get("HF_TOKEN"), ) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float16, device_map="auto", token=os.environ.get("HF_TOKEN"), ) def model_func(input_text, request: gr.Request): sequences = pipeline( input_text, do_sample=True, top_k=10, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, max_length=200, ) if "name" in list(request.query_params): output_text = f"{request.query_params['name']}:\n" else: output_text = "" for seq in sequences: output_text += seq["generated_text"] + "\n" return output_text demo = gr.Interface( fn=model_func, inputs=gr.Textbox( label="How can I help?", lines=3, value="", ), outputs=gr.Textbox( label="LLM", lines=5, value="", ), allow_flagging=False, theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink") ) demo.launch(server_name="0.0.0.0", server_port=7860)