net / server.py
deeme's picture
Upload server.py
15683f3 verified
# -*- coding: utf-8 -*-
import time
import uvicorn
import sys
import getopt
import json
import os
from pprint import pprint
import requests
import trafilatura
from trafilatura import bare_extraction
from concurrent.futures import ThreadPoolExecutor
import concurrent
import requests
import openai
import time
from datetime import datetime
from urllib.parse import urlparse
import platform
import urllib.parse
import free_ask_internet
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from fastapi.responses import StreamingResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class QueryRequest(BaseModel):
query:str
model: str
ask_type: Literal["search", "llm"]
llm_auth_token: Optional[str] = os.environ.get('OPENAI_API_KEY')
llm_base_url: Optional[str] = os.environ.get('OPENAI_BASE_URL')
using_custom_llm:Optional[bool] = False
lang:Optional[str] = "zh-CN"
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
class SearchItem(BaseModel):
url: str
icon_url: str
site_name:str
snippet:str
title:str
class SearchItemList(BaseModel):
search_items: List[SearchItem] = []
class SearchResp(BaseModel):
code:int
msg:str
data: List[SearchItem] = []
@app.get("/deem/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/deem/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
if query[0] == '@':
#print("当前问题:gpt ---> {}".format(query))
generate = askgpt(request.messages,"",request.model)
else:
query = query[1:]
if len(request.messages) > 2:
message = '\n'.join([msg.content for msg in request.messages])
query = free_ask_internet.summary_gpt(message + "\n请根据以上的内容总结," + query +" 这个问题是要问什么?不要有模糊的代名词比如他/她之类的,不允许缺失上下文语境,需要明确提问的主题;最后只允许输出总结并完善语境的问题,不要有任何多余的文字!", request.model)
else:
pass
#print("当前问题:net ---> {}".format(query))
generate = predict(query, "", request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
def askgpt(query, history, model_id):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
new_response = ""
current_length = 0
for token in free_ask_internet.ask_gpt(query, model_id,):
new_response += token
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text,role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
yield '[DONE]'
def predict(query: str, history: None, model_id: str):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
new_response = ""
current_length = 0
for token in free_ask_internet.ask_internet(query=query, model=model_id):
new_response += token
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text,role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[
choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True))
yield '[DONE]'
@app.post("/api/search/get_search_refs", response_model=SearchResp)
async def get_search_refs(request: QueryRequest):
global search_results
search_results = []
search_item_list = []
if request.ask_type == "search":
search_links,search_results = free_ask_internet.search_web_ref(request.query)
for search_item in search_links:
snippet = search_item.get("snippet")
url = search_item.get("url")
icon_url = search_item.get("icon_url")
site_name = search_item.get("site_name")
title = search_item.get("title")
si = SearchItem(snippet=snippet,url=url,icon_url=icon_url,site_name=site_name,title=title)
search_item_list.append(si)
resp = SearchResp(code=0,msg="success",data=search_item_list)
return resp
def generator(prompt:str, model:str, llm_auth_token:str,llm_base_url:str, using_custom_llm=False,is_failed=False):
if is_failed:
yield "搜索失败,没有返回结果"
else:
total_token = ""
for token in free_ask_internet.chat(prompt=prompt,model=model,llm_auth_token=llm_auth_token,llm_base_url=llm_base_url,using_custom_llm=using_custom_llm,stream=True):
total_token += token
yield token
@app.post("/api/search/stream/{search_uuid}")
async def stream(search_uuid:str,request: QueryRequest):
global search_results
if request.ask_type == "llm":
answer_language = ' Simplified Chinese '
if request.lang == "zh-CN":
answer_language = ' Simplified Chinese '
if request.lang == "zh-TW":
answer_language = ' Traditional Chinese '
if request.lang == "en-US":
answer_language = ' English '
prompt = ' You are a large language AI assistant develop by nash_su. Answer user question in ' + answer_language + '. And here is the user question: ' + request.query
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token, llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm)
else:
prompt = None
limit_count = 10
while limit_count > 0:
try:
if len(search_results) > 0:
prompt = free_ask_internet.gen_prompt(request.query,search_results,lang=request.lang,context_length_limit=8000)
break
else:
limit_count -= 1
time.sleep(1)
except Exception as err:
limit_count -= 1
time.sleep(1)
total_token = ""
if prompt:
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token, llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm)
else:
generate = generator(prompt,model=request.model,llm_auth_token=request.llm_auth_token,llm_base_url=request.llm_base_url, using_custom_llm=request.using_custom_llm,is_failed=True)
# return EventSourceResponse(generate, media_type="text/event-stream")
return StreamingResponse(generate, media_type="text/event-stream")
def main():
port = 8000
search_results = []
uvicorn.run(app, host='0.0.0.0', port=port, workers=1)
if __name__ == "__main__":
main()