import gradio as gr import logging import os import torch import transformers from transformers import AutoTokenizer logging.basicConfig(level=logging.INFO) if torch.cuda.is_available(): logging.info("Running on GPU") else: logging.info("Running on CPU") if "googleflan" == os.environ.get("MODEL"): model = "google/flan-t5-small" pipeline = transformers.pipeline("text2text-generation", model=model) def model_func(input_text, request: gr.Request): return pipeline(input_text) elif "llama" == os.environ.get("MODEL"): model = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained( model, token=os.environ.get("HF_TOKEN"), ) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float16, # torch_dtype="auto", low_cpu_mem_usage=True, device_map="auto", token=os.environ.get("HF_TOKEN"), ) def model_func(input_text, request: gr.Request): sequences = pipeline( input_text, do_sample=True, top_k=10, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, max_length=200, ) if "name" in list(request.query_params): output_text = f"{request.query_params['name']}:\n" else: output_text = "" for seq in sequences: output_text += seq["generated_text"] + "\n" return output_text demo = gr.Interface( fn=model_func, inputs=gr.Textbox( label="How can I help?", lines=3, value="", ), outputs=gr.Textbox( label=f"Model: {model}", lines=5, value="", ), allow_flagging=False, theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink") ) demo.launch(server_name="0.0.0.0", server_port=7860)