shljessie commited on
Commit
6d36ca7
1 Parent(s): bbf1954

update model

Browse files
Files changed (2) hide show
  1. app.py +84 -48
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,43 +1,24 @@
1
  import os
2
- import threading
 
3
  import gradio as gr
 
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
- # # Check if CUDA is available
8
- # if not torch.cuda.is_available():
9
- # raise EnvironmentError("CUDA is not available. This script requires a GPU.")
10
 
11
- # Model Configuration
12
- # MODEL_ID = "meta-llama/Llama-2-7b-chat"
13
- # MAX_INPUT_TOKEN_LENGTH = 4096
14
- # MAX_NEW_TOKENS = 1024
15
- # TEMPERATURE = 0.6
16
- # TOP_P = 0.9
17
- # TOP_K = 50
18
- # REPETITION_PENALTY = 1.2
19
-
20
- # # Load the model and tokenizer
21
- # model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
22
- # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
-
24
- # def generate_response(user_input):
25
- # """
26
- # Generate a response to the user input using the Llama-2 7B model.
27
- # """
28
- # input_ids = tokenizer.encode(user_input, return_tensors="pt")
29
- # input_ids = input_ids.to(model.device)
30
-
31
- # # Generate a response
32
- # output = model.generate(input_ids, max_length=MAX_INPUT_TOKEN_LENGTH + len(input_ids[0]),
33
- # max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,
34
- # top_k=TOP_K, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY)
35
-
36
- # response = tokenizer.decode(output[0], skip_special_tokens=True)
37
- # return response
38
-
39
- # def chatbot_interface(user_input):
40
- # return generate_response(user_input)
41
 
42
  def yes_man(message, history):
43
  if message.endswith("?"):
@@ -45,20 +26,75 @@ def yes_man(message, history):
45
  else:
46
  return "Ask me anything!"
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  # Create the Gradio interface
52
- gr.ChatInterface(
53
- yes_man,
54
- chatbot=gr.Chatbot(height=300),
55
- textbox=gr.Textbox(placeholder="Ask me a yes or no question", container=False, scale=7),
56
- title="Yes Man",
57
- description="Ask Yes Man any question",
58
- theme="soft",
59
- examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
60
- cache_examples=True,
61
- retry_btn=None,
62
- undo_btn="Delete Previous",
63
- clear_btn="Clear",
64
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
  import gradio as gr
5
+ from typing import List, Tuple
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
+ import spaces
9
 
10
+ MAX_INPUT_TOKEN_LENGTH= 50
 
 
11
 
12
+ LICENSE = """
13
+ <p/>
14
+ ---
15
+ As a derivate work of [ConsistentAgents]() by Seonghee Lee.
16
+ """
17
+ if torch.cuda.is_available():
18
+ model_id = "shljessie/profile-model-69"
19
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ tokenizer.use_default_system_prompt = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def yes_man(message, history):
24
  if message.endswith("?"):
 
26
  else:
27
  return "Ask me anything!"
28
 
29
+ @spaces.GPU
30
+ def generate(
31
+ message: str,
32
+ chat_history: List[Tuple[str, str]],
33
+ system_prompt: str,
34
+ max_new_tokens: int = 1024,
35
+ temperature: float = 0.6,
36
+ top_p: float = 0.9,
37
+ top_k: int = 50,
38
+ repetition_penalty: float = 1.2,
39
+ ) -> Iterator[str]:
40
+ conversation = []
41
+ if system_prompt:
42
+ conversation.append({"role": "system", "content": system_prompt})
43
+ for user, assistant in chat_history:
44
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
45
+ conversation.append({"role": "user", "content": message})
46
 
47
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
48
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
49
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
50
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
51
+ input_ids = input_ids.to(model.device)
52
+
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
54
+ generate_kwargs = dict(
55
+ {"input_ids": input_ids},
56
+ streamer=streamer,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ top_p=top_p,
60
+ top_k=top_k,
61
+ temperature=temperature,
62
+ num_beams=1,
63
+ repetition_penalty=repetition_penalty,
64
+ )
65
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
66
+ t.start()
67
+
68
+ outputs = []
69
+ for text in streamer:
70
+ outputs.append(text)
71
+ yield "".join(outputs)
72
 
73
 
74
  # Create the Gradio interface
75
+ # gr.ChatInterface(
76
+ # yes_man,
77
+ # chatbot=gr.Chatbot(height=300),
78
+ # textbox=gr.Textbox(placeholder="Ask me a yes or no question", container=False, scale=7),
79
+ # title="Yes Man",
80
+ # description="Ask Yes Man any question",
81
+ # theme="soft",
82
+ # examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
83
+ # cache_examples=True,
84
+ # retry_btn=None,
85
+ # undo_btn="Delete Previous",
86
+ # clear_btn="Clear",
87
+ # ).launch()
88
+
89
+
90
+ chat_interface = gr.ChatInterface(
91
+ fn=generate,
92
+ additional_inputs=[
93
+ gr.Textbox(label="System prompt", lines=6),
94
+ ],
95
+ )
96
+
97
+ with gr.Blocks(css="style.css") as demo:
98
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
99
+ chat_interface.render()
100
+ gr.Markdown(LICENSE)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch
2
- transformers
 
 
1
  torch
2
+ transformers
3
+ spaces