File size: 2,425 Bytes
b124d69
84054b0
74909f9
 
 
 
b124d69
8f4f3a6
62e031b
 
 
 
8f4f3a6
65c9805
f0a60ae
55c8e0a
f0a60ae
81c29d0
f0a60ae
81c29d0
65c9805
 
 
 
 
f0a60ae
55c8e0a
f0a60ae
 
 
 
 
 
 
 
14147ec
f0a60ae
 
102f247
f0a60ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84054b0
65c9805
66f3bf5
fc25e76
66f3bf5
248a405
f0a60ae
 
66f3bf5
f0a60ae
 
 
eb68182
81c29d0
eb68182
 
 
a86fdf9
248a405
 
b124d69
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import logging
import os
import torch
import transformers
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO)
if torch.cuda.is_available():
    logging.info("Running on GPU")
else:
    logging.info("Running on CPU")

# Language model
if "googleflan" == os.environ.get("MODEL"):
    ### Fast/small model used to debug UI on local machine
    model = "google/flan-t5-small"
    pipeline = transformers.pipeline("text2text-generation", model=model)
    def model_func(input_text, request: gr.Request):
        return pipeline(input_text)
elif "summary_bart" == os.environ.get("MODEL"):
    model="facebook/bart-large-cnn"
    summarizer = transformers.pipeline("summarization", model=model)
    def model_func(input_text, request: gr.Request):
        return summarizer(input_text, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
elif "llama" == os.environ.get("MODEL"):
    ### Works on CPU but runtime is > 4 minutes
    model = "meta-llama/Llama-2-7b-chat-hf"
    tokenizer = AutoTokenizer.from_pretrained(
        model,
        token=os.environ.get("HF_TOKEN"),
    )
    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        torch_dtype=torch.float32,
        device_map="auto",
        token=os.environ.get("HF_TOKEN"),
    )

    def model_func(input_text, request: gr.Request):
        sequences = pipeline(
            input_text,
            do_sample=True,
            top_k=10,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
            max_length=200,
        )
        if "name" in list(request.query_params):
            output_text = f"{request.query_params['name']}:\n"
        else:
            output_text = ""
        for seq in sequences:
            output_text += seq["generated_text"] + "\n"
        return output_text

# UI: Gradio
input_label = "How can I help?"
if "summary" in os.environ.get("MODEL"):
    input_label = "Enter text to summarize"
demo = gr.Interface(
    fn=model_func, 
    inputs=gr.Textbox(
            label=input_label,
            lines=3,
            value="",
    ),
    outputs=gr.Textbox(
            label=f"Model: {model}",
            lines=5,
            value="",
    ),
    allow_flagging=False, 
    theme=gr.themes.Default(primary_hue="blue", secondary_hue="pink")
)
    
demo.launch(server_name="0.0.0.0", server_port=7860)