dominguezdaniel commited on
Commit
31a0b27
·
verified ·
1 Parent(s): 759f312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -1,21 +1,34 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def predict(image):
5
- model_id = "google/vit-base-patch16-224"
6
- classifier = pipeline("image-classification", model=model_id)
7
  predictions = classifier(image)
8
  # Sort predictions based on confidence and select the top one
9
  top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0]
 
10
 
11
- # Generate a promotional tweet based on the top prediction
12
- tweet_template = "Check out this amazing {label}! 📸✨ Explore more about it and let your curiosity lead you to discover wonders."
13
- tweet_text = tweet_template.format(label=top_prediction['label'].split(',')[0]) # Using split to clean up label if necessary
14
- return tweet_text
15
 
16
- title = "Image Classifier to Promotional Tweet"
17
- description = "This demo recognizes and classifies images using the 'google/vit-base-patch16-224' model. Below, you'll see a generated promotional tweet based on the top prediction. Your task: Upload an image, and let's write a tweet about it!"
18
  input_component = gr.Image(type="pil", label="Upload an image here")
19
- output_component = gr.Textbox(label="Generated Promotional Tweet", placeholder="Write a tweet about the image")
20
 
21
  gr.Interface(fn=predict, inputs=input_component, outputs=output_component, title=title, description=description).launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+
4
+ # Initialize the image classification pipeline
5
+ classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
6
+
7
+ # Initialize the tokenizer and model for the generative text (GPT-like model)
8
+ tokenizer = AutoTokenizer.from_pretrained("gpt2") # Example model, replace with your choice
9
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
10
+
11
+ def generate_tweet(label):
12
+ # Generate a promotional tweet using a GPT-like model
13
+ prompt = f"Write a creative and promotional tweet about {label}:"
14
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
15
+ outputs = model.generate(inputs, max_length=280, num_return_sequences=1)
16
+ tweet = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
+ return tweet
18
 
19
  def predict(image):
 
 
20
  predictions = classifier(image)
21
  # Sort predictions based on confidence and select the top one
22
  top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0]
23
+ label = top_prediction['label'].split(',')[0] # Clean up label if necessary
24
 
25
+ # Generate the tweet
26
+ tweet = generate_tweet(label)
27
+ return tweet
 
28
 
29
+ title = "Image Classifier to Generative Tweet"
30
+ description = "This demo recognizes and classifies images using the 'google/vit-base-patch16-224' model and generates a creative promotional tweet about the top prediction using a generative text model."
31
  input_component = gr.Image(type="pil", label="Upload an image here")
32
+ output_component = gr.Textbox(label="Generated Promotional Tweet")
33
 
34
  gr.Interface(fn=predict, inputs=input_component, outputs=output_component, title=title, description=description).launch()