dolly-v0-70m / app.py
debisoft's picture
t
e6296d8
raw
history blame
1.69 kB
import numpy as np
import pandas as pd
import requests
import os
import gradio as gr
import json
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
databricks_token = os.getenv('DATABRICKS_TOKEN')
model_uri = "https://dbc-eb788f31-6c73.cloud.databricks.com/serving-endpoints/Mpt-7b-tester/invocations"
def extract_json(gen_text):
start_index = gen_text.index("### Response:\n{") + 14
end_index = gen_text.index("}\n\n### End") + 1
return gen_text[start_index:end_index]
def score_model(model_uri, databricks_token, prompt):
dataset=pd.DataFrame({
"prompt":[prompt],
"temperature": [0.5],
"max_tokens": [500]})
headers = {
"Authorization": f"Bearer {databricks_token}",
"Content-Type": "application/json",
}
ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
data_json = json.dumps(ds_dict, allow_nan=True)
print("***ds_dict: ")
print(ds_dict)
print("***data_json: ")
print(data_json)
response = requests.request(method='POST', headers=headers, url=model_uri, data=data_json)
if response.status_code != 200:
raise Exception(f"Request failed with status {response.status_code}, {response.text}")
return response.json()
def get_completion(prompt):
return score_model(model_uri, databricks_token, prompt)
def greet(input):
response = get_completion(input)
gen_text = response["predictions"][0]["generated_text"]
return extract_json(gen_text)
#return json.dumps(response)
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Prompt", lines=3)], outputs="text")
iface.launch()