jimmy624135 commited on
Commit
1bc3ff2
1 Parent(s): d8ce7ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ TextIteratorStreamer,
12
+ LlamaTokenizer,
13
+ )
14
+
15
+ MAX_MAX_NEW_TOKENS = 1024
16
+ DEFAULT_MAX_NEW_TOKENS = 50
17
+ MAX_INPUT_TOKEN_LENGTH = 512
18
+
19
+ DESCRIPTION = """\
20
+ # OpenELM-270M-Instruct -- Running on CPU
21
+ This Space demonstrates [apple/OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) by Apple. Please, check the original model card for details.
22
+ For additional detail on the model, including a link to the arXiv paper, refer to the [Hugging Face Paper page for OpenELM](https://huggingface.co/papers/2404.14619) .
23
+ For details on pre-training, instruction tuning, and parameter-efficient finetuning for the model refer to the [OpenELM page in the CoreNet GitHub repository](https://github.com/apple/corenet/tree/main/projects/openelm) .
24
+ """
25
+
26
+ LICENSE = """
27
+ <p/>
28
+ ---
29
+ As a derivative work of [apple/OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) by Apple,
30
+ this demo is governed by the original [license](https://huggingface.co/apple/OpenELM-270M-Instruct/blob/main/LICENSE)
31
+ Based on the [Norod78/OpenELM_3B_Demo](https://huggingface.co/spaces/Norod78/OpenELM_3B_Demo) space.
32
+ """
33
+
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ "apple/OpenELM-270M-Instruct",
37
+ revision="eb111ff",
38
+ trust_remote_code=True,
39
+ )
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ "meta-llama/Llama-2-7b-hf",
42
+ revision="01c7f73",
43
+ trust_remote_code=True,
44
+ tokenizer_class=LlamaTokenizer,
45
+ )
46
+
47
+ if tokenizer.pad_token == None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.pad_token_id = tokenizer.eos_token_id
50
+ model.config.pad_token_id = tokenizer.eos_token_id
51
+
52
+ def generate(
53
+ message: str,
54
+ chat_history: list[tuple[str, str]],
55
+ max_new_tokens: int = 1024,
56
+ temperature: float = 0.1,
57
+ top_p: float = 0.5,
58
+ top_k: int = 3,
59
+ repetition_penalty: float = 1.4,
60
+ ) -> Iterator[str]:
61
+
62
+ historical_text = ""
63
+ #Prepend the entire chat history to the message with new lines between each message
64
+ for user, assistant in chat_history:
65
+ historical_text += f"\n{user}\n{assistant}"
66
+
67
+ if len(historical_text) > 0:
68
+ message = historical_text + f"\n{message}"
69
+ input_ids = tokenizer([message], return_tensors="pt").input_ids
70
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
73
+ input_ids = input_ids.to(model.device)
74
+
75
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
76
+ generate_kwargs = dict(
77
+ {"input_ids": input_ids},
78
+ streamer=streamer,
79
+ max_new_tokens=max_new_tokens,
80
+ do_sample=True,
81
+ top_p=top_p,
82
+ top_k=top_k,
83
+ temperature=temperature,
84
+ num_beams=1,
85
+ pad_token_id = tokenizer.eos_token_id,
86
+ repetition_penalty=repetition_penalty,
87
+ no_repeat_ngram_size=5,
88
+ early_stopping=False,
89
+ )
90
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ t.start()
92
+
93
+ outputs = []
94
+ for text in streamer:
95
+ outputs.append(text)
96
+ yield "".join(outputs)
97
+
98
+
99
+ chat_interface = gr.ChatInterface(
100
+ fn=generate,
101
+ additional_inputs=[
102
+ gr.Slider(
103
+ label="Max new tokens",
104
+ minimum=1,
105
+ maximum=MAX_MAX_NEW_TOKENS,
106
+ step=1,
107
+ value=DEFAULT_MAX_NEW_TOKENS,
108
+ ),
109
+ gr.Slider(
110
+ label="Temperature",
111
+ minimum=0.0,
112
+ maximum=4.0,
113
+ step=0.1,
114
+ value=0.1,
115
+ ),
116
+ gr.Slider(
117
+ label="Top-p (nucleus sampling)",
118
+ minimum=0.05,
119
+ maximum=1.0,
120
+ step=0.05,
121
+ value=0.5,
122
+ ),
123
+ gr.Slider(
124
+ label="Top-k",
125
+ minimum=1,
126
+ maximum=1000,
127
+ step=1,
128
+ value=3,
129
+ ),
130
+ gr.Slider(
131
+ label="Repetition penalty",
132
+ minimum=1.0,
133
+ maximum=2.0,
134
+ step=0.05,
135
+ value=1.4,
136
+ ),
137
+ ],
138
+ stop_btn="Stop",
139
+ cache_examples=False,
140
+ examples=[
141
+ ["You are three years old. Count from one to ten."],
142
+ ["Explain quantum physics in 5 words or less:"],
143
+ ["Question: What do you call a bear with no teeth?\nAnswer:"],
144
+ ],
145
+ )
146
+
147
+ with gr.Blocks(css="style.css") as demo:
148
+ gr.Markdown(DESCRIPTION)
149
+ chat_interface.render()
150
+ gr.Markdown(LICENSE)
151
+
152
+ if __name__ == "__main__":
153
+ demo.queue(max_size=20).launch()