import gradio as gr import logging import os import torch import transformers from transformers import AutoTokenizer logging.basicConfig(level=logging.INFO) if torch.cuda.is_available(): logging.info("Running on GPU") else: logging.info("Running on CPU") # Language model if "googleflan" == os.environ.get("MODEL"): ### Fast/small model used to debug UI on local machine model = "google/flan-t5-small" pipeline = transformers.pipeline("text2text-generation", model=model) def model_func(input_text, request: gr.Request): return pipeline(input_text) elif "summary_bart" == os.environ.get("MODEL"): model="facebook/bart-large-cnn" summarizer = transformers.pipeline("summarization", model=model) def model_func(input_text, request: gr.Request): return summarizer(input_text, max_length=130, min_length=30, do_sample=False)[0]["summary_text"] elif "llama" == os.environ.get("MODEL"): ### Works on CPU but runtime is > 4 minutes model = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained( model, token=os.environ.get("HF_TOKEN"), ) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float32, device_map="auto", token=os.environ.get("HF_TOKEN"), ) def model_func(input_text, request: gr.Request): sequences = pipeline( input_text, do_sample=True, top_k=10, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, max_length=200, ) if "name" in list(request.query_params): output_text = f"{request.query_params['name']}:\n" else: output_text = "" for seq in sequences: output_text += seq["generated_text"] + "\n" return output_text # UI: Gradio input_label = "How can I help?" if "summary" in os.environ.get("MODEL"): input_label = "Enter text to summarize" demo = gr.Interface( fn=model_func, inputs=gr.Textbox( label=input_label, lines=3, value="", ), outputs=gr.Textbox( label=f"Model: {model}", lines=5, value="", ), allow_flagging=False, theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink") ) demo.launch(server_name="0.0.0.0", server_port=7860)