Spaces:
Build error
Build error
import time | |
import torch | |
import gradio as gr | |
import torch._dynamo as dynamo | |
model = torch.load("GPT2Model.pt") | |
optimized_model = dynamo.optimize("inductor")(model) | |
tokenizer = torch.load("GPT2Tokenizer.pt") | |
def timed(fn): | |
start = time.time() | |
result = fn() | |
end = time.time() - start | |
return result, float("{:.5f}".format(end)) | |
def gpt2(prompt): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
eager_outputs, eager_time = timed(lambda: model.generate(input_ids, do_sample=False, max_length=30)) | |
dynamo_outputs, dynamo_time = timed(lambda: optimized_model.generate(input_ids, do_sample=False, max_length=30)) | |
if torch.allclose(eager_outputs, dynamo_outputs): | |
actual_output = tokenizer.batch_decode(dynamo_outputs, skip_special_tokens=True)[0] | |
else: | |
actual_output = "Result is not correct between dynamo and eager!" | |
expect_output = f"Torch eager takes: {eager_time} \nDynamo takes: {dynamo_time} \nSpeedup: " | |
expect_output += "{:.2f}".format(eager_time/dynamo_time) + f"x \nOutput: {actual_output}" | |
return expect_output | |
demo = gr.Interface(fn=gpt2, inputs="text", outputs="text") | |
demo.launch() |