ninagala commited on
Commit
ee15bd8
·
verified ·
1 Parent(s): 5873e46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -104,15 +104,12 @@ class TransformerDecoder(nn.Module):
104
  return output
105
 
106
  @classmethod
107
- def from_pretrained(cls, model_path: str, device: str = 'cpu'):
108
- """Load a pretrained model from a directory"""
109
  try:
110
- # Load config
111
- config_path = os.path.join(model_path, "config.json")
112
- if not os.path.exists(config_path):
113
- raise FileNotFoundError(f"Config not found at {config_path}")
114
-
115
- with open(config_path) as f:
116
  config = json.load(f)
117
 
118
  # Create model instance
@@ -126,25 +123,23 @@ class TransformerDecoder(nn.Module):
126
  dropout=config.get('dropout', 0.1)
127
  )
128
 
129
- # Load weights
130
- weights_path = os.path.join(model_path, "pytorch_model.bin")
131
- if not os.path.exists(weights_path):
132
- raise FileNotFoundError(f"Weights not found at {weights_path}")
133
-
134
- state_dict = torch.load(weights_path, map_location=device)
135
  model.load_state_dict(state_dict)
136
 
137
  return model.to(device)
138
 
139
  except Exception as e:
140
- raise Exception(f"Error loading model from {model_path}: {str(e)}")
141
 
142
  def generate_text(prompt, max_length=100, temperature=0.7):
143
  try:
144
  # Load model and tokenizer from Hugging Face Hub
145
  model_id = "ninagala/shakespeare-model"
146
- tokenizer_file = hf_hub_download(repo_id=model_id, filename="tokenizer.json")
147
 
 
 
148
  model = TransformerDecoder.from_pretrained(model_id)
149
  tokenizer = Tokenizer.from_file(tokenizer_file)
150
 
@@ -153,6 +148,8 @@ def generate_text(prompt, max_length=100, temperature=0.7):
153
  tokens = tokenizer.encode(prompt).ids
154
  input_ids = torch.tensor(tokens).unsqueeze(0)
155
 
 
 
156
  with torch.no_grad():
157
  for _ in range(max_length):
158
  outputs = model(input_ids)
@@ -161,7 +158,10 @@ def generate_text(prompt, max_length=100, temperature=0.7):
161
  next_token = torch.multinomial(probs, num_samples=1)
162
  input_ids = torch.cat([input_ids, next_token], dim=1)
163
 
164
- if next_token.item() == tokenizer.token_to_id("[EOS]"):
 
 
 
165
  break
166
 
167
  return tokenizer.decode(input_ids[0].tolist())
 
104
  return output
105
 
106
  @classmethod
107
+ def from_pretrained(cls, model_id: str, device: str = 'cpu'):
108
+ """Load a pretrained model from Hugging Face Hub"""
109
  try:
110
+ # Download config
111
+ config_file = hf_hub_download(repo_id=model_id, filename="config.json")
112
+ with open(config_file) as f:
 
 
 
113
  config = json.load(f)
114
 
115
  # Create model instance
 
123
  dropout=config.get('dropout', 0.1)
124
  )
125
 
126
+ # Download and load weights
127
+ weights_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
128
+ state_dict = torch.load(weights_file, map_location=device)
 
 
 
129
  model.load_state_dict(state_dict)
130
 
131
  return model.to(device)
132
 
133
  except Exception as e:
134
+ raise Exception(f"Error loading model from {model_id}: {str(e)}")
135
 
136
  def generate_text(prompt, max_length=100, temperature=0.7):
137
  try:
138
  # Load model and tokenizer from Hugging Face Hub
139
  model_id = "ninagala/shakespeare-model"
 
140
 
141
+ # Download files from hub
142
+ tokenizer_file = hf_hub_download(repo_id=model_id, filename="tokenizer.json")
143
  model = TransformerDecoder.from_pretrained(model_id)
144
  tokenizer = Tokenizer.from_file(tokenizer_file)
145
 
 
148
  tokens = tokenizer.encode(prompt).ids
149
  input_ids = torch.tensor(tokens).unsqueeze(0)
150
 
151
+ generated_tokens = []
152
+
153
  with torch.no_grad():
154
  for _ in range(max_length):
155
  outputs = model(input_ids)
 
158
  next_token = torch.multinomial(probs, num_samples=1)
159
  input_ids = torch.cat([input_ids, next_token], dim=1)
160
 
161
+ token_id = next_token.item()
162
+ generated_tokens.append(token_id)
163
+
164
+ if token_id == tokenizer.token_to_id("[EOS]"):
165
  break
166
 
167
  return tokenizer.decode(input_ids[0].tolist())