Juna190825 commited on
Commit
3f44f2a
·
verified ·
1 Parent(s): 31eb54c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import gradio as gr
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import pipeline
8
+ import torch
9
+ import os
10
+ import time
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI()
14
+
15
+ # Mount Gradio app
16
+ gradio_app = gr.Blocks()
17
+
18
+ # Model loading function
19
+ def load_model():
20
+ model_name = "trillionlabs/Trillion-7B-preview-AWQ"
21
+
22
+ # Load tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+
25
+ # Load model with CPU support
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ device_map="cpu",
29
+ torch_dtype=torch.float32
30
+ )
31
+
32
+ # Create text generation pipeline
33
+ text_generator = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ device="cpu"
38
+ )
39
+
40
+ return text_generator
41
+
42
+ # Load model (this will happen when the server starts)
43
+ text_generator = load_model()
44
+
45
+ # API endpoint for text generation
46
+ @app.post("/api/generate")
47
+ async def generate_text(request: Request):
48
+ try:
49
+ data = await request.json()
50
+ prompt = data.get("prompt", "")
51
+ max_length = data.get("max_length", 100)
52
+
53
+ # Generate text
54
+ start_time = time.time()
55
+ outputs = text_generator(
56
+ prompt,
57
+ max_length=max_length,
58
+ do_sample=True,
59
+ temperature=0.7,
60
+ top_k=50,
61
+ top_p=0.95
62
+ )
63
+ generation_time = time.time() - start_time
64
+
65
+ return JSONResponse({
66
+ "generated_text": outputs[0]["generated_text"],
67
+ "generation_time": generation_time,
68
+ "model": "trillionlabs/Trillion-7B-preview-AWQ",
69
+ "device": "cpu"
70
+ })
71
+ except Exception as e:
72
+ return JSONResponse({"error": str(e)}, status_code=500)
73
+
74
+ # Gradio interface
75
+ def gradio_generate(prompt, max_length=100):
76
+ outputs = text_generator(
77
+ prompt,
78
+ max_length=max_length,
79
+ do_sample=True,
80
+ temperature=0.7,
81
+ top_k=50,
82
+ top_p=0.95
83
+ )
84
+ return outputs[0]["generated_text"]
85
+
86
+ with gradio_app:
87
+ gr.Markdown("# Trillion-7B-preview-AWQ Demo (CPU)")
88
+ gr.Markdown("This is a CPU-only demo of the Trillion-7B-preview-AWQ model running with 16GB RAM.")
89
+
90
+ with gr.Row():
91
+ input_prompt = gr.Textbox(label="Input Prompt", lines=5)
92
+ output_text = gr.Textbox(label="Generated Text", lines=5)
93
+
94
+ length_slider = gr.Slider(50, 500, value=100, label="Max Length")
95
+ generate_btn = gr.Button("Generate")
96
+
97
+ generate_btn.click(
98
+ fn=gradio_generate,
99
+ inputs=[input_prompt, length_slider],
100
+ outputs=output_text
101
+ )
102
+
103
+ # Mount Gradio app
104
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
105
+
106
+ # CORS middleware
107
+ app.add_middleware(
108
+ CORSMiddleware,
109
+ allow_origins=["*"],
110
+ allow_credentials=True,
111
+ allow_methods=["*"],
112
+ allow_headers=["*"],
113
+ )