yoonusajwardapiit commited on
Commit
474ebab
1 Parent(s): 141eb85

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -81,11 +81,12 @@ class BigramLanguageModel(nn.Module):
81
  logits = self.lm_head(x)
82
  return logits, None
83
 
84
- def generate(self, idx, max_new_tokens):
85
  for _ in range(max_new_tokens):
86
  idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
87
  logits, _ = self(idx_cond)
88
  logits = logits[:, -1, :] # Get the logits for the last token
 
89
  probs = nn.functional.softmax(logits, dim=-1)
90
  idx_next = torch.multinomial(probs, num_samples=1)
91
  idx_next = torch.clamp(idx_next, min=0, max=60) # Strictly enforce index range [0, 60]
@@ -129,13 +130,17 @@ def generate_text(prompt):
129
  print(f"Encoded prompt: {context}")
130
 
131
  with torch.no_grad():
132
- generated = model.generate(context, max_new_tokens=20) # Reduced tokens to speed up
133
  print(f"Generated tensor: {generated}")
134
 
135
  result = decode(generated[0].tolist())
136
  print(f"Decoded result: {result}")
 
 
 
 
137
  print(f"Processing time: {time.time() - start_time:.2f}s")
138
- return result
139
  except Exception as e:
140
  print(f"Error during generation: {e}")
141
  return f"Error: {str(e)}"
 
81
  logits = self.lm_head(x)
82
  return logits, None
83
 
84
+ def generate(self, idx, max_new_tokens, temperature=0.7):
85
  for _ in range(max_new_tokens):
86
  idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
87
  logits, _ = self(idx_cond)
88
  logits = logits[:, -1, :] # Get the logits for the last token
89
+ logits = logits / temperature # Apply temperature control
90
  probs = nn.functional.softmax(logits, dim=-1)
91
  idx_next = torch.multinomial(probs, num_samples=1)
92
  idx_next = torch.clamp(idx_next, min=0, max=60) # Strictly enforce index range [0, 60]
 
130
  print(f"Encoded prompt: {context}")
131
 
132
  with torch.no_grad():
133
+ generated = model.generate(context, max_new_tokens=20, temperature=0.7) # Adjust temperature
134
  print(f"Generated tensor: {generated}")
135
 
136
  result = decode(generated[0].tolist())
137
  print(f"Decoded result: {result}")
138
+
139
+ # Post-process to clean up and make output more readable
140
+ cleaned_result = result.replace('\n', ' ').strip()
141
+ print(f"Cleaned result: {cleaned_result}")
142
  print(f"Processing time: {time.time() - start_time:.2f}s")
143
+ return cleaned_result
144
  except Exception as e:
145
  print(f"Error during generation: {e}")
146
  return f"Error: {str(e)}"