File size: 1,716 Bytes
29b560e
314c465
a22e0d4
3b47068
e73ae0d
29b560e
a22e0d4
 
 
 
06e297b
a22e0d4
06e297b
a22e0d4
 
 
 
 
 
 
 
e73ae0d
73c4071
 
41caafb
a22e0d4
 
47728bf
e73ae0d
0580074
74704c7
e73ae0d
874ae6d
47728bf
dde5ace
dc9192f
e73ae0d
7dfeb8a
e73ae0d
 
31bbafb
4952bf3
 
47728bf
a22e0d4
 
 
 
 
 
 
 
29b560e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

MODEL = "NTQAI/Nxcode-CQ-7B-orpo"

system_message = "You are a computer programmer that can translate python code to C++ in order to improve performance"

def user_prompt_for(python):
    return f"Rewrite this python code to C++. You must search for the maximum performance. \
    Format your response in Markdown. This is the python Code: \
    \n\n\
    {python}"

def messages_for(python):
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_prompt_for(python)}
    ]

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="auto", device_map="auto")

decode_kwargs = dict(skip_special_tokens=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs)

cplusplus = None
def translate(python):
    inputs = tokenizer.apply_chat_template(
                        messages_for(python),
                        add_generation_prompt=True,
                        return_tensors="pt").to(model.device)

    generation_kwargs = dict(
        input_ids=inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    cplusplus = ""
    for chunk in streamer:
        cplusplus += chunk
        yield cplusplus

demo = gr.Interface(fn=translate, inputs="code", outputs="markdown")
demo.launch()