File size: 956 Bytes
808ec88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2560311
 
808ec88
 
934a315
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import gradio as gr

import data_utils
from gpt_language_model import GPTLanguageModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'

inference_model = GPTLanguageModel()

inference_model.load_state_dict(torch.load('model/friendsGPT.pth', map_location=torch.device('cpu')))
inferenceModel = inference_model.to(device)

def generate():
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    output = data_utils.decode(inferenceModel.generate(context, max_new_tokens=500)[0].tolist())
    return output
    

demo = gr.Interface(fn=generate, inputs=None, outputs="text", title="F.R.I.E.N.D.S GPT", 
                    thumbnail="https://people.com/thmb/FcIy814mvCkqy62dKwv-z-CWcAk=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc():focal(1499x0:1501x2)/matthew-perry-friends-cast-tribute-103023-83fecd2734cc44c18c47bfb9ec0a694b.jpg")
    
if __name__ == "__main__":
    demo.launch(show_api=False, share=True)