phi2v2 / phi_predict.py
NobodyExistsOnTheInternet's picture
Upload 22 files
859da2a
raw
history blame contribute delete
No virus
600 Bytes
import pandas as pd
def predict(data, task, model, tokenizer, config, **kwargs):
if isinstance(data, pd.DataFrame):
data = data[data.columns[0]].tolist()
is_df = True
results = []
addn_args = kwargs.get("addn_args", {})
for d in data:
inputs = tokenizer(d, return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, **addn_args, max_length=50)
text = tokenizer.batch_decode(outputs)[0]
results.append(text)
if is_df:
return pd.DataFrame(results,columns =['output'])
return {"output": results}