cahya commited on
Commit
4ac6ada
1 Parent(s): e84c607

add bloomz

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -0
  2. app/api.py +19 -7
  3. app/config.json +4 -1
Dockerfile CHANGED
@@ -8,6 +8,7 @@ RUN apt-get update && apt-get install -y \
8
  libxmlsec1-dev libffi-dev liblzma-dev git-lfs ffmpeg libsm6 libxext6 cmake \
9
  libgl1-mesa-glx curl nginx espeak-ng openssl libssl-dev libbz2-dev \
10
  libncurses5-dev libreadline-dev \
 
11
  && rm -rf /var/lib/apt/lists/* && git lfs install
12
 
13
  RUN wget https://github.com/tsl0922/ttyd/releases/download/1.7.3/ttyd.x86_64 -O /usr/local/bin/ttyd && \
8
  libxmlsec1-dev libffi-dev liblzma-dev git-lfs ffmpeg libsm6 libxext6 cmake \
9
  libgl1-mesa-glx curl nginx espeak-ng openssl libssl-dev libbz2-dev \
10
  libncurses5-dev libreadline-dev \
11
+ vim lynx haproxy \
12
  && rm -rf /var/lib/apt/lists/* && git lfs install
13
 
14
  RUN wget https://github.com/tsl0922/ttyd/releases/download/1.7.3/ttyd.x86_64 -O /usr/local/bin/ttyd && \
app/api.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
3
  from fastapi import Form, Depends, HTTPException, status
4
- from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel
5
  import torch
6
  import os
7
  import time
@@ -68,7 +68,13 @@ async def websocket_endpoint(websocket: WebSocket):
68
 
69
 
70
  @app.post("/api/indochat/v1")
71
- async def indochat(
 
 
 
 
 
 
72
  text: str = Form(default="", description="The Prompt"),
73
  decoding_method: str = Form(default="Sampling", description="Decoding method"),
74
  min_length: int = Form(default=50, description="Minimal length of the generated text"),
@@ -102,13 +108,13 @@ async def indochat(
102
  max_penalty = 1.5
103
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
104
  prompt = f"User: {text}\nAssistant: "
105
- input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
106
- model.eval()
107
  print("Generating text...")
108
  print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
109
  f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
110
  time_start = time.time()
111
- sample_outputs = model.generate(input_ids,
112
  penalty_alpha=penalty_alpha,
113
  do_sample=do_sample,
114
  num_beams=num_beams,
@@ -134,7 +140,7 @@ def get_text_generator(model_name: str, device: str = "cpu"):
134
  print(f"hf_auth_token: {hf_auth_token}")
135
  print(f"Loading model with device: {device}...")
136
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
137
- model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
138
  use_auth_token=hf_auth_token)
139
  model.to(device)
140
  print("Model loaded")
@@ -147,4 +153,10 @@ def get_config():
147
 
148
  config = get_config()
149
  device = "cuda" if torch.cuda.is_available() else "cpu"
150
- model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)
 
 
 
 
 
 
1
  from fastapi import FastAPI, WebSocket
2
  from fastapi.responses import HTMLResponse
3
  from fastapi import Form, Depends, HTTPException, status
4
+ from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import os
7
  import time
68
 
69
 
70
  @app.post("/api/indochat/v1")
71
+ async def indochat(**kwargs):
72
+ return text_generate("indochat-tiny", kwargs)
73
+
74
+
75
+ @app.post("/api/text-generator/v1")
76
+ async def text_generate(
77
+ model_name: str = Form(default="", description="The model name"),
78
  text: str = Form(default="", description="The Prompt"),
79
  decoding_method: str = Form(default="Sampling", description="Decoding method"),
80
  min_length: int = Form(default=50, description="Minimal length of the generated text"),
108
  max_penalty = 1.5
109
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
110
  prompt = f"User: {text}\nAssistant: "
111
+ input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids.to(device)
112
+ text_generator[model_name]["model"].eval()
113
  print("Generating text...")
114
  print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
115
  f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
116
  time_start = time.time()
117
+ sample_outputs = text_generator[model_name]["model"].generate(input_ids,
118
  penalty_alpha=penalty_alpha,
119
  do_sample=do_sample,
120
  num_beams=num_beams,
140
  print(f"hf_auth_token: {hf_auth_token}")
141
  print(f"Loading model with device: {device}...")
142
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
143
+ model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
144
  use_auth_token=hf_auth_token)
145
  model.to(device)
146
  print("Model loaded")
153
 
154
  config = get_config()
155
  device = "cuda" if torch.cuda.is_available() else "cpu"
156
+ text_generator = {}
157
+ for model_name in config["text-generator"]:
158
+ model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name], device=device)
159
+ text_generator[model_name] = {
160
+ "model": model,
161
+ "tokenizer": tokenizer
162
+ }
app/config.json CHANGED
@@ -1,3 +1,6 @@
1
  {
2
- "model_name": "cahya/indochat-tiny"
 
 
 
3
  }
1
  {
2
+ "text-generator": {
3
+ "indochat-tiny": "cahya/indochat-tiny",
4
+ "bloomz-1b1-instruct": "cahya/bloomz-1b1-instruct"
5
+ }
6
  }