yoonusajwardapiit commited on
Commit
9d37c49
·
verified ·
1 Parent(s): ea9f47a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -82,7 +82,7 @@ class BigramLanguageModel(nn.Module):
82
 
83
  def generate(self, idx, max_new_tokens):
84
  for _ in range(max_new_tokens):
85
- idx_cond = idx[:, -32:]
86
  logits, _ = self(idx_cond)
87
  logits = logits[:, -1, :]
88
  probs = nn.functional.softmax(logits, dim=-1)
@@ -112,11 +112,19 @@ decode = lambda l: ''.join([itos[i] for i in l])
112
  def generate_text(prompt):
113
  try:
114
  print(f"Received prompt: {prompt}")
115
- context = torch.tensor([encode(prompt)], dtype=torch.long)
 
 
 
 
 
 
116
  print(f"Encoded prompt: {context}")
 
117
  with torch.no_grad():
118
  generated = model.generate(context, max_new_tokens=250) # Adjust as needed
119
  print(f"Generated tensor: {generated}")
 
120
  result = decode(generated[0].tolist())
121
  print(f"Decoded result: {result}")
122
  return result
 
82
 
83
  def generate(self, idx, max_new_tokens):
84
  for _ in range(max_new_tokens):
85
+ idx_cond = idx[:, -32:] # Ensure context length does not exceed block size
86
  logits, _ = self(idx_cond)
87
  logits = logits[:, -1, :]
88
  probs = nn.functional.softmax(logits, dim=-1)
 
112
  def generate_text(prompt):
113
  try:
114
  print(f"Received prompt: {prompt}")
115
+ encoded_prompt = encode(prompt)
116
+
117
+ # Ensure the prompt length fits within the block size
118
+ if len(encoded_prompt) > 32:
119
+ encoded_prompt = encoded_prompt[:32] # Truncate to fit block size
120
+
121
+ context = torch.tensor([encoded_prompt], dtype=torch.long)
122
  print(f"Encoded prompt: {context}")
123
+
124
  with torch.no_grad():
125
  generated = model.generate(context, max_new_tokens=250) # Adjust as needed
126
  print(f"Generated tensor: {generated}")
127
+
128
  result = decode(generated[0].tolist())
129
  print(f"Decoded result: {result}")
130
  return result