ranamhamoud commited on
Commit
4656d45
โ€ข
1 Parent(s): e174867

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -52
app.py CHANGED
@@ -29,26 +29,19 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
29
  if not torch.cuda.is_available():
30
  DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
31
 
32
- if torch.cuda.is_available():
33
- model_id = "meta-llama/Llama-2-7b-hf"
34
- bnb_config = BitsAndBytesConfig(
35
- load_in_4bit=True,
36
- bnb_4bit_use_double_quant=False,
37
- bnb_4bit_quant_type="nf4",
38
- bnb_4bit_compute_dtype=torch.bfloat16
39
- )
40
- base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
41
- storytell_model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
42
- storytell_tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- storytell_tokenizer.pad_token = storytell_tokenizer.eos_token
44
-
45
-
46
- editing_model_id = "meta-llama/Llama-2-7b-chat-hf"
47
- editing_model = AutoModelForCausalLM.from_pretrained(editing_model_id, torch_dtype=torch.float16, device_map="auto")
48
- editing_tokenizer = AutoTokenizer.from_pretrained(model_id)
49
- editing_tokenizer.use_default_system_prompt = False
50
 
51
-
52
  # MongoDB Connection
53
  PASSWORD = os.environ.get("MONGO_PASS")
54
  connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
@@ -69,10 +62,9 @@ def process_text(text):
69
 
70
  return text
71
 
72
-
73
  @spaces.GPU
74
  def generate(
75
- model_choice: str,
76
  message: str,
77
  chat_history: list[tuple[str, str]],
78
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
@@ -81,38 +73,19 @@ def generate(
81
  top_k: int = 20,
82
  repetition_penalty: float = 1.0,
83
  ) -> Iterator[str]:
84
- if chat_history is None:
85
- chat_history = []
86
-
87
  conversation = []
88
- if model_choice == "Storytell":
89
- model = storytell_model
90
- tokenizer = storytell_tokenizer
91
- else:
92
- model = editing_model
93
- tokenizer = editing_tokenizer
94
-
95
- # Checking each tuple in chat_history to ensure it has exactly two elements
96
- for item in chat_history:
97
- if isinstance(item, tuple) and len(item) == 2:
98
- user, assistant = item
99
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
100
- else:
101
- print(f"Error in chat history item: {item}. Each item must be a tuple with exactly two elements.")
102
- continue # Skip this item or handle appropriately
103
-
104
- # Append the current user message
105
- conversation.append({"role": "user", "content": message})
106
-
107
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
108
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
109
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
110
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
111
- input_ids = input_ids.to(model.device)
112
-
113
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
114
  generate_kwargs = dict(
115
- input_ids=input_ids,
116
  streamer=streamer,
117
  max_new_tokens=max_new_tokens,
118
  do_sample=True,
@@ -127,9 +100,10 @@ def generate(
127
 
128
  outputs = []
129
  for text in streamer:
130
- outputs.append(text)
131
- yield "".join(outputs)
132
-
 
133
 
134
  final_story = "".join(outputs)
135
  try:
@@ -142,7 +116,6 @@ def generate(
142
  chat_interface = gr.ChatInterface(
143
  fn=generate,
144
  stop_btn=None,
145
- additional_inputs=[gr.Dropdown(["Storytell", "HF Meta Llama 7b Chat"], label="Choose Model")],
146
  examples=[
147
  ["Can you explain briefly to me what is the Python programming language?"],
148
  ["Could you please provide an explanation about the concept of recursion?"],
 
29
  if not torch.cuda.is_available():
30
  DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
31
 
32
+ # Model and Tokenizer Configuration
33
+ model_id = "meta-llama/Llama-2-7b-hf"
34
+ bnb_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_use_double_quant=False,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_compute_dtype=torch.bfloat16
39
+ )
40
+ base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
41
+ model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
42
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
43
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
44
 
 
45
  # MongoDB Connection
46
  PASSWORD = os.environ.get("MONGO_PASS")
47
  connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
 
62
 
63
  return text
64
 
65
+ # Gradio Function
66
  @spaces.GPU
67
  def generate(
 
68
  message: str,
69
  chat_history: list[tuple[str, str]],
70
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
 
73
  top_k: int = 20,
74
  repetition_penalty: float = 1.0,
75
  ) -> Iterator[str]:
 
 
 
76
  conversation = []
77
+ for user, assistant in chat_history:
78
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
79
+ conversation.append({"role": "user", "content": make_prompt(message)})
80
+ enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
81
+ input_ids = enc.input_ids.to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
83
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
84
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
85
+
86
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
 
87
  generate_kwargs = dict(
88
+ {"input_ids": input_ids},
89
  streamer=streamer,
90
  max_new_tokens=max_new_tokens,
91
  do_sample=True,
 
100
 
101
  outputs = []
102
  for text in streamer:
103
+ processed_text = process_text(text)
104
+ outputs.append(processed_text)
105
+ output = "".join(outputs)
106
+ yield output
107
 
108
  final_story = "".join(outputs)
109
  try:
 
116
  chat_interface = gr.ChatInterface(
117
  fn=generate,
118
  stop_btn=None,
 
119
  examples=[
120
  ["Can you explain briefly to me what is the Python programming language?"],
121
  ["Could you please provide an explanation about the concept of recursion?"],