Spaces:
Running
Running
import torch | |
from models.base import HFModel | |
class Qwen(HFModel): | |
def __init__(self, model_path): | |
super().__init__(model_path) | |
def generate(self, input_text, stop_words=[]): | |
im_end = '<|im_end|>' | |
if im_end not in stop_words: | |
stop_words = stop_words + [im_end] | |
stop_words_ids = [self.tokenizer.encode(w) for w in stop_words] | |
input_ids = torch.tensor([self.tokenizer.encode(input_text) | |
]).to(self.model.device) | |
output = self.model.generate(input_ids, stop_words_ids=stop_words_ids) | |
output = output.tolist()[0] | |
output = self.tokenizer.decode(output, errors='ignore') | |
assert output.startswith(input_text) | |
output = output[len(input_text):].replace('<|endoftext|>', | |
'').replace(im_end, '') | |
return output | |
class QwenVL(HFModel): | |
def __init__(self, model_path): | |
super().__init__(model_path) | |
def generate(self, inputs: list): | |
query = self.tokenizer.from_list_format(inputs) | |
response, _ = self.model.chat(self.tokenizer, query=query, history=None) | |
return response |