dipankardas011 commited on
Commit
7852f50
·
unverified ·
1 Parent(s): b51ba4d

New commits

Browse files

Signed-off-by: Dipankar Das <dipankardas0115@gmail.com>

Files changed (2) hide show
  1. app.py +26 -27
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,37 +1,36 @@
1
  from fastapi import FastAPI
2
  from fastapi.responses import RedirectResponse
3
- from transformers import pipeline
4
-
5
- # Create a new FastAPI app instance
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
  @app.get("/")
10
  async def docs_redirect():
11
  return RedirectResponse(url='/docs')
12
- # Initialize the text generation pipeline
13
- # This function will be able to generate text
14
- # given an input.
15
- pipe = pipeline("text2text-generation",
16
- model="google/flan-t5-small")
17
-
18
- # Define a function to handle the GET request at `/generate`
19
- # The generate() function is defined as a FastAPI route that takes a
20
- # string parameter called text. The function generates text based on the # input using the pipeline() object, and returns a JSON response
21
- # containing the generated text under the key "output"
22
- # @app.get("/")
23
- # async def root():
24
- # return {"message": "Hello World"}
25
 
26
  @app.get("/generate")
27
- def generate(text: str):
28
- """
29
- Using the text2text-generation pipeline from `transformers`, generate text
30
- from the given input text. The model used is `google/flan-t5-small`, which
31
- can be found [here](<https://huggingface.co/google/flan-t5-small>).
32
- """
33
- # Use the pipeline to generate text from the given input text
34
- output = pipe(text)
35
-
36
- # Return the generated text in a JSON response
37
- return {"output": output[0]["generated_text"]}
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.responses import RedirectResponse
3
+ # from transformers import pipeline
4
+ from PIL import Image
5
+ import requests
6
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration
7
+
8
+ processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
9
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
10
+
11
 
12
  app = FastAPI()
13
 
14
  @app.get("/")
15
  async def docs_redirect():
16
  return RedirectResponse(url='/docs')
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.get("/generate")
19
+ def generate():
20
+ url = "https://www.ilankelman.org/stopsigns/australia.jpg"
21
+ image = Image.open(requests.get(url, stream=True).raw)
22
+
23
+ inputs = processor(images=image, return_tensors="pt")
24
+
25
+ # autoregressive generation
26
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
27
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
+ print(generated_text)
29
+
30
+ # conditional generation
31
+ text = "A picture of"
32
+ inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)
33
+
34
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
35
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
+ print(generated_text)
requirements.txt CHANGED
@@ -4,3 +4,4 @@ uvicorn[standard]==0.17.*
4
  sentencepiece==0.1.*
5
  torch==1.11.*
6
  transformers==4.*
 
 
4
  sentencepiece==0.1.*
5
  torch==1.11.*
6
  transformers==4.*
7
+ Pillow