Edit model card

Usage

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'csdc-atl/doc2query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def create_queries(history, next_question):
    inputs_ids = []
    for line in history:
        inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False))
    inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False))
    inputs_ids = inputs_ids + [1]
    inputs_ids = torch.Tensor([inputs_ids]).long()
    with torch.no_grad():
        sampling_outputs = model.generate(
            input_ids=inputs_ids,
            max_length=512,
            do_sample=True,
            top_p=0.95,
            top_k=10
            )
    print("\nSampling Outputs:")
    for i in range(len(sampling_outputs)):
        rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
        print(f'{i + 1}: {rewrite_question}')
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']]
next_question = '它的优点是什么?'
create_queries(history, next_question)
# 1: loghub解决方案的优点是什么?
Downloads last month
33