shubhamrooter commited on
Commit
121060a
·
verified ·
1 Parent(s): 72e07ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -27
app.py CHANGED
@@ -5,44 +5,74 @@ import torch
5
  # Load model and tokenizer
6
  model_name = "0dAI/0dAI-8x7b-0761"
7
 
8
- @cache_resource
9
  def load_model():
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_name,
13
- torch_dtype=torch.float16,
14
- device_map="auto",
15
- trust_remote_code=True
16
- )
17
- return tokenizer, model
 
 
 
 
 
18
 
19
  def generate_text(prompt, max_length=512):
20
  tokenizer, model = load_model()
21
 
22
- inputs = tokenizer(prompt, return_tensors="pt")
23
- with torch.no_grad():
24
- outputs = model.generate(
25
- **inputs,
26
- max_length=max_length,
27
- num_return_sequences=1,
28
- temperature=0.7,
29
- do_sample=True,
30
- pad_token_id=tokenizer.eos_token_id
31
- )
32
 
33
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Create Gradio interface
37
- iface = gr.Interface(
38
  fn=generate_text,
39
  inputs=[
40
- gr.Textbox(label="Input Prompt", lines=3),
41
- gr.Slider(100, 1024, value=512, label="Max Length")
 
 
 
 
 
42
  ],
43
- outputs=gr.Textbox(label="Generated Text", lines=5),
44
  title="0dAI 8x7B Model Demo",
45
- description="Interactive demo for the 0dAI/0dAI-8x7b-0761 model"
 
 
 
 
 
46
  )
47
 
48
- iface.launch()
 
 
5
  # Load model and tokenizer
6
  model_name = "0dAI/0dAI-8x7b-0761"
7
 
8
+ @gr.cache_resource
9
  def load_model():
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto",
16
+ trust_remote_code=True,
17
+ low_cpu_mem_usage=True
18
+ )
19
+ return tokenizer, model
20
+ except Exception as e:
21
+ print(f"Error loading model: {e}")
22
+ return None, None
23
 
24
  def generate_text(prompt, max_length=512):
25
  tokenizer, model = load_model()
26
 
27
+ if tokenizer is None or model is None:
28
+ return "Error: Model failed to load. Please check the logs."
 
 
 
 
 
 
 
 
29
 
30
+ try:
31
+ # Tokenize input
32
+ inputs = tokenizer(prompt, return_tensors="pt")
33
+
34
+ # Move to same device as model
35
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
36
+
37
+ # Generate response
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_new_tokens=max_length,
42
+ temperature=0.7,
43
+ do_sample=True,
44
+ pad_token_id=tokenizer.eos_token_id,
45
+ repetition_penalty=1.1
46
+ )
47
+
48
+ # Decode response
49
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return response[len(prompt):] # Return only the generated part
51
+
52
+ except Exception as e:
53
+ return f"Error during generation: {str(e)}"
54
 
55
  # Create Gradio interface
56
+ demo = gr.Interface(
57
  fn=generate_text,
58
  inputs=[
59
+ gr.Textbox(
60
+ label="Input Prompt",
61
+ lines=3,
62
+ placeholder="Enter your prompt here...",
63
+ value="Hello, how are you?"
64
+ ),
65
+ gr.Slider(50, 1024, value=256, label="Max New Tokens")
66
  ],
67
+ outputs=gr.Textbox(label="Generated Text", lines=8),
68
  title="0dAI 8x7B Model Demo",
69
+ description="Interactive demo for the 0dAI/0dAI-8x7b-0761 model. This is a large model, so initial loading may take a few minutes.",
70
+ examples=[
71
+ ["Explain quantum computing in simple terms."],
72
+ ["Write a short story about a robot learning to paint."],
73
+ ["What are the benefits of renewable energy?"]
74
+ ]
75
  )
76
 
77
+ if __name__ == "__main__":
78
+ demo.launch(server_name="0.0.0.0", server_port=7860)