|
import gradio as gr |
|
import requests |
|
import time |
|
from ast import literal_eval |
|
from datetime import datetime |
|
|
|
def to_md(text): |
|
|
|
return text.replace("\n", "<br />") |
|
|
|
def infer( |
|
prompt, |
|
model_name, |
|
max_new_tokens=10, |
|
temperature=0.1, |
|
top_p=1.0, |
|
top_k=40, |
|
num_completions=1, |
|
seed=42, |
|
stop="\n" |
|
): |
|
model_name_map = { |
|
"GPT-JT-6B-v1": "Together-gpt-JT-6B-v1", |
|
} |
|
max_new_tokens = int(max_new_tokens) |
|
num_completions = int(num_completions) |
|
temperature = float(temperature) |
|
top_p = float(top_p) |
|
top_k = int(top_k) |
|
stop = stop.split(";") |
|
seed = seed |
|
|
|
assert 1 <= max_new_tokens <= 256 |
|
assert 1 <= num_completions <= 5 |
|
assert 0.0 <= temperature <= 10.0 |
|
assert 0.0 <= top_p <= 1.0 |
|
assert 1 <= top_k <= 1000 |
|
|
|
if temperature == 0.0: |
|
temperature = 0.01 |
|
if prompt=="": |
|
prompt = " " |
|
my_post_dict = { |
|
"model": "Together-gpt-JT-6B-v1", |
|
"prompt": prompt, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"temperature": temperature, |
|
"max_tokens": max_new_tokens, |
|
"stop": stop, |
|
} |
|
print(f"send: {datetime.now()}") |
|
response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json() |
|
generated_text = response['output']['choices'][0]['text'] |
|
print(f"recv: {datetime.now()}") |
|
|
|
for stop_word in stop: |
|
if stop_word != '' and stop_word in generated_text: |
|
generated_text = generated_text[:generated_text.find(stop_word)] |
|
|
|
return generated_text |
|
|
|
def main (): |
|
iface = gr.Interface( |
|
fn=infer, |
|
inputs=[ |
|
gr.Textbox(lines=20), |
|
gr.Dropdown(["GPT-JT-6B-v1"]), |
|
gr.Slider(10, 1000, value=200), |
|
gr.Slider(0.0, 0.1, value=0.1), |
|
gr.Slider(0.0, 1.0, value=1.0), |
|
gr.Slider(0, 100, value=40) |
|
], |
|
outputs=gr.Textbox(lines=7) |
|
) |
|
|
|
iface.launch(debug=True) |
|
|
|
if __name__ == '__main__': |
|
main() |