This checkpoint is a states tuning file from RWKV-6-7B. Please download the base model from https://huggingface.co/BlinkDL/rwkv-6-world/tree/main . Usage:
- update the latest rwkv package: pip install --upgrade rwkv
- Download the base model and the states file. You may download either the states from root directory or the epoch_2 directory. Test which one is better for you.
- Following the codes:
- Loading the model and states
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
import torch
# download models: https://huggingface.co/BlinkDL
model = RWKV(model='/media/yueyulin/KINGSTON/models/rwkv6/RWKV-x060-World-7B-v2.1-20240507-ctx4096.pth', strategy='cuda fp16')
print(model.args)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") # 20B_tokenizer.json is in https://github.com/BlinkDL/ChatRWKV
# use pipeline = PIPELINE(model, "rwkv_vocab_v20230424") for rwkv "world" models
states_file = '/media/yueyulin/data_4t/models/states_tuning/custom_trainer/epoch_2/RWKV-x060-World-7B-v2.1-20240507-ctx4096.pth.pth'
states = torch.load(states_file)
states_value = []
device = 'cuda'
n_head = model.args.n_head
head_size = model.args.n_embd//model.args.n_head
for i in range(model.args.n_layer):
key = f'blocks.{i}.att.time_state'
value = states[key]
prev_x = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
prev_states = value.clone().detach().to(device=device,dtype=torch.float16).transpose(1,2)
prev_ffn = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
states_value.append(prev_x)
states_value.append(prev_states)
states_value.append(prev_ffn)
- Try the Chinese IE
cat_char = '🐱'
bot_char = '🤖'
instruction ='你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。'
input_text = '{\"input\":\"6 月 17 日,广发证券研报指出,近期大飞机各项进展持续推进。6 月 14 日,东航 C919 机型开启第四条商业定期航线——上海虹桥往返广州白云。\
\
工业和信息化部、国家自然科学基金委员会 6 月 14 日签署合作协议,共同设立大飞机基础研究联合基金。\
\
全球积压飞机订单超 1.4 万架,当前全球航空业因零部件供应短缺、交付周期变长等问题面临供应链威胁,或为国内航空航发产业链相关企业带来航空出海业务新增量。\",\
\"schema\":[\"地理位置\",\"组织机构\",\"气候类型\",\"时间\"]}'
ctx = f'{cat_char}:{instruction}\n{input_text}\n{bot_char}:'
print(ctx)
def my_print(s):
print(s, end='', flush=True)
# For alpha_frequency and alpha_presence, see "Frequency and presence penalties":
# https://platform.openai.com/docs/api-reference/parameter-details
args = PIPELINE_ARGS(temperature = 1.0, top_p = 0, top_k = 0, # top_k = 0 then ignore
alpha_frequency = 0.25,
alpha_presence = 0.25,
alpha_decay = 0.996, # gradually decay the penalty
token_ban = [0], # ban the generation of some tokens
token_stop = [0,1], # stop generation whenever you see any token here
chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)
pipeline.generate(ctx, token_count=200, args=args, callback=my_print,state=states_value)
print('\n')
The output looks like:
🐱:你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。
{"input":"6 月 17 日,广发证券研报指出,近期大飞机各项进展持续推进。6 月 14 日,东航 C919 机型开启第四条商业定期航线——上海虹桥往返广州白云。工业和信息化部、国家自然科学基金委员会 6 月 14 日签署合作协议,共同设立大飞机基础研究联合基金。全球积压飞机订单超 1.4 万架,当前全球航空业因零部件供应短缺、交付周期变长等问题面临供应链威胁,或为国内航空航发产业链相关企业带来航空出海业务新增量。","schema":["地理位置","组织机构","气候类型","时间"]}
🤖:
{"地理位置": ["上海", "广州", "白云"], "组织机构": ["广发证券", "工业和信息化部", "国家自然科学基金委员会"], "气候类型": [], "时间": ["6 月 14 日"]}
- English IE
instruction = "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string."
input_text = "{\"input\":\"Mumtaz Mahal died in 1631 in Burhanpur, Deccan (present-day Madhya Pradesh) during the birth of her 14th child, a daughter named Gauhar Ara Begum.[20] Shah Jahan had the Taj Mahal built as a tomb for her, which is considered to be a monument of undying love. As with other Mughal royal ladies, no contemporary likenesses of her are accepted, but imagined portraits were created from the 19th century onwards. \",\"schema\":[\"location\",\"time\",\"person\",\"organization\"]}"
ctx = f'{cat_char}:{instruction}\n{input_text}\n{bot_char}:'
print(ctx)
states_value = []
for i in range(model.args.n_layer):
key = f'blocks.{i}.att.time_state'
value = states[key]
prev_x = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
prev_states = value.clone().detach().to(device=device,dtype=torch.float16).transpose(1,2)
prev_ffn = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
states_value.append(prev_x)
states_value.append(prev_states)
states_value.append(prev_ffn)
pipeline.generate(ctx, token_count=200, args=args, callback=my_print,state=states_value)
print('\n')
The output should looks like:
🐱:You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.
{"input":"Mumtaz Mahal died in 1631 in Burhanpur, Deccan (present-day Madhya Pradesh) during the birth of her 14th child, a daughter named Gauhar Ara Begum.[20] Shah Jahan had the Taj Mahal built as a tomb for her, which is considered to be a monument of undying love. As with other Mughal royal ladies, no contemporary likenesses of her are accepted, but imagined portraits were created from the 19th century onwards. ","schema":["location","time","person","organization"]}
🤖:
{"location": ["Burhanpur", "Deccan", "Madhya Pradesh"], "time": ["1631"], "person": ["Mumtaz Mahal", "Gauhar Ara Begum", "Shah Jahan"], "organization": []}
- Chinese and English combination
instruction ='你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。'
input_text = '{\"input\":\"马拉维共和国(英语:Republic of Malawi;齐切瓦语:Dziko la Malaŵi),通称马拉维(齐切瓦语:Malaŵi;英语:Malawi),是一个位于非洲东南部的内陆国家,邻接赞比亚、莫桑比克及坦桑尼亚。国土位于南纬9°45\'至17°16\'、东经32°35\'-35°24\'之间。\
其首都里朗威位于马拉维的中部。 \",\"schema\":[\"country\",\"person\",\"time\",\"毗邻国家\"]}'
ctx = f'{cat_char}:{instruction}\n{input_text}\n{bot_char}:'
print(ctx)
states_value = []
for i in range(model.args.n_layer):
key = f'blocks.{i}.att.time_state'
value = states[key]
prev_x = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
prev_states = value.clone().detach().to(device=device,dtype=torch.float16).transpose(1,2)
prev_ffn = torch.zeros(model.args.n_embd,device=device,dtype=torch.float16)
states_value.append(prev_x)
states_value.append(prev_states)
states_value.append(prev_ffn)
pipeline.generate(ctx, token_count=200, args=args, callback=my_print,state=states_value)
print('\n')
The output looks like:
🐱:你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。
{"input":"马拉维共和国(英语:Republic of Malawi;齐切瓦语:Dziko la Malaŵi),通称马拉维(齐切瓦语:Malaŵi;英语:Malawi),是一个位于非洲东南部的内陆国家,邻接赞比亚、莫桑比克及坦桑尼亚。国土位于南纬9°45'至17°16'、东经32°35'-35°24'之间。 其首都里朗威位于马拉维的中部。 ","schema":["country","person","time","毗邻国家"]}
🤖:
{"country": ["马拉维共和国", "马拉维", "齐切瓦语:Dziko la Malaŵi", "英语:Republic of Malawi", "Malawi"], "person": [], "time": [], "毗邻国家": ["赞比亚", "莫桑比克", "坦桑尼亚"]}