sagar007 commited on
Commit
ec014a4
1 Parent(s): 76cf633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -20
app.py CHANGED
@@ -3,19 +3,16 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
6
- import asyncio
7
 
8
- # Try to import spaces, use a dummy decorator if not available
9
  try:
10
  import spaces
11
  use_spaces_gpu = True
12
  except ImportError:
13
  use_spaces_gpu = False
14
- # Dummy decorator in case spaces is not available
15
  def dummy_gpu_decorator(func):
16
  return func
17
  spaces = type('', (), {'GPU': dummy_gpu_decorator})()
18
-
19
  # Define the GPTConfig class
20
  class GPTConfig:
21
  def __init__(self):
@@ -131,10 +128,10 @@ def load_model(model_path):
131
 
132
  enc = tiktoken.get_encoding('gpt2')
133
 
 
134
  # Update the generate_text function
135
- @spaces.GPU(duration=60)
136
- async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
137
- # Load the model inside the GPU-decorated function
138
  model = load_model('gpt_model.pth')
139
  device = next(model.parameters()).device
140
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
@@ -153,24 +150,15 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
153
  input_ids = torch.cat([input_ids, next_token], dim=-1)
154
  generated.append(next_token.item())
155
 
156
- next_token_str = enc.decode([next_token.item()])
157
- yield next_token_str
158
-
159
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
160
  break
161
-
162
- await asyncio.sleep(0.02)
163
 
164
- if len(generated) == max_length:
165
- yield "... (output truncated due to length)"
166
 
167
  # Add the gradio_generate function
168
- @spaces.GPU(duration=60)
169
- async def gradio_generate(prompt, max_length, temperature, top_k):
170
- output = ""
171
- async for token in generate_text(prompt, max_length, temperature, top_k):
172
- output += token
173
- yield output
174
 
175
 
176
  # # Your existing imports and model code here...
 
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
 
6
 
 
7
  try:
8
  import spaces
9
  use_spaces_gpu = True
10
  except ImportError:
11
  use_spaces_gpu = False
 
12
  def dummy_gpu_decorator(func):
13
  return func
14
  spaces = type('', (), {'GPU': dummy_gpu_decorator})()
15
+
16
  # Define the GPTConfig class
17
  class GPTConfig:
18
  def __init__(self):
 
128
 
129
  enc = tiktoken.get_encoding('gpt2')
130
 
131
+
132
  # Update the generate_text function
133
+ @spaces.GPU
134
+ def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
 
135
  model = load_model('gpt_model.pth')
136
  device = next(model.parameters()).device
137
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
 
150
  input_ids = torch.cat([input_ids, next_token], dim=-1)
151
  generated.append(next_token.item())
152
 
 
 
 
153
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
154
  break
 
 
155
 
156
+ return enc.decode(generated)
 
157
 
158
  # Add the gradio_generate function
159
+ @spaces.GPU
160
+ def gradio_generate(prompt, max_length, temperature, top_k):
161
+ return generate_text(prompt, max_length, temperature, top_k)
 
 
 
162
 
163
 
164
  # # Your existing imports and model code here...