File size: 3,585 Bytes
d0da8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a622515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0da8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a622515
d0da8a7
 
 
 
 
 
 
 
 
 
 
 
 
494bb4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from langchain import OpenAI, PromptTemplate, LLMChain
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.mapreduce import MapReduceChain
from langchain.prompts import PromptTemplate
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
# from langchain.chains import LLMChain
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from transformers import LlamaTokenizer, LlamaForCausalLM
import gradio as gr

print("Loading Pipeline Dolly...")
    # print("Loading Pipeline...", str(File.name))

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", padding_side="left")
base_model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)
instruct_pipeline = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=tokenizer,
    max_length=2048,
    temperature=0.6,
    pad_token_id=tokenizer.eos_token_id,
    top_p=0.95,
    repetition_penalty=1.2
)
# instruct_pipeline = pipeline(model="databricks/dolly-v2-3b",  torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
# print(instruct_pipeline)
print("Dolly Pipeline Loaded!")
llm_dolly = HuggingFacePipeline(pipeline=instruct_pipeline)


# print("Loading Pipeline Alpaca...")
# tokenizer_alpaca = LlamaTokenizer.from_pretrained('minlik/chinese-alpaca-plus-7b-merged')
# model_alpaca = LlamaForCausalLM.from_pretrained('minlik/chinese-alpaca-plus-7b-merged')
# instruct_pipeline_alpaca = pipeline(
#     "text-generation",
#     model=model_alpaca,
#     tokenizer=tokenizer_alpaca,
#     max_length=1024,
#     temperature=0.6,
#     pad_token_id=tokenizer_alpaca.eos_token_id,
#     top_p=0.95,
#     repetition_penalty=1.2,
#     device_map= "auto"
# )
# print("Pipeline Loaded Alpaca!")
# llm_alpaca = HuggingFacePipeline(pipeline=instruct_pipeline_alpaca)

def summarize(Model, File, Input_text):
    prompt_template = """Write a concise summary of the following:


    {text}
    
    Summary in English:
    """

    PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])


    text_splitter = CharacterTextSplitter()
    if File:
        with open(str(File.name)) as f:
            state_of_the_union = f.read()
        text = state_of_the_union
    else:
        text = Input_text
    print(text)
    texts = text_splitter.split_text(text)

    docs = [Document(page_content=t) for t in texts[:3]]
    print("Printing Docs-------")
    print(docs)
    print("-----------------\n\n")
    if Model=='Dolly':
        chain = load_summarize_chain(llm_dolly, chain_type="refine", question_prompt=PROMPT)
    else:
        chain = load_summarize_chain(llm_dolly, chain_type="refine", question_prompt=PROMPT)
    summary_text = chain({"input_documents": docs}, return_only_outputs=True)
    print(summary_text["output_text"])
    return summary_text["output_text"]

def greet(name):
    return "Hello " + name + "!"

# with gr.Blocks() as demo:
#     a = gr.File()
#     gr.Interface(fn=summarize, inputs = [gr.inputs.Dropdown(["Dolly", "Alpaca"]), a , "text"], outputs="text", title="Summarization Tool")

demo = gr.Interface(fn=summarize, inputs = [gr.inputs.Dropdown(["Dolly", "Alpaca"]),gr.inputs.File(label="Upload .txt file"), "text"], outputs="text", title="Summarization Tool")
    
demo.queue().launch()