ranamhamoud commited on
Commit
297485e
1 Parent(s): 1661753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -38
app.py CHANGED
@@ -1,82 +1,82 @@
1
  import os
 
2
  from threading import Thread
3
  from typing import Iterator
4
-
5
  from mongoengine import connect, Document, StringField, SequenceField
6
-
7
  import gradio as gr
8
  import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
11
  from peft import PeftModel
12
 
 
13
  MAX_MAX_NEW_TOKENS = 2048
14
  DEFAULT_MAX_NEW_TOKENS = 1024
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
 
17
- DESCRIPTION = """\
 
18
  # ✨Storytell AI🧑🏽‍💻
19
- Welcome to the **Storytell AI** space, crafted with care by Ranam & George. Dive into the world of educational storytelling with our [Storytell](https://huggingface.co/ranamhamoud/storytell) model. This iteration of the Llama 2 model with 7 billion parameters is fine-tuned to generate educational stories that engage and educate. Enjoy a journey of discovery and creativity—your storytelling lesson begins here! You can prompt this model to explain any computer science concept. **Please check the examples below**.
20
  """
21
-
22
-
23
  LICENSE = """
24
- <p/>
25
  ---
26
- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
27
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
28
  """
29
 
 
30
  if not torch.cuda.is_available():
31
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- if torch.cuda.is_available():
35
- bnb_config = BitsAndBytesConfig(
36
- load_in_8bit=True,
37
- bnb_4bit_compute_dtype=torch.float16,
38
- )
39
- model_id = "meta-llama/Llama-2-7b-chat-hf"
40
- base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",quantization_config=bnb_config)
41
- model = PeftModel.from_pretrained(base_model,"ranamhamoud/storytell")
42
- tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- tokenizer.pad_token = tokenizer.eos_token
44
-
45
  PASSWORD = os.environ.get("MONGO_PASS")
46
- connect(host = f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
47
 
 
48
  class Story(Document):
49
  message = StringField()
50
  content = StringField()
51
  story_id = SequenceField(primary_key=True)
52
-
 
53
  def make_prompt(entry):
54
- return f"TELL A STORY,RELATE TO COMPUTER SCIENCE,INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS: {entry} "
55
-
 
 
56
  @spaces.GPU
57
  def generate(
58
  message: str,
59
  chat_history: list[tuple[str, str]],
60
- max_new_tokens: int = 1024,
61
- temperature: float = 0.1,
62
- top_p: float = 0.6,
63
- top_k: int = 20,
64
  repetition_penalty: float = 1.0,
65
  ) -> Iterator[str]:
66
  conversation = []
67
  for user, assistant in chat_history:
68
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
69
  conversation.append({"role": "user", "content": make_prompt(message)})
70
-
71
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
72
-
73
-
74
- input_ids = enc.input_ids
75
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
76
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
77
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
78
- input_ids = input_ids.to(model.device)
79
-
80
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
81
  generate_kwargs = dict(
82
  {"input_ids": input_ids},
@@ -98,26 +98,29 @@ def generate(
98
  yield "".join(outputs)
99
  final_story = "".join(outputs)
100
  try:
101
- saved_story = Story(message=message, content=final_story).save()
102
  yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
103
  except Exception as e:
104
  yield f"Failed to save story: {str(e)}"
105
 
 
106
  chat_interface = gr.ChatInterface(
107
  fn=generate,
108
  stop_btn=None,
109
  examples=[
110
  ["Can you explain briefly to me what is the Python programming language?"],
111
- ["Could you please provide an explanation about the concept of recursion?"],
112
  ["Could you explain what a URL is?"]
113
  ],
114
  )
115
 
 
116
  with gr.Blocks(css="style.css") as demo:
117
  gr.Markdown(DESCRIPTION)
118
  chat_interface.render()
119
  gr.Markdown(LICENSE)
120
 
 
121
  if __name__ == "__main__":
122
  demo.queue(max_size=20)
123
- demo.launch(share=True)
 
1
  import os
2
+ import torch
3
  from threading import Thread
4
  from typing import Iterator
 
5
  from mongoengine import connect, Document, StringField, SequenceField
 
6
  import gradio as gr
7
  import spaces
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
9
  from peft import PeftModel
10
 
11
+ # Constants
12
  MAX_MAX_NEW_TOKENS = 2048
13
  DEFAULT_MAX_NEW_TOKENS = 1024
14
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
15
 
16
+ # Description and License Texts
17
+ DESCRIPTION = """
18
  # ✨Storytell AI🧑🏽‍💻
19
+ Welcome to the **Storytell AI** space, crafted with care by Ranam & George. Dive into the world of educational storytelling with our model. This iteration of the Llama 2 model with 7 billion parameters is fine-tuned to generate educational stories that engage and educate. Enjoy a journey of discovery and creativity—your storytelling lesson begins here! You can prompt this model to explain any computer science concept. **Please check the examples below**.
20
  """
 
 
21
  LICENSE = """
 
22
  ---
23
+ As a derivative work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
24
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
25
  """
26
 
27
+ # GPU Check and add CPU warning
28
  if not torch.cuda.is_available():
29
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
30
 
31
+ # Model and Tokenizer Configuration
32
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
33
+ bnb_config = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_use_double_quant=False,
36
+ bnb_4bit_quant_type="nf4",
37
+ bnb_4bit_compute_dtype=torch.bfloat16
38
+ )
39
+ base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
40
+ model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ tokenizer.pad_token = tokenizer.eos_token
43
 
44
+ # MongoDB Connection
 
 
 
 
 
 
 
 
 
 
45
  PASSWORD = os.environ.get("MONGO_PASS")
46
+ connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
47
 
48
+ # MongoDB Document
49
  class Story(Document):
50
  message = StringField()
51
  content = StringField()
52
  story_id = SequenceField(primary_key=True)
53
+
54
+ # Utility function for prompts
55
  def make_prompt(entry):
56
+ return f"### Human: {entry} ### Assistant:"
57
+ # f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS: {entry}"
58
+
59
+ # Gradio Function
60
  @spaces.GPU
61
  def generate(
62
  message: str,
63
  chat_history: list[tuple[str, str]],
64
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
65
+ temperature: float = 0.3,
66
+ top_p: float = 0.7,
67
+ top_k: int = 20,
68
  repetition_penalty: float = 1.0,
69
  ) -> Iterator[str]:
70
  conversation = []
71
  for user, assistant in chat_history:
72
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
73
  conversation.append({"role": "user", "content": make_prompt(message)})
 
74
  enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
75
+ input_ids = enc.input_ids.to(model.device)
 
 
76
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
+
 
80
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
81
  generate_kwargs = dict(
82
  {"input_ids": input_ids},
 
98
  yield "".join(outputs)
99
  final_story = "".join(outputs)
100
  try:
101
+ saved_story = Story(message=message, content=final_story).save()
102
  yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
103
  except Exception as e:
104
  yield f"Failed to save story: {str(e)}"
105
 
106
+ # Gradio Interface Setup
107
  chat_interface = gr.ChatInterface(
108
  fn=generate,
109
  stop_btn=None,
110
  examples=[
111
  ["Can you explain briefly to me what is the Python programming language?"],
112
+ ["Could you please provide an explanation about the concept of recursion?"],
113
  ["Could you explain what a URL is?"]
114
  ],
115
  )
116
 
117
+ # Gradio Web Interface
118
  with gr.Blocks(css="style.css") as demo:
119
  gr.Markdown(DESCRIPTION)
120
  chat_interface.render()
121
  gr.Markdown(LICENSE)
122
 
123
+ # Main Execution
124
  if __name__ == "__main__":
125
  demo.queue(max_size=20)
126
+ demo.launch(share=True)