Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def generate_answer(llm_name, texts, query, queries, mode='validate'): | |
| if llm_name == 'solar': | |
| tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| "Upstage/SOLAR-10.7B-Instruct-v1.0", | |
| device_map="auto", #device_map="cuda" | |
| #torch_dtype=torch.float16, | |
| ) | |
| elif llm_name == 'mistral': | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| #device_map="auto", | |
| device_map="cuda", | |
| torch_dtype=torch.float16, | |
| ) | |
| elif llm_name == 'phi3mini': | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Phi-3-mini-128k-instruct", | |
| device_map="auto", | |
| torch_dtype="auto", | |
| trust_remote_code=True, | |
| ) | |
| template_texts ="" | |
| for i, text in enumerate(texts): | |
| template_texts += f'{i+1}. {text} \n' | |
| if mode == 'validate': | |
| conversation = [ {'role': 'user', 'content': f'Given the following query: "{query}"? \nIs the following document relevant to answer this query?\n{template_texts} \nResponse: Yes / No'} ] | |
| elif mode == 'summarize': | |
| conversation = [ {'role': 'user', 'content': f'For the following query and documents, try to answer the given query based on the documents.\nQuery: {query} \nDocuments: {template_texts}.'} ] | |
| elif mode == 'h_summarize': | |
| conversation = [ {'role': 'user', 'content': f'The documents below describe a developing disaster event. Based on these documents, write a brief summary in the form of a paragraph, highlighting the most crucial information. \nDocuments: {template_texts}'} ] | |
| elif mode == "multi_summarize": | |
| # conversation = [ {'role': 'user', 'content': f'For the following queries and documents, try to answer the given queries based on the documents. Also, return the top 5 unaltered documents that answer the queries.\nQueries: {queries} \nDocuments: {template_texts}.'} ] | |
| conversation = [ {'role': 'user', 'content': f'For the following queries and documents, in a brief paragraph try to answer the given queries based on the documents. Then, return the top 5 documents as provided that answer the queries.\nQueries: {queries} \nDocuments: {template_texts}.'} ] | |
| prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device) | |
| outputs = llm_model.generate(**inputs, use_cache=True, max_length=4096,do_sample=True,temperature=0.7,top_p=0.95,top_k=10,repetition_penalty=1.1) | |
| output_text = tokenizer.decode(outputs[0]) | |
| if llm_name == "solar": | |
| assistant_respond = output_text.split("Assistant:")[1] | |
| elif llm_name == "phi3mini": | |
| assistant_respond = output_text.split("<|assistant|>")[1] | |
| assistant_respond = assistant_respond[:-7] | |
| else: | |
| assistant_respond = output_text.split("[/INST]")[1] | |
| if mode == 'validate': | |
| if 'Yes' in assistant_respond: | |
| return True | |
| else: | |
| return False | |
| elif mode == 'summarize': | |
| return assistant_respond | |
| elif mode == 'h_summarize': | |
| return assistant_respond | |
| elif mode == 'multi_summarize': | |
| return assistant_respond | |