ranamhamoud commited on
Commit
b75125a
โ€ข
1 Parent(s): cf9b3fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -50
app.py CHANGED
@@ -1,73 +1,125 @@
 
 
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from threading import Thread
4
- from typing import Iterator, List, Tuple
 
5
  import gradio as gr
 
 
 
6
 
7
  # Constants
8
- MAX_INPUT_TOKEN_LENGTH = 4096
9
  DEFAULT_MAX_NEW_TOKENS = 930
 
10
 
11
- # Load Models and Tokenizers
12
- model_id = "meta-llama/Llama-2-7b-hf"
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- tokenizer.pad_token = tokenizer.eos_token
15
- model_generate = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
16
- model_edit = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") # Assuming a different setup or hyperparameters
17
 
18
- # Helper function to process text
19
- def process_text(text: str) -> str:
20
- return text.replace("\n", " ").strip()
21
 
22
- def run_model(input_ids, model, max_new_tokens, top_p, top_k, temperature, repetition_penalty):
23
- return model.generate(
24
- input_ids=input_ids,
25
- max_length=input_ids.shape[1] + max_new_tokens,
26
- do_sample=True,
27
- top_p=top_p,
28
- top_k=top_k,
29
- temperature=temperature,
30
- num_beams=1,
31
- repetition_penalty=repetition_penalty
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def generate_text(mode: str, message: str, chat_history: List[Tuple[str, str]], max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
35
- temperature: float = 0.6, top_p: float = 0.7, top_k: int = 20, repetition_penalty: float = 1.0) -> Iterator[str]:
36
- if chat_history is None:
37
- chat_history = []
38
- conversation = [{"role": "user", "content": user} for user, _ in chat_history]
39
- conversation.extend([{"role": "assistant", "content": assistant} for _, assistant in chat_history])
 
 
 
40
  conversation.append({"role": "user", "content": message})
41
 
42
- context = "\n".join(f"{entry['role']}: {entry['content']}" for entry in conversation)
43
- input_ids = tokenizer(context, return_tensors="pt", padding=True, truncation=True).input_ids.to(model_generate.device)
44
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
45
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
46
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
47
 
48
- model = model_edit if mode == 'edit' else model_generate
49
- outputs = []
50
- t = Thread(target=lambda: outputs.extend(run_model(input_ids, model, max_new_tokens, top_p, top_k, temperature, repetition_penalty)))
 
 
 
 
 
 
 
 
 
 
51
  t.start()
52
- t.join()
53
 
54
- for output in outputs:
55
- for text in tokenizer.decode(output, skip_special_tokens=True).split():
56
- processed_text = process_text(text)
57
- yield processed_text
58
 
59
- # Gradio Interface
60
- def switch_mode(mode: str, message: str, chat_history: List[Tuple[str, str]]):
61
- return list(generate_text(mode, message, chat_history))
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Row():
65
- mode_selector = gr.Radio(["generate", "edit"], label="Mode", value="generate")
66
- input_text = gr.Textbox(label="Input Text")
67
- output_text = gr.Textbox(label="Output")
68
- chat_history = gr.State() # Corrected 'default' keyword
69
 
70
- generate_button = gr.Button("Generate/Edit")
71
- generate_button.click(switch_mode, inputs=[mode_selector, input_text, chat_history], outputs=output_text)
72
 
73
- demo.launch()
 
 
 
 
1
+ import os
2
+ import re
3
  import torch
 
4
  from threading import Thread
5
+ from typing import Iterator
6
+ from mongoengine import connect, Document, StringField, SequenceField
7
  import gradio as gr
8
+ import spaces
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
10
+ from peft import PeftModel
11
 
12
  # Constants
13
+ MAX_MAX_NEW_TOKENS = 2048
14
  DEFAULT_MAX_NEW_TOKENS = 930
15
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
 
17
+ LICENSE = """
18
+ ---
19
+ As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
20
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
21
+ """
 
22
 
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
 
25
 
26
+ if torch.cuda.is_available():
27
+ modelA_id = "meta-llama/Llama-2-7b-chat-hf"
28
+ bnb_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_use_double_quant=False,
31
+ bnb_4bit_quant_type="nf4",
32
+ bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
33
  )
34
+ base_model = AutoModelForCausalLM.from_pretrained(modelA_id, device_map="auto", quantization_config=bnb_config)
35
+ modelA = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
36
+ tokenizerA = AutoTokenizer.from_pretrained(modelA_id)
37
+ tokenizerA.pad_token = tokenizerA.eos_token
38
+
39
+ modelB_id = "meta-llama/Llama-2-7b-chat-hf"
40
+ modelB = AutoModelForCausalLM.from_pretrained(modelB_id, torch_dtype=torch.float16, device_map="auto")
41
+ tokenizerB = AutoTokenizer.from_pretrained(modelB_id)
42
+ tokenizerB.use_default_system_prompt = False
43
+
44
+ def make_prompt(entry):
45
+ return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
46
+
47
+ @spaces.GPU
48
+ def generate(
49
+ model: str,
50
+ message: str,
51
+ chat_history: list[tuple[str, str]],
52
+ system_prompt: str,
53
+ max_new_tokens: int = 1024,
54
+ temperature: float = 0.6,
55
+ top_p: float = 0.9,
56
+ top_k: int = 50,
57
+ repetition_penalty: float = 1.2,
58
+ ) -> Iterator[str]:
59
+ if model == "A":
60
+ model = modelA
61
+ tokenizer = tokenizerA
62
+ enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
63
+ input_ids = enc.input_ids.to(model.device)
64
 
65
+ else:
66
+ model = modelB
67
+ tokenizer = tokenizerB
68
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
69
+ conversation = []
70
+ if system_prompt:
71
+ conversation.append({"role": "system", "content": system_prompt})
72
+ for user, assistant in chat_history:
73
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
74
  conversation.append({"role": "user", "content": message})
75
 
 
 
76
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
+ input_ids = input_ids.to(model.device)
80
 
81
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
+ generate_kwargs = dict(
83
+ {"input_ids": input_ids},
84
+ streamer=streamer,
85
+ max_new_tokens=max_new_tokens,
86
+ do_sample=True,
87
+ top_p=top_p,
88
+ top_k=top_k,
89
+ temperature=temperature,
90
+ num_beams=1,
91
+ repetition_penalty=repetition_penalty,
92
+ )
93
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
94
  t.start()
 
95
 
96
+ outputs = []
97
+ for text in streamer:
98
+ outputs.append(text)
99
+ yield "".join(outputs)
100
 
101
+ # Gradio Interface Setup
102
+ chat_interface = gr.ChatInterface(
103
+ fn=generate,
104
+ additional_inputs=[gr.Dropdown("Model", ["A", "B"], default="A")],
105
+ fill_height=True,
106
+ stop_btn=None,
107
+ examples=[
108
+ ["Can you explain briefly to me what is the Python programming language?"],
109
+ ["Could you please provide an explanation about the concept of recursion?"],
110
+ ["Could you explain what a URL is?"]
111
+ ],
112
+ theme='shivi/calm_seafoam'
113
+ )
114
 
115
+ # Gradio Web Interface
116
+ with gr.Blocks(theme='shivi/calm_seafoam',fill_height=True) as demo:
117
+ # gr.Markdown(DESCRIPTION)
118
+ chat_interface.render()
119
+ gr.Markdown(LICENSE)
 
120
 
 
 
121
 
122
+ # Main Execution
123
+ if __name__ == "__main__":
124
+ demo.queue(max_size=20)
125
+ demo.launch(share=True)