dolly-v0-70m / app.py
debisoft's picture
json
075b0da
raw history blame
No virus
1.89 kB
import numpy as np
import pandas as pd
import requests
import os
import gradio as gr
import json
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
databricks_token = os.getenv('DATABRICKS_TOKEN')
model_uri = "https://dbc-eb788f31-6c73.cloud.databricks.com/serving-endpoints/Mpt-7b-tester/invocations"
def extract_json(gen_text, n_shot_learning=0):
start_index = gen_text.index("### Response:\n{") + 14
if(n_shot_learning > 0) :
for i in range(0, n_shot_learning):
gen_text = gen_text[start_index:]
start_index = gen_text.index("### Response:\n{") + 14
end_index = gen_text.index("}\n\n### ") + 1
return gen_text[start_index:end_index]
def score_model(model_uri, databricks_token, prompt):
dataset=pd.DataFrame({
"prompt":[prompt],
"temperature": [0.1],
"max_tokens": [500]})
headers = {
"Authorization": f"Bearer {databricks_token}",
"Content-Type": "application/json",
}
ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
data_json = json.dumps(ds_dict, allow_nan=True)
print("***ds_dict: ")
print(ds_dict)
print("***data_json: ")
print(data_json)
response = requests.request(method='POST', headers=headers, url=model_uri, data=data_json)
if response.status_code != 200:
raise Exception(f"Request failed with status {response.status_code}, {response.text}")
return response.json()
def get_completion(prompt):
return score_model(model_uri, databricks_token, prompt)
def greet(input):
response = get_completion(input)
gen_text = response["predictions"][0]["generated_text"]
return extract_json(gen_text, 3)
#return json.dumps(response)
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Prompt", lines=3)], outputs="json")
iface.launch()