ranamhamoud commited on
Commit
4d5d8af
โ€ข
1 Parent(s): 85f58d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -56
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import re
3
- import logging
4
  import torch
5
  from threading import Thread
6
  from typing import Iterator
@@ -12,7 +11,7 @@ from peft import PeftModel
12
 
13
  # Constants
14
  MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 930
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
 
18
  LICENSE = """
@@ -21,92 +20,120 @@ As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-
21
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
22
  """
23
 
 
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
26
 
27
- if torch.cuda.is_available():
28
- modelA_id = "meta-llama/Llama-2-7b-hf"
29
- bnb_config = BitsAndBytesConfig(
30
- load_in_4bit=True,
31
- bnb_4bit_use_double_quant=False,
32
- bnb_4bit_quant_type="nf4",
33
- bnb_4bit_compute_dtype=torch.bfloat16
34
- )
35
- base_model = AutoModelForCausalLM.from_pretrained(modelA_id, device_map="auto", quantization_config=bnb_config)
36
- modelA = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
37
- tokenizerA = AutoTokenizer.from_pretrained(modelA_id)
38
- tokenizerA.pad_token = tokenizerA.eos_token
 
 
39
 
40
- modelB_id = "meta-llama/Llama-2-7b-chat-hf"
41
- modelB = AutoModelForCausalLM.from_pretrained(modelB_id, torch_dtype=torch.float16, device_map="auto")
42
- tokenizerB = AutoTokenizer.from_pretrained(modelB_id)
43
- tokenizerB.use_default_system_prompt = False
44
- tokenizerB.pad_token = tokenizerB.eos_token
45
 
46
-
 
 
 
 
47
 
 
48
  def make_prompt(entry):
49
  return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
51
  @spaces.GPU
52
  def generate(
53
- model: str,
54
  message: str,
55
  chat_history: list[tuple[str, str]],
56
- max_new_tokens: int = 1024,
57
- # temperature: float = 0.6,
58
- # top_p: float = 0.9,
59
- # top_k: int = 50,
60
- # repetition_penalty: float = 1.2,
61
  ) -> Iterator[str]:
62
- if chat_history is None:
63
- logging.error("chat_history is None, initializing to empty list.")
64
- chat_history = [] # Initialize to an empty list if None is passed
65
-
66
  conversation = []
67
  for user, assistant in chat_history:
68
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
69
- conversation.append({"role": "user", "content": message})
70
- if model == "A":
71
- model = modelA
72
- tokenizer = tokenizerA
73
- else:
74
- model = modelB
75
- tokenizer = tokenizerB
76
-
77
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
78
- input_ids = enc.input_ids.to(model.device)
79
-
80
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
81
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
82
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
83
- input_ids = input_ids.to(model.device)
84
-
85
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
86
  generate_kwargs = dict(
87
  {"input_ids": input_ids},
88
  streamer=streamer,
89
  max_new_tokens=max_new_tokens,
90
  do_sample=True,
91
- # top_p=top_p,
92
- # top_k=top_k,
93
- # temperature=temperature,
94
- # num_beams=1,
95
- # repetition_penalty=repetition_penalty,
96
  )
97
  t = Thread(target=model.generate, kwargs=generate_kwargs)
98
  t.start()
99
 
100
  outputs = []
101
  for text in streamer:
102
- outputs.append(text)
103
- yield "".join(outputs)
104
- logging.basicConfig(level=logging.DEBUG)
 
 
 
 
 
 
 
 
105
 
106
  # Gradio Interface Setup
107
  chat_interface = gr.ChatInterface(
108
  fn=generate,
109
- additional_inputs=[gr.Dropdown(["A", "B"],label="Model", info="Will add more animals later!")],
110
  fill_height=True,
111
  stop_btn=None,
112
  examples=[
@@ -118,13 +145,12 @@ chat_interface = gr.ChatInterface(
118
  )
119
 
120
  # Gradio Web Interface
121
- with gr.Blocks(theme='shivi/calm_seafoam',fill_height=True) as demo:
122
- # gr.Markdown(DESCRIPTION)
123
  chat_interface.render()
124
- gr.Markdown(LICENSE)
125
 
126
 
127
  # Main Execution
128
  if __name__ == "__main__":
129
  demo.queue(max_size=20)
130
- demo.launch(share=True)
 
1
  import os
2
  import re
 
3
  import torch
4
  from threading import Thread
5
  from typing import Iterator
 
11
 
12
  # Constants
13
  MAX_MAX_NEW_TOKENS = 2048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
 
17
  LICENSE = """
 
20
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
21
  """
22
 
23
+ # GPU Check and add CPU warning
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
26
 
27
+ if torch.cuda.is_available():
28
+
29
+ # Model and Tokenizer Configuration
30
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
31
+ bnb_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_use_double_quant=False,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.bfloat16
36
+ )
37
+ base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
38
+ model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
39
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
40
+ tokenizer.pad_token = tokenizer.eos_token
41
 
42
+ # # MongoDB Connection
43
+ # PASSWORD = os.environ.get("MONGO_PASS")
44
+ # connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
 
 
45
 
46
+ # # MongoDB Document
47
+ # class Story(Document):
48
+ # message = StringField()
49
+ # content = StringField()
50
+ # story_id = SequenceField(primary_key=True)
51
 
52
+ # Utility function for prompts
53
  def make_prompt(entry):
54
  return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
55
+ # f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}"
56
+
57
+ def process_text(text):
58
+ # First, handle the specific case for [answer:]
59
+ # This replaces [answer:] with "Answer:" and keeps the content after it on the same line.
60
+ text = re.sub(r'\[answer:\]\s*', 'Answer: ', text)
61
+
62
+ # Now, remove all other content within brackets.
63
+ # This regex looks for square brackets and any content inside them, excluding those that start with "Answer: " already modified.
64
+ text = re.sub(r'\[.*?\](?<!Answer: )', '', text)
65
+
66
+ return text
67
+ custom_css = """
68
+ body, input, button, textarea, label {
69
+ font-family: Arial, sans-serif;
70
+ font-size: 24px;
71
+ }
72
+ .gr-chat-interface .gr-chat-message-container {
73
+ font-size: 14px;
74
+ }
75
+ .gr-button {
76
+ font-size: 14px;
77
+ padding: 12px 24px;
78
+ }
79
+ .gr-input {
80
+ font-size: 14px;
81
+ }
82
+ """
83
 
84
+ # Gradio Function
85
  @spaces.GPU
86
  def generate(
 
87
  message: str,
88
  chat_history: list[tuple[str, str]],
89
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
90
+ temperature: float = 0.6,
91
+ top_p: float = 0.7,
92
+ top_k: int = 20,
93
+ repetition_penalty: float = 1.0,
94
  ) -> Iterator[str]:
 
 
 
 
95
  conversation = []
96
  for user, assistant in chat_history:
97
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
98
+ conversation.append({"role": "user", "content": make_prompt(message)})
 
 
 
 
 
 
 
99
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
100
+ input_ids = enc.input_ids.to(model.device)
 
101
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
102
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
103
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
104
+
105
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
 
106
  generate_kwargs = dict(
107
  {"input_ids": input_ids},
108
  streamer=streamer,
109
  max_new_tokens=max_new_tokens,
110
  do_sample=True,
111
+ top_p=top_p,
112
+ top_k=top_k,
113
+ temperature=temperature,
114
+ num_beams=1,
115
+ repetition_penalty=repetition_penalty,
116
  )
117
  t = Thread(target=model.generate, kwargs=generate_kwargs)
118
  t.start()
119
 
120
  outputs = []
121
  for text in streamer:
122
+ processed_text = process_text(text)
123
+ outputs.append(processed_text)
124
+ output = "".join(outputs)
125
+ yield output
126
+
127
+ # final_story = "".join(outputs)
128
+ # try:
129
+ # saved_story = Story(message=message, content=final_story).save()
130
+ # yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
131
+ # except Exception as e:
132
+ # yield f"Failed to save story: {str(e)}"
133
 
134
  # Gradio Interface Setup
135
  chat_interface = gr.ChatInterface(
136
  fn=generate,
 
137
  fill_height=True,
138
  stop_btn=None,
139
  examples=[
 
145
  )
146
 
147
  # Gradio Web Interface
148
+ with gr.Blocks(css=custom_css,theme='shivi/calm_seafoam',fill_height=True) as demo:
 
149
  chat_interface.render()
150
+ # gr.Markdown(LICENSE)
151
 
152
 
153
  # Main Execution
154
  if __name__ == "__main__":
155
  demo.queue(max_size=20)
156
+ demo.launch(share=True)