Nausea582 commited on
Commit
23d1c46
·
verified ·
1 Parent(s): 4c23624
Files changed (2) hide show
  1. app.py +80 -10
  2. requirements +6 -0
app.py CHANGED
@@ -1,10 +1,80 @@
1
- import gradio as gr
2
-
3
- with gr.Blocks(fill_height=True) as demo:
4
- with gr.Sidebar():
5
- gr.Markdown("# Inference Provider")
6
- gr.Markdown("This Space showcases the WiroAI/WiroAI-Finance-Qwen-1.5B model, served by the hf-inference API. Sign in with your Hugging Face account to use this API.")
7
- button = gr.LoginButton("Sign in")
8
- gr.load("models/WiroAI/WiroAI-Finance-Qwen-1.5B", accept_token=button, provider="hf-inference")
9
-
10
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import transformers
5
+ import torch
6
+
7
+ # ---- 初始化模型 ----
8
+ model_id = "WiroAI/WiroAI-Finance-Qwen-1.5B"
9
+
10
+ # 缓存 pipeline 避免重复加载
11
+ _pipeline = None
12
+
13
+
14
+ def get_pipeline():
15
+ global _pipeline
16
+ if _pipeline is None:
17
+ _pipeline = transformers.pipeline(
18
+ "text-generation",
19
+ model=model_id,
20
+ model_kwargs={"torch_dtype": torch.bfloat16},
21
+ device_map="auto" if torch.cuda.is_available() else None,
22
+ )
23
+ _pipeline.model.eval()
24
+ return _pipeline
25
+
26
+
27
+ # ---- FastAPI 配置 ----
28
+ app = FastAPI(title="WiroAI Finance Chat API")
29
+
30
+ # 允许跨域(重要!)
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+
39
+ # ---- 请求/响应模型 ----
40
+ class ChatRequest(BaseModel):
41
+ message: str
42
+ max_new_tokens: int = 512
43
+ temperature: float = 0.9
44
+
45
+
46
+ class ChatResponse(BaseModel):
47
+ content: str
48
+
49
+
50
+ # ---- API 端点 ----
51
+ @app.post("/chat", response_model=ChatResponse)
52
+ async def generate_response(request: ChatRequest):
53
+ try:
54
+ pipeline = get_pipeline()
55
+
56
+ # 构建对话历史
57
+ messages = [
58
+ {"role": "system", "content": "You are a finance chatbot developed by Wiro AI"},
59
+ {"role": "user", "content": request.message}
60
+ ]
61
+
62
+ # 终止符配置
63
+ terminators = [
64
+ pipeline.tokenizer.eos_token_id,
65
+ pipeline.tokenizer.convert_tokens_to_ids("")
66
+ ]
67
+
68
+ # 生成回复
69
+ outputs = pipeline(
70
+ messages,
71
+ max_new_tokens=request.max_new_tokens,
72
+ eos_token_id=terminators,
73
+ do_sample=True,
74
+ temperature=request.temperature,
75
+ )
76
+
77
+ return {"content": outputs[0]["generated_text"][-1]['content']}
78
+
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=str(e))
requirements ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.68.0
2
+ uvicorn>=0.15.0
3
+ transformers>=4.40.0
4
+ torch>=2.3.0
5
+ accelerate>=0.30.0
6
+ pydantic>=2.0