dialogue-rewriter / README.md
xiaotinghe's picture
dialogue-rewriter
4f28c80
metadata
inference:
  parameters:
    max_length: 512
    temperature: 0.7
    top_p: 1
widget:
  - text: >-
      <extra_id_0> loghub是什么<extra_id_1> AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示
      Amazon CloudWatch
      Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集
      Amazon CloudWatch Logs。<extra_id_0> 它的优点是什么?
  - text: <extra_id_0> 基督山伯爵讲的什么故事<extra_id_1> 电影版的基督山伯爵里面的台词太经典了<extra_id_0> 是呢剧情是啥
  - text: <extra_id_0> 你知道明朝那些事儿吗<extra_id_1> 有趣有趣寓教于乐的典型相当不错啊<extra_id_0> 它都讲了什么故事呀
language:
  - en
  - zh

Usage

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'csdc-atl/doc2query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def create_queries(history, next_question):
    inputs_ids = []
    for line in history:
        inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False))
    inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False))
    inputs_ids = inputs_ids + [1]
    inputs_ids = torch.Tensor([inputs_ids]).long()
    with torch.no_grad():
        sampling_outputs = model.generate(
            input_ids=inputs_ids,
            max_length=512,
            do_sample=True,
            top_p=0.95,
            top_k=10
            )
    print("\nSampling Outputs:")
    for i in range(len(sampling_outputs)):
        rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
        print(f'{i + 1}: {rewrite_question}')
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']]
next_question = '它的优点是什么?'
create_queries(history, next_question)
# 1: loghub解决方案的优点是什么?