Text Generation
Transformers
Safetensors
Finnish
llama
finnish
conversational
text-generation-inference
Ahma-7B / EasyLM /serving.py
aapot
Add training codes
a85f909
raw
history blame
20.5 kB
import dataclasses
import pprint
from functools import partial
import re
import os
from threading import Lock
import urllib
import time
from typing import List, Optional, Union
from pydantic import BaseModel
import absl.logging
from tqdm import tqdm, trange
import numpy as np
import mlxu
from ml_collections import ConfigDict
import uvicorn
from fastapi import FastAPI
import gradio as gr
import requests
from requests.exceptions import Timeout, ConnectionError
class InferenceRequest(BaseModel):
prefix_text: Optional[List[str]] = None
text: Optional[List[str]] = None
until: Optional[Union[List[str], List[List[str]]]] = None
temperature: Optional[float] = None
class ChatRequest(BaseModel):
prompt: str
context: str = ''
temperature: Optional[float] = None
class LMServer(object):
""" HTTP server for serving langauge models. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.host = '0.0.0.0'
config.port = 5007
config.batch_size = 1
config.logging = False
config.pre_compile = 'loglikelihood'
config.default_temperature = 1.0
config.greedy_until_max_length = 5000
config.prepend_to_prefix = ''
config.append_to_prefix = ''
config.prepend_to_text = ''
config.append_to_text = ''
config.chat_prepend_text = ''
config.chat_user_prefix = ''
config.chat_user_suffix = ''
config.chat_lm_prefix = ''
config.chat_lm_suffix = ''
config.notes = ''
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config):
self.config = self.get_default_config(config)
self.lock = Lock()
self.app = FastAPI()
self.app.post('/loglikelihood')(self.serve_loglikelihood)
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
self.app.post('/generate')(self.serve_generate)
self.app.post('/greedy-until')(self.serve_greedy_until)
self.app.post('/chat')(self.serve_chat)
self.app.get('/ready')(self.serve_ready)
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
@staticmethod
def loglikelihood(prefix_text, text):
raise NotImplementedError()
@staticmethod
def loglikelihood_rolling(text):
raise NotImplementedError()
@staticmethod
def generate(text, temperature):
raise NotImplementedError()
@staticmethod
def greedy_until(prefix_text, until, max_length):
raise NotImplementedError()
@staticmethod
def to_list(x):
if isinstance(x, np.ndarray):
return x.tolist()
return x
def serve_ready(self):
return 'Ready!\n'
def serve_loglikelihood(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
if data.prefix_text is None:
data.prefix_text = ['' for _ in data.text]
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
batch_prefix_text, batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'prefix_text': data.prefix_text,
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_loglikelihood_rolling(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Log Likelihood Request ========= \n'
+ pprint.pformat(data) + '\n'
)
text = [
self.config.prepend_to_text + t + self.config.append_to_text
for t in data.text
]
log_likelihood = []
is_greedy = []
for i in trange(0, len(text), self.config.batch_size, ncols=0):
batch_text = text[i:i + self.config.batch_size]
batch_size = len(batch_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_text.extend(['a' for _ in range(extra)])
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
batch_text
)
batch_log_likelihood = self.to_list(batch_log_likelihood)
batch_is_greedy = self.to_list(batch_is_greedy)
log_likelihood.extend(batch_log_likelihood[:batch_size])
is_greedy.extend(batch_is_greedy[:batch_size])
output = {
'text': data.text,
'log_likelihood': log_likelihood,
'is_greedy': is_greedy,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_generate(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Generate Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
if data.temperature is None:
data.temperature = self.config.default_temperature
output_text = []
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
if batch_size < self.config.batch_size:
extra = self.config.batch_size - batch_size
batch_prefix_text.extend(['a' for _ in range(extra)])
batch_output_text = self.generate(
batch_prefix_text,
temperature=data.temperature,
)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'output_text': output_text,
'temperature': data.temperature,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def serve_greedy_until(self, data: InferenceRequest):
with self.lock:
if self.config.logging:
absl.logging.info(
'\n========= Serving Greedy Until Request ========= \n'
+ pprint.pformat(data) + '\n'
)
prefix_text = [
self.config.prepend_to_prefix + p + self.config.append_to_prefix
for p in data.prefix_text
]
until = data.until
max_length = self.config.greedy_until_max_length
output_text = []
for i in range(0, len(prefix_text), self.config.batch_size):
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
batch_until = until[i:i + self.config.batch_size]
batch_size = len(batch_prefix_text)
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
output_text.extend(self.to_list(batch_output_text)[:batch_size])
output = {
'prefix_text': data.prefix_text,
'until': data.until,
'max_length': max_length,
'output_text': output_text,
}
if self.config.logging:
absl.logging.info(
'\n========= Output ========= \n'
+ pprint.pformat(output) + '\n'
)
return output
def process_chat(self, prompt, context, temperature):
context = (
context + self.config.chat_user_prefix
+ prompt + self.config.chat_user_suffix
+ self.config.chat_lm_prefix
)
response = self.generate(
[self.config.chat_prepend_text + context],
temperature=float(temperature),
)[0]
context = context + response + self.config.chat_lm_suffix
return response, context
def serve_chat(self, data: ChatRequest):
if data.temperature is None:
data.temperature = self.config.default_temperature
response, context = self.process_chat(
data.prompt, data.context,
temperature=data.temperature,
)
return {
'response': response,
'context': context,
'temperature': data.temperature,
}
def create_chat_app(self):
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
gr.Markdown(self.config.notes)
chatbot = gr.Chatbot(label='Chat history')
msg = gr.Textbox(
placeholder='Type your message here...',
show_label=False
)
with gr.Row():
send = gr.Button('Send')
regenerate = gr.Button('Regenerate', interactive=False)
clear = gr.Button('Reset')
temp_slider = gr.Slider(
label='Temperature', minimum=0, maximum=2.0,
value=self.config.default_temperature
)
context_state = gr.State(['', ''])
def user_fn(user_message, history, context):
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
chatbot: history + [[user_message, None]],
context_state: [context[1], context[1]],
}
def model_fn(history, context, temperature):
history[-1][1], new_context = self.process_chat(
history[-1][0], context[0], temperature
)
return {
msg: gr.update(value='', interactive=True),
clear: gr.update(interactive=True),
send: gr.update(interactive=True),
chatbot: history,
context_state: [context[0], new_context],
regenerate: gr.update(interactive=True),
}
def regenerate_fn():
return {
msg: gr.update(value='', interactive=False),
clear: gr.update(interactive=False),
send: gr.update(interactive=False),
regenerate: gr.update(interactive=False),
}
def clear_fn():
return {
chatbot: None,
msg: '',
context_state: ['', ''],
regenerate: gr.update(interactive=False),
}
msg.submit(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
send.click(
user_fn,
inputs=[msg, chatbot, context_state],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
regenerate.click(
regenerate_fn,
inputs=None,
outputs=[msg, clear, send, regenerate],
queue=False
).then(
model_fn,
inputs=[chatbot, context_state, temp_slider],
outputs=[msg, clear, send, chatbot, context_state, regenerate],
queue=True
)
clear.click(
clear_fn,
inputs=None,
outputs=[chatbot, msg, context_state, regenerate],
queue=False
)
gradio_chatbot.queue(concurrency_count=1)
return gradio_chatbot
def run(self):
if self.config.pre_compile != '':
if self.config.pre_compile == 'all':
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
else:
pre_compile = self.config.pre_compile.split(',')
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
for task in pre_compile:
if task == 'loglikelihood':
self.loglikelihood(pre_compile_data, pre_compile_data)
self.loglikelihood_rolling(pre_compile_data)
elif task == 'generate':
self.generate(pre_compile_data, 1.0)
elif task == 'greedy_until':
self.greedy_until(
pre_compile_data, pre_compile_data,
self.config.greedy_until_max_length
)
elif task == 'chat':
self.process_chat('a', 'a', 1.0)
else:
raise ValueError(f'Invalid precompile task: {task}!')
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
class LMClient(object):
""" A simple client for the LM server. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.url = 'http://localhost:5007'
config.batch_size = 1
config.wait_for_ready = True
config.dummy = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config=None):
self.config = self.get_default_config(config)
if self.config.wait_for_ready:
self.wait_for_ready()
def wait_for_ready(self):
if self.config.dummy:
return
while True:
try:
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
return
except (Timeout, ConnectionError) as e:
time.sleep(10)
@staticmethod
def batched(iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def loglikelihood(self, prefix, text):
prefix, text = list(prefix), list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(text, self.config.batch_size)
))
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
json={'prefix_text': batch_prefix, 'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def loglikelihood_rolling(self, text):
text = list(text)
if self.config.dummy:
return [-1.0 for _ in text], [False for _ in text]
log_likelihood = []
is_greedy = []
batched_iterator = list(self.batched(text, self.config.batch_size))
for batch_text in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
json={'text': batch_text}
).json()
log_likelihood.extend(response['log_likelihood'])
is_greedy.extend(response['is_greedy'])
return log_likelihood, is_greedy
def greedy_until(self, prefix, until):
prefix, until = list(prefix), list(until)
if self.config.dummy:
results = []
for u in until:
if isinstance(u, str):
results.append('dummy text ' + u)
else:
results.append('dummy text ' + u[0])
return results
batched_iterator = list(zip(
self.batched(prefix, self.config.batch_size),
self.batched(until, self.config.batch_size),
))
output_text = []
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'greedy-until'),
json={'prefix_text': batch_prefix, 'until': batch_until}
).json()
output_text.extend(response['output_text'])
return output_text
def generate(self, prefix, temperature=None):
prefix = list(prefix)
if self.config.dummy:
return ['' for _ in prefix]
output_text = []
batched_iterator = list(self.batched(prefix, self.config.batch_size))
for batch_prefix in tqdm(batched_iterator, ncols=0):
response = requests.post(
urllib.parse.urljoin(self.config.url, 'generate'),
json={
'prefix_text': batch_prefix,
'temperature': temperature,
}
).json()
output_text.extend(response['output_text'])
return output_text
def chat(self, prompt, context, temperature=None):
if self.config.dummy:
return ''
response = requests.post(
urllib.parse.urljoin(self.config.url, 'chat'),
json={
'prompt': prompt,
'context': context,
'temperature': temperature,
}
).json()
return response['response'], response['context']