taufiqdp commited on
Commit
a2cb894
β€’
1 Parent(s): 865b4b0

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +8 -6
  2. app.py +105 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Gemma 1.1 7b It
3
- emoji: πŸ†
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: gemma-1.1-7b-it
3
+ emoji: πŸ‘‘
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.23.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ header: mini
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import subprocess
4
+ from threading import Thread
5
+
6
+ import torch
7
+ import spaces
8
+ import gradio as gr
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
10
+
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ MODEL_ID = os.environ.get("MODEL_ID")
14
+ CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
15
+ MODEL_NAME = MODEL_ID.split("/")[-1]
16
+ CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
17
+ COLOR = os.environ.get("COLOR")
18
+ EMOJI = os.environ.get("EMOJI")
19
+ DESCRIPTION = os.environ.get("DESCRIPTION")
20
+
21
+
22
+ @spaces.GPU()
23
+ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
24
+ # Format history with a given chat template
25
+ if CHAT_TEMPLATE == "ChatML":
26
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
27
+ instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
28
+ for human, assistant in history:
29
+ instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
30
+ instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
31
+ elif CHAT_TEMPLATE == "Mistral Instruct":
32
+ stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
33
+ instruction = '<s>[INST] ' + system_prompt
34
+ for human, assistant in history:
35
+ instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
36
+ instruction += ' ' + message + ' [/INST]'
37
+ else:
38
+ raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
39
+ print(instruction)
40
+
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
43
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
44
+
45
+ if input_ids.shape[1] > CONTEXT_LENGTH:
46
+ input_ids = input_ids[:, -CONTEXT_LENGTH:]
47
+
48
+ generate_kwargs = dict(
49
+ {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
50
+ streamer=streamer,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ max_new_tokens=max_new_tokens,
54
+ top_k=top_k,
55
+ repetition_penalty=repetition_penalty,
56
+ top_p=top_p
57
+ )
58
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
59
+ t.start()
60
+ outputs = []
61
+ for new_token in streamer:
62
+ outputs.append(new_token)
63
+ if new_token in stop_tokens:
64
+ break
65
+ yield "".join(outputs)
66
+
67
+
68
+ # Load model
69
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_compute_dtype=torch.bfloat16
73
+ )
74
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
75
+ model = AutoModelForCausalLM.from_pretrained(
76
+ MODEL_ID,
77
+ device_map="auto",
78
+ quantization_config=quantization_config,
79
+ attn_implementation="flash_attention_2",
80
+ )
81
+
82
+ # Create Gradio interface
83
+ gr.ChatInterface(
84
+ predict,
85
+ title=EMOJI + " " + MODEL_NAME,
86
+ description=DESCRIPTION,
87
+ examples=[
88
+ ["Can you solve the equation 2x + 3 = 11 for x?"],
89
+ ["Write an epic poem about Ancient Rome."],
90
+ ["Who was the first person to walk on the Moon?"],
91
+ ["Use a list comprehension to create a list of squares for numbers from 1 to 10."],
92
+ ["Recommend some popular science fiction books."],
93
+ ["Can you write a short story about a time-traveling detective?"]
94
+ ],
95
+ additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False),
96
+ additional_inputs=[
97
+ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
98
+ gr.Slider(0, 1, 0.8, label="Temperature"),
99
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
100
+ gr.Slider(1, 80, 40, label="Top K sampling"),
101
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
102
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
103
+ ],
104
+ theme=gr.themes.Soft(primary_hue=COLOR),
105
+ ).queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.38.2
2
+ accelerate
3
+ bitsandbytes
4
+ optimum