HinaCortus commited on
Commit
5f0a349
1 Parent(s): b1d7fd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -1,23 +1,26 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- from transformers import TextClassificationPipeline
3
- from shiny import App, Inputs
4
-
5
- www_dir = Path(__file__).parent.resolve() / "www"
6
-
7
- def classification(input: Inputs):
8
- model_name = 'lincoln/flaubert-mlsum-topic-classification'
9
-
10
- loaded_tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- loaded_model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
-
13
- nlp = TextClassificationPipeline(model=loaded_model, tokenizer=loaded_tokenizer)
14
- result = nlp(Inputs, truncation=True)
15
- print(result)
16
- return result
17
-
18
- app = App(
19
- print('Processing'),
20
- Inputs = "Le Bayern Munich prend la grenadine.",
21
- classification(Inputs),
22
- static_assets=str(www_dir),
23
- )
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from app.model.model import predict_pipeline
4
+ from app.model.model import __version__ as model_version
5
+
6
+
7
+ app = FastAPI()
8
+
9
+
10
+ class TextIn(BaseModel):
11
+ text: str
12
+
13
+
14
+ class PredictionOut(BaseModel):
15
+ language: str
16
+
17
+
18
+ @app.get("/")
19
+ def home():
20
+ return {"health_check": "OK", "model_version": model_version}
21
+
22
+
23
+ @app.post("/predict", response_model=PredictionOut)
24
+ def predict(payload: TextIn):
25
+ language = predict_pipeline(payload.text)
26
+ return {"language": language}