damerajee commited on
Commit
755d66e
·
verified ·
1 Parent(s): f9dc32d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -22
app.py CHANGED
@@ -2,10 +2,10 @@ import gradio as gr
2
  import torch
3
  from mingru_lm import MinGRU_LM
4
 
5
-
6
  model = MinGRU_LM(dim=512, num_tokens=256, num_layers=6)
7
  pt_model = "best_model.pt"
8
- checkpoint = torch.load(pt_model,map_location=torch.device('cpu'))
9
  model.load_state_dict(checkpoint['model_state_dict'])
10
 
11
  # Move model to GPU if available
@@ -25,36 +25,52 @@ def generate_text(start_text, max_length, temperature):
25
  input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) # Ensure long tensor
26
 
27
  generated_tokens = tokens.copy()
28
-
29
- with torch.no_grad():
30
- for _ in range(max_length):
31
- _, logits = model(input_tensor, labels=None)
 
32
 
33
  last_token_logits = logits[0, -1, :] / temperature
34
  probs = torch.softmax(last_token_logits, dim=-1)
 
 
35
  next_token = torch.multinomial(probs, num_samples=1).item()
36
 
37
- # Only append if it's within the 256-character ASCII range
38
  if next_token < 256:
39
  generated_tokens.append(next_token)
40
  input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=device)], dim=1)
 
 
41
  else:
42
  continue # Skip tokens outside ASCII range
43
 
44
- return decode_tokens(generated_tokens)
 
 
 
 
 
 
45
 
46
  # Gradio interface
47
- iface = gr.Interface(
48
- fn=generate_text,
49
- inputs=[
50
- gr.Textbox(lines=3, label="Enter your prompt", value="Once upon a time"),
51
- gr.Slider(minimum=10, maximum=500, value=200, step=1, label="Max Length"),
52
- gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
53
- ],
54
- outputs=gr.Textbox(lines=10, label="Generated Text"),
55
- title="Text Generation with MinGRU_LM",
56
- description="Enter a prompt and adjust parameters to generate text using the MinGRU_LM model."
57
- )
58
-
59
- if __name__ == "__main__":
60
- iface.launch(show_api=False, server_name="0.0.0.0")
 
 
 
 
 
 
2
  import torch
3
  from mingru_lm import MinGRU_LM
4
 
5
+ # Load the model
6
  model = MinGRU_LM(dim=512, num_tokens=256, num_layers=6)
7
  pt_model = "best_model.pt"
8
+ checkpoint = torch.load(pt_model, map_location=torch.device('cpu'))
9
  model.load_state_dict(checkpoint['model_state_dict'])
10
 
11
  # Move model to GPU if available
 
25
  input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) # Ensure long tensor
26
 
27
  generated_tokens = tokens.copy()
28
+
29
+ # Use a generator to yield tokens one by one
30
+ for _ in range(max_length):
31
+ with torch.no_grad():
32
+ logits = model(input_tensor, labels=None)[1] # Get logits directly
33
 
34
  last_token_logits = logits[0, -1, :] / temperature
35
  probs = torch.softmax(last_token_logits, dim=-1)
36
+
37
+ # Sample the next token
38
  next_token = torch.multinomial(probs, num_samples=1).item()
39
 
40
+ # Only append valid tokens
41
  if next_token < 256:
42
  generated_tokens.append(next_token)
43
  input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=device)], dim=1)
44
+
45
+ yield decode_tokens(generated_tokens)
46
  else:
47
  continue # Skip tokens outside ASCII range
48
 
49
+ yield decode_tokens(generated_tokens)
50
+
51
+ def wrapper_generate_text(start_text, max_length, temperature):
52
+ async_gen = generate_text(start_text, max_length, temperature)
53
+
54
+ for output in async_gen:
55
+ yield output
56
 
57
  # Gradio interface
58
+ with gr.Blocks() as iface:
59
+ gr.Markdown("### Please be patient, generating text will take some time...")
60
+
61
+ with gr.Row():
62
+ textbox = gr.Textbox(lines=3, label="Enter your prompt", value="Once upon a time")
63
+ max_length = gr.Slider(minimum=10, maximum=500, value=200, step=1, label="Max Length")
64
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
65
+
66
+ output_textbox = gr.Textbox(lines=10, label="Generated Text")
67
+
68
+ btn = gr.Button("Generate Text")
69
+
70
+ btn.click(
71
+ wrapper_generate_text,
72
+ inputs=[textbox, max_length, temperature],
73
+ outputs=output_textbox
74
+ )
75
+
76
+ iface.launch(show_api=False, server_name="0.0.0.0")