Shriti09 commited on
Commit
5e3a92a
·
verified ·
1 Parent(s): 04d9c75

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer
4
+ from model_smol2 import LlamaForCausalLM, config_model
5
+
6
+ # Instantiate the model
7
+ model = LlamaForCausalLM(config_model)
8
+
9
+ # Load the checkpoint
10
+ checkpoint_path = "final_checkpoint.pt"
11
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
12
+ model.load_state_dict(checkpoint['model_state_dict'])
13
+ model.eval()
14
+
15
+ # Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
16
+ # Load the tokenizer
17
+ TOKENIZER_PATH = "HuggingFaceTB/cosmo2-tokenizer"
18
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
21
+
22
+
23
+ # Text generation function
24
+ def generate_text(
25
+ prompt, max_length=50, temperature=0.7, top_k=50, repetition_penalty=1.2, n_gram_block=2
26
+ ):
27
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
28
+ generated_tokens = input_ids[0].tolist()
29
+
30
+ with torch.no_grad():
31
+ for _ in range(max_length):
32
+ outputs = model(input_ids) # model outputs
33
+
34
+ # Check if the output is a dictionary with logits
35
+ if isinstance(outputs, dict) and 'logits' in outputs:
36
+ logits = outputs['logits'][:, -1, :]
37
+ else:
38
+ # If not, treat the output as a plain tensor
39
+ logits = outputs[:, -1, :]
40
+
41
+ # Repetition penalty
42
+ for token_id in set(generated_tokens):
43
+ logits[:, token_id] /= repetition_penalty
44
+
45
+ # n-gram blocking
46
+ if len(generated_tokens) >= n_gram_block:
47
+ n_gram = tuple(generated_tokens[-n_gram_block:])
48
+ for token_id in set(generated_tokens):
49
+ if generated_tokens[-n_gram_block:] == list(n_gram):
50
+ logits[:, token_id] -= 1e9
51
+
52
+ logits /= temperature
53
+ top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
54
+ probs = torch.softmax(top_k_logits, dim=-1)
55
+
56
+ next_token_idx = torch.multinomial(probs, num_samples=1)
57
+ next_token = top_k_indices[0, next_token_idx[0]]
58
+
59
+ generated_tokens.append(next_token.item())
60
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
61
+
62
+ if next_token.item() == tokenizer.eos_token_id:
63
+ break
64
+
65
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True)
66
+
67
+
68
+ # Gradio UI
69
+ def generate_response(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block):
70
+ return generate_text(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block)
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("# Smol2 Text Generator")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ prompt_input = gr.Textbox(label="Input Prompt", placeholder="Enter your text prompt here...")
77
+ max_length = gr.Slider(label="Max Length", minimum=10, maximum=200, value=50)
78
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1)
79
+ top_k = gr.Slider(label="Top K", minimum=10, maximum=100, value=50, step=1)
80
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1)
81
+ n_gram_block = gr.Slider(label="N-Gram Blocking", minimum=1, maximum=5, value=2, step=1)
82
+ generate_button = gr.Button("Generate Text")
83
+ with gr.Column():
84
+ output_text = gr.Textbox(label="Generated Text", lines=10)
85
+
86
+ generate_button.click(
87
+ generate_response,
88
+ inputs=[prompt_input, max_length, temperature, top_k, repetition_penalty, n_gram_block],
89
+ outputs=[output_text],
90
+ )
91
+
92
+ demo.launch()