amankumar commited on
Commit
d0da8a7
β€’
1 Parent(s): 37b4cf0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from langchain import OpenAI, PromptTemplate, LLMChain
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.chains.mapreduce import MapReduceChain
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
9
+ # from langchain.chains import LLMChain
10
+ from langchain.chains.summarize import load_summarize_chain
11
+ from langchain.docstore.document import Document
12
+ from langchain.llms import HuggingFacePipeline
13
+ from transformers import LlamaTokenizer, LlamaForCausalLM
14
+ import gradio as gr
15
+
16
+ print("Loading Pipeline Dolly...")
17
+ # print("Loading Pipeline...", str(File.name))
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", padding_side="left")
20
+ base_model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)
21
+ instruct_pipeline = pipeline(
22
+ "text-generation",
23
+ model=base_model,
24
+ tokenizer=tokenizer,
25
+ max_length=2048,
26
+ temperature=0.6,
27
+ pad_token_id=tokenizer.eos_token_id,
28
+ top_p=0.95,
29
+ repetition_penalty=1.2
30
+ )
31
+ # instruct_pipeline = pipeline(model="databricks/dolly-v2-3b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
32
+ # print(instruct_pipeline)
33
+ print("Dolly Pipeline Loaded!")
34
+ llm_dolly = HuggingFacePipeline(pipeline=instruct_pipeline)
35
+
36
+
37
+ print("Loading Pipeline Alpaca...")
38
+ tokenizer_alpaca = LlamaTokenizer.from_pretrained('minlik/chinese-alpaca-plus-7b-merged')
39
+ model_alpaca = LlamaForCausalLM.from_pretrained('minlik/chinese-alpaca-plus-7b-merged')
40
+ instruct_pipeline_alpaca = pipeline(
41
+ "text-generation",
42
+ model=model_alpaca,
43
+ tokenizer=tokenizer_alpaca,
44
+ max_length=1024,
45
+ temperature=0.6,
46
+ pad_token_id=tokenizer_alpaca.eos_token_id,
47
+ top_p=0.95,
48
+ repetition_penalty=1.2,
49
+ device_map= "auto"
50
+ )
51
+ print("Pipeline Loaded Alpaca!")
52
+ llm_alpaca = HuggingFacePipeline(pipeline=instruct_pipeline_alpaca)
53
+
54
+ def summarize(Model, File, Input_text):
55
+ prompt_template = """Write a concise summary of the following:
56
+
57
+
58
+ {text}
59
+
60
+ Summary in English:
61
+ """
62
+
63
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
64
+
65
+
66
+ text_splitter = CharacterTextSplitter()
67
+ if File:
68
+ with open(str(File.name)) as f:
69
+ state_of_the_union = f.read()
70
+ text = state_of_the_union
71
+ else:
72
+ text = Input_text
73
+ print(text)
74
+ texts = text_splitter.split_text(text)
75
+
76
+ docs = [Document(page_content=t) for t in texts[:3]]
77
+ print("Printing Docs-------")
78
+ print(docs)
79
+ print("-----------------\n\n")
80
+ if Model=='Dolly':
81
+ chain = load_summarize_chain(llm_dolly, chain_type="refine", question_prompt=PROMPT)
82
+ else:
83
+ chain = load_summarize_chain(llm_alpaca, chain_type="refine", question_prompt=PROMPT)
84
+ summary_text = chain({"input_documents": docs}, return_only_outputs=True)
85
+ print(summary_text["output_text"])
86
+ return summary_text["output_text"]
87
+
88
+ def greet(name):
89
+ return "Hello " + name + "!"
90
+
91
+ # with gr.Blocks() as demo:
92
+ # a = gr.File()
93
+ # gr.Interface(fn=summarize, inputs = [gr.inputs.Dropdown(["Dolly", "Alpaca"]), a , "text"], outputs="text", title="Summarization Tool")
94
+
95
+ demo = gr.Interface(fn=summarize, inputs = [gr.inputs.Dropdown(["Dolly", "Alpaca"]),gr.inputs.File(label="Upload .txt file"), "text"], outputs="text", title="Summarization Tool")
96
+
97
+ demo.queue().launch(share = True)