jframed281's picture
modify app.py
dca8200
# -*- coding: utf-8 -*-
"""imagetortext.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1UXh8tivt-4vHaBeXgfyq-TYLvpJAMZVV
"""
!pip install transformers dataset
!pip install -q gradio
!pip install sentencepiece
!pip install googletrans==3.1.0a0
import gradio as gr
import requests
from PIL import Image
from torchvision import transforms
from transformers import pipeline
import torch
import sentencepiece
import re
import googletrans
from googletrans import Translator
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
from transformers import SegformerFeatureExtractor, SegformerForImageClassification, T5Tokenizer, T5Model
from PIL import Image
import requests
def loadImageToText(image, argument):
# url = "https://media.istockphoto.com/id/470604022/es/foto/%C3%A1rbol-de-manzano.jpg?s=1024x1024&w=is&k=20&c=R7b6jPeTGsDw75Sqn3VwpNRckqlAkJNPLelb48pCk2U="
# image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b2")
model = SegformerForImageClassification.from_pretrained("nvidia/mit-b2")
translator = Translator()
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
part_args = f"<"+re.sub("[^(\w|<|>)]+(?=\w)", "><", argument) + ">"
story_gen = pipeline("text-generation", "pranavpsv/gpt2-genre-story-generator")
story_text = story_gen(part_args + model.config.id2label[predicted_class_idx])
generate_text_stroy = story_text[0]["generated_text"]
ln_text_story = generate_text_stroy[len(part_args):len(generate_text_stroy)]
translated_ita = translator.translate(ln_text_story, src='en', dest='es')
return translated_ita.text
#print(loadImageToText("animal,super")) # borra el argumento 'image' y sus variables internas si quieres probarlo desde aquí.
gr.Interface(fn=loadImageToText,
inputs=[gr.Image(), gr.Text(label="Argumentos base", placeholder="Verano, película,playa, superhéroe, animal")],
outputs="text").launch()