Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -15,10 +15,13 @@ model_name = "unsloth/Llama-3.2-1B-Instruct"
|
|
15 |
device = torch.device('cuda')
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
creative_sampler = BacktrackSampler(strategy, provider)
|
23 |
|
24 |
def create_chat_template_messages(history, prompt):
|
@@ -45,7 +48,7 @@ def generate_responses(prompt, history):
|
|
45 |
return tokenizer.decode(generated_list, skip_special_tokens=True)
|
46 |
|
47 |
custom_output = asyncio.run(custom_sampler_task())
|
48 |
-
standard_output =
|
49 |
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
|
50 |
|
51 |
return standard_response.strip(), custom_output.strip()
|
|
|
15 |
device = torch.device('cuda')
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
|
18 |
+
model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
|
19 |
|
20 |
+
model2 = AutoModelForCausalLM.from_pretrained(model_name)
|
21 |
+
device = torch.device('cuda')
|
22 |
+
|
23 |
+
strategy = CreativeWritingStrategy()
|
24 |
+
provider = TransformersProvider(model2, tokenizer, device)
|
25 |
creative_sampler = BacktrackSampler(strategy, provider)
|
26 |
|
27 |
def create_chat_template_messages(history, prompt):
|
|
|
48 |
return tokenizer.decode(generated_list, skip_special_tokens=True)
|
49 |
|
50 |
custom_output = asyncio.run(custom_sampler_task())
|
51 |
+
standard_output = model1.generate(inputs, max_length=2048, temperature=1)
|
52 |
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
|
53 |
|
54 |
return standard_response.strip(), custom_output.strip()
|