File size: 2,515 Bytes
ba34bf3
 
 
 
 
 
 
 
 
 
 
7a8d345
ba34bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4949853
ba34bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e839ef
6f624b8
 
ba34bf3
 
 
f19fc86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import shutil

model_name = "minoD/JURAN"

# モデルのロード
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu"
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# プロンプトテンプレートの準備
def generate_prompt(F):
    # input キーの代わりに Q と F を使用
    result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.

### 質問:
{F}

### 回答:
""" # 回答セクションを追加
    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result

# テキスト生成関数の定義
def generate2(F=None, maxTokens=256):
    # 推論
    prompt = generate_prompt(F)
    input_ids = tokenizer(prompt,
                          return_tensors="pt",
                          truncation=True,
                          add_special_tokens=False).input_ids
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=maxTokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.75,
        top_k=40,
        no_repeat_ngram_size=2,
    )
    outputs = outputs[0].tolist()
    decoded = tokenizer.decode(outputs)

    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])

        # レスポンス内容のみ抽出
        sentinel = "### 回答:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            result = decoded[sentinelLoc + len(sentinel):]
            return result.replace("<NL>", "\n")  # <NL>→改行
        else:
            return 'Warning: Expected prompt template to be emitted.  Ignoring output.'
    else:
       return 'Warning: no <eos> detected ignoring output'
    
def inference(input_text):
  return generate2(input_text)


iface = gr.Interface(
    fn=inference,
    inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
    outputs=gr.Textbox(label="想定される質問"),
    title="JURAN🌺",
    description="面接官モデルが回答を生成します。",
    api_name="ask",
    allow_flagging="never"
)

if __name__ == "__main__":
     iface.launch(share=True)