Remostart commited on
Commit
18231b8
·
verified ·
1 Parent(s): 129037c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -2,10 +2,15 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load the fine-tuned Llama-3-8B model and tokenizer
6
- model_name = "ubiodee/Test_Plutus"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
9
 
10
  # Set padding token if not already set
11
  if tokenizer.pad_token is None:
@@ -23,25 +28,28 @@ def generate_text(prompt, max_length=200, temperature=0.7, top_p=0.9):
23
  temperature=temperature,
24
  top_p=top_p,
25
  do_sample=True,
26
- num_return_sequences=1
 
27
  )
28
 
29
  # Decode the generated text
30
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
31
  return generated_text
32
 
33
  # Create Gradio interface
34
  demo = gr.Interface(
35
  fn=generate_text,
36
  inputs=[
37
- gr.Textbox(label="Input Prompt", placeholder="Enter your prompt here..."),
38
  gr.Slider(label="Max Length", minimum=50, maximum=500, value=200, step=10),
39
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1),
40
  gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.05)
41
  ],
42
- outputs=gr.Textbox(label="Generated Text"),
43
- title="Fine-Tuned Llama-3-8B Demo",
44
- description="Interact with the fine-tuned Llama-3-8B model (ubiodee/Test_Plutus) to generate text based on your prompt."
45
  )
46
 
47
  if __name__ == "__main__":
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load the fine-tuned Llama-3-8B model and tokenizer for ubiodee/plutus_llm
6
+ model_name = "ubiodee/plutus_llm"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Safeguard against fast tokenizer issues
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ torch_dtype=torch.float16,
11
+ device_map="auto",
12
+ load_in_8bit=True # Enable 8-bit quantization as per model specs
13
+ )
14
 
15
  # Set padding token if not already set
16
  if tokenizer.pad_token is None:
 
28
  temperature=temperature,
29
  top_p=top_p,
30
  do_sample=True,
31
+ num_return_sequences=1,
32
+ pad_token_id=tokenizer.eos_token_id
33
  )
34
 
35
  # Decode the generated text
36
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ # Remove the input prompt from the output for cleaner response
38
+ generated_text = generated_text[len(prompt):].strip()
39
  return generated_text
40
 
41
  # Create Gradio interface
42
  demo = gr.Interface(
43
  fn=generate_text,
44
  inputs=[
45
+ gr.Textbox(label="Input Prompt", placeholder="Enter your prompt here...", lines=3),
46
  gr.Slider(label="Max Length", minimum=50, maximum=500, value=200, step=10),
47
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1),
48
  gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.05)
49
  ],
50
+ outputs=gr.Textbox(label="Generated Text", lines=10),
51
+ title="Plutus LLM Demo (ubiodee/plutus_llm)",
52
+ description="Interact with the fine-tuned Llama-3-8B model using LoRA and 8-bit quantization. This is based on ubiodee/plutus_llm."
53
  )
54
 
55
  if __name__ == "__main__":