File size: 1,697 Bytes
1207df4
 
 
 
 
 
843fc81
1207df4
9c16e7d
1207df4
 
 
 
 
 
 
843fc81
f042920
 
1207df4
843fc81
 
1207df4
 
 
f042920
1207df4
 
f042920
843fc81
1207df4
 
 
 
 
 
 
 
 
 
 
 
 
843fc81
2e8ce81
1207df4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from transformers import pipeline
import torch
from diffusers import DiffusionPipeline

def main():
    # Prepare pipeline
    classifier = pipeline("text-classification", model="lori0330/BART_FineTuned_ZeroShotClassification")
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    painter = DiffusionPipeline.from_pretrained(
        "cagliostrolab/animagine-xl-3.1", 
        torch_dtype=torch.float16, 
        use_safetensors=True, 
        )
    painter.to('cuda')

    # Edit the space
    st.title("Brief Report Generator")
    st.write("Copy the text here:")
    user_input = st.text_input("")

    # Check input
    if user_input:
        result_1 = classifier(user_input)
        label = result_1[0]['label']
        st.write(f"Label of this text: {label}")

        result_2 = summarizer(user_input, max_length=100, min_length=30, do_sample=False)
        summary = result_2[0]['summary_text']
        st.write(f"The summary of this text:\n{summary}")
        
        description = f"This is mainly about {label}: {summary}"
        negative_prompt = "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"

        image = painter(
            description, 
            negative_prompt=negative_prompt,
            width=832,
            height=1216, 
            guidance_scale=7,
            num_inference_steps=28
            ).images[0]

        st.write(f"The attached image:\n")
        st.image(image)

if __name__ == "__main__":
    main()