Artples commited on
Commit
d05908c
1 Parent(s): 1c84354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -33
app.py CHANGED
@@ -36,47 +36,18 @@ def generate(
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
  ) -> Iterator[str]:
39
- model_id = model_options[model_choice]
40
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
41
- tokenizer = AutoTokenizer.from_pretrained(model_id)
42
- tokenizer.use_default_system_prompt = False
43
-
44
- conversation = []
45
- if system_prompt:
46
- conversation.append({"role": "system", "content": system_prompt})
47
- for user, assistant in chat_history:
48
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
- conversation.append({"role": "user", "content": message})
50
-
51
- input_ids = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True)
52
- if input_ids['input_ids'].shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
- input_ids['input_ids'] = input_ids['input_ids'][:, -MAX_INPUT_TOKEN_LENGTH:]
54
-
55
- outputs = model.generate(
56
- **input_ids,
57
- max_length=input_ids['input_ids'].shape[1] + max_new_tokens,
58
- top_p=top_p,
59
- top_k=top_k,
60
- temperature=temperature,
61
- num_return_sequences=1,
62
- repetition_penalty=repetition_penalty
63
- )
64
-
65
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- yield generated_text
67
 
68
  chat_interface = gr.Interface(
69
  fn=generate,
70
  inputs=[
71
  gr.Textbox(lines=2, placeholder="Type your message here..."),
72
  gr.Dropdown(label="Choose Model", choices=list(model_options.keys())),
73
- gr.State(label="Chat History", default=[]),
74
  gr.Textbox(label="System Prompt", lines=6, placeholder="Enter system prompt if any..."),
75
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
76
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.1),
77
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
78
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
79
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
80
  ],
81
  outputs=[gr.Textbox(label="Response")],
82
  theme="default",
 
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
  ) -> Iterator[str]:
39
+ # Your existing function implementation...
40
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  chat_interface = gr.Interface(
43
  fn=generate,
44
  inputs=[
45
  gr.Textbox(lines=2, placeholder="Type your message here..."),
46
  gr.Dropdown(label="Choose Model", choices=list(model_options.keys())),
47
+ chat_history, # Updated to include state without label
48
  gr.Textbox(label="System Prompt", lines=6, placeholder="Enter system prompt if any..."),
49
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
50
+ # More inputs as previously defined
 
 
 
51
  ],
52
  outputs=[gr.Textbox(label="Response")],
53
  theme="default",