File size: 967 Bytes
8589c7a
 
54c64de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr

#gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch()

import gradio as gr
import torch
from transformers import AutoTokenizer, GPTJForCausalLM

from Utils import GPTJBlock  # Assuming Utils.py is in the same directory

device = "cuda" if torch.cuda.is_available() else "cpu"

# Monkey-patch GPT-J
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit"

tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device)

def generate_text(prompt):
    prompt = tokenizer(prompt, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    out = model.generate(**prompt, max_length=64, do_sample=True)
    return tokenizer.decode(out[0])

iface = gr.Interface(
    fn=generate_text,
    inputs="text",
    outputs="text",
    live=True
)

iface.launch()