ranamhamoud commited on
Commit
f317c15
1 Parent(s): bcb8e3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -131
app.py CHANGED
@@ -1,36 +1,16 @@
1
  import os
2
- import re
3
  import torch
4
- from threading import Thread
5
- from typing import Iterator
6
- from mongoengine import connect, Document, StringField, SequenceField
7
- import gradio as gr
8
- import spaces
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
10
  from peft import PeftModel
 
 
11
 
12
  # Constants
13
- MAX_MAX_NEW_TOKENS = 2048
14
- DEFAULT_MAX_NEW_TOKENS = 930
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
16
 
17
- # # Description and License Texts
18
- # DESCRIPTION = """
19
- # # ✨Storytell AI🧑🏽‍💻
20
- # Welcome to the **Storytell AI** space, crafted with care by Ranam & George. Dive into the world of educational storytelling with our model. This iteration of the Llama 2 model with 7 billion parameters is fine-tuned to generate educational stories that engage and educate. Enjoy a journey of discovery and creativity—your storytelling lesson begins here! You can prompt this model to explain any computer science concept. **Please check the examples below**.
21
- # """
22
- LICENSE = """
23
- ---
24
- As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
25
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
26
- """
27
-
28
- # GPU Check and add CPU warning
29
- if not torch.cuda.is_available():
30
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
31
-
32
- # Model and Tokenizer Configuration
33
- model_id = "meta-llama/Llama-2-7b-chat-hf"
34
  bnb_config = BitsAndBytesConfig(
35
  load_in_4bit=True,
36
  bnb_4bit_use_double_quant=False,
@@ -38,117 +18,51 @@ bnb_config = BitsAndBytesConfig(
38
  bnb_4bit_compute_dtype=torch.bfloat16
39
  )
40
  base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
41
- model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
42
  tokenizer = AutoTokenizer.from_pretrained(model_id)
43
  tokenizer.pad_token = tokenizer.eos_token
44
 
45
- # MongoDB Connection
46
- PASSWORD = os.environ.get("MONGO_PASS")
47
- connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
48
-
49
- # MongoDB Document
50
- class Story(Document):
51
- message = StringField()
52
- content = StringField()
53
- story_id = SequenceField(primary_key=True)
54
-
55
- # Utility function for prompts
56
- def make_prompt(entry):
57
- return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
58
- # f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}"
59
-
60
- def process_text(text):
61
- text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
62
-
63
- return text
64
- custom_css = """
65
- body, input, button, textarea, label {
66
- font-family: Arial, sans-serif;
67
- font-size: 24px;
68
- }
69
- .gr-chat-interface .gr-chat-message-container {
70
- font-size: 14px;
71
- }
72
- .gr-button {
73
- font-size: 14px;
74
- padding: 12px 24px;
75
- }
76
- .gr-input {
77
- font-size: 14px;
78
- }
79
- """
80
-
81
- # Gradio Function
82
- @spaces.GPU
83
- def generate(
84
- message: str,
85
- chat_history: list[tuple[str, str]],
86
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
87
- temperature: float = 0.6,
88
- top_p: float = 0.7,
89
- top_k: int = 20,
90
- repetition_penalty: float = 1.0,
91
- ) -> Iterator[str]:
92
- conversation = []
93
- for user, assistant in chat_history:
94
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
- conversation.append({"role": "user", "content": make_prompt(message)})
96
- enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
97
- input_ids = enc.input_ids.to(model.device)
98
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
99
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
100
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
101
-
102
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
103
- generate_kwargs = dict(
104
- {"input_ids": input_ids},
105
- streamer=streamer,
106
- max_new_tokens=max_new_tokens,
107
- do_sample=True,
108
- top_p=top_p,
109
- top_k=top_k,
110
- temperature=temperature,
111
- num_beams=1,
112
- repetition_penalty=repetition_penalty,
113
- )
114
- t = Thread(target=model.generate, kwargs=generate_kwargs)
115
- t.start()
116
-
117
- outputs = []
118
- for text in streamer:
119
- processed_text = process_text(text)
120
- outputs.append(processed_text)
121
- output = "".join(outputs)
122
  yield output
 
123
 
124
- final_story = "".join(outputs)
125
- try:
126
- saved_story = Story(message=message, content=final_story).save()
127
- yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
128
- except Exception as e:
129
- yield f"Failed to save story: {str(e)}"
130
-
131
- # Gradio Interface Setup
132
- chat_interface = gr.ChatInterface(
133
- fn=generate,
134
- fill_height=True,
135
- stop_btn=None,
136
- examples=[
137
- ["Can you explain briefly to me what is the Python programming language?"],
138
- ["Could you please provide an explanation about the concept of recursion?"],
139
- ["Could you explain what a URL is?"]
140
- ],
141
- theme='shivi/calm_seafoam'
142
- )
143
-
144
- # Gradio Web Interface
145
- with gr.Blocks(css=custom_css,theme='shivi/calm_seafoam',fill_height=True) as demo:
146
- # gr.Markdown(DESCRIPTION)
147
- chat_interface.render()
148
- gr.Markdown(LICENSE)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  # Main Execution
152
  if __name__ == "__main__":
153
- demo.queue(max_size=20)
154
- demo.launch(share=True)
 
1
  import os
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
 
4
  from peft import PeftModel
5
+ import gradio as gr
6
+ from typing import Iterator, List, Tuple
7
 
8
  # Constants
 
 
9
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
10
+ DEFAULT_MAX_NEW_TOKENS = 930
11
 
12
+ # Model Configuration for Generating Mode
13
+ model_id = "meta-llama/Llama-2-7b-hf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  bnb_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
16
  bnb_4bit_use_double_quant=False,
 
18
  bnb_4bit_compute_dtype=torch.bfloat16
19
  )
20
  base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
21
+ model_generate = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
22
  tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  tokenizer.pad_token = tokenizer.eos_token
24
 
25
+ # Editing mode uses the same tokenizer but might use a simpler or different model setup
26
+ model_edit = model_generate # For simplicity, using the same model setup for editing in this example
27
+
28
+ # Helper Functions
29
+ def generate_text(input_text: str, chat_history: List[Tuple[str, str]], max_tokens: int = DEFAULT_MAX_NEW_TOKENS) -> Iterator[str]:
30
+ # Append the new message to the chat history for context
31
+ chat_history.append(("user", input_text))
32
+ # Prepare the input with the conversation context
33
+ context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history])
34
+ input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_generate.device)
35
+ outputs = model_generate.generate(input_ids, max_length=input_ids.shape[1] + max_tokens, do_sample=True)
36
+ for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  yield output
38
+ chat_history.append(("assistant", tokenizer.decode(outputs[0], skip_special_tokens=True)))
39
 
40
+ def edit_text(input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]:
41
+ context = "\n".join([f"{speaker}: {text}" for speaker, text in chat_history])
42
+ input_ids = tokenizer(context, return_tensors="pt").input_ids.to(model_edit.device)
43
+ outputs = model_edit.generate(input_ids, max_length=input_ids.shape[1] + DEFAULT_MAX_NEW_TOKENS, do_sample=True)
44
+ for output in tokenizer.decode(outputs[0], skip_special_tokens=True).split():
45
+ yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Gradio Interface
48
+ def switch_mode(is_editing: bool, input_text: str, chat_history: List[Tuple[str, str]]) -> Iterator[str]:
49
+ if is_editing and chat_history:
50
+ return edit_text(input_text, chat_history)
51
+ elif not is_editing:
52
+ return generate_text(input_text, chat_history)
53
+ else:
54
+ yield "Chat history is empty, cannot edit."
55
+
56
+ with gr.Blocks() as demo:
57
+ with gr.Row():
58
+ input_text = gr.Textbox(label="Input Text")
59
+ is_editing = gr.Checkbox(label="Editing Mode", value=False)
60
+ output_text = gr.Textbox(label="Output", interactive=True)
61
+ chat_history = gr.State([]) # Using State to maintain chat history
62
+
63
+ generate_button = gr.Button("Generate/Edit")
64
+ generate_button.click(switch_mode, inputs=[is_editing, input_text, chat_history], outputs=output_text)
65
 
66
  # Main Execution
67
  if __name__ == "__main__":
68
+ demo.launch()