Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import transformers | |
# saved_model | |
def load_model(model_path): | |
saved_data = torch.load( | |
model_path, | |
map_location="cpu" | |
) | |
bart_best = saved_data["model"] | |
train_config = saved_data["config"] | |
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1') | |
## Load weights. | |
model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1') | |
model.load_state_dict(bart_best) | |
return model, tokenizer | |
# main | |
def inference(prompt): | |
model_path = "./kobart-model-logical.pth" | |
model, tokenizer = load_model( | |
model_path=model_path | |
) | |
input_ids = tokenizer.encode(prompt) | |
input_ids = torch.tensor(input_ids) | |
input_ids = input_ids.unsqueeze(0) | |
output = model.generate(input_ids) | |
output = tokenizer.decode(output[0], skip_special_tokens=True) | |
return output | |
demo = gr.Interface( | |
fn=inference, | |
inputs="text", | |
outputs="text" #return κ° | |
).launch(share=True) # launch(share=True)λ₯Ό μ€μ νλ©΄ μΈλΆμμ μ μ κ°λ₯ν λ§ν¬κ° μμ±λ¨ | |
demo.launch() |